mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
@@ -79,11 +79,11 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) {
|
||||
|
||||
c.StreamResponseTemplate = json.Get("streamResponseTemplate").String()
|
||||
if c.StreamResponseTemplate == "" {
|
||||
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
|
||||
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
|
||||
}
|
||||
c.ResponseTemplate = json.Get("responseTemplate").String()
|
||||
if c.ResponseTemplate == "" {
|
||||
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
}
|
||||
|
||||
if json.Get("enableSemanticCache").Exists() {
|
||||
|
||||
@@ -34,7 +34,7 @@ func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log wrapper.
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action {
|
||||
templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable")
|
||||
if templateEnable != "true" {
|
||||
if templateEnable == "false" {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
@@ -4,14 +4,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/google/uuid"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -551,7 +551,8 @@ func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.Ht
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
||||
return ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
|
||||
token, _ := ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
|
||||
@@ -41,9 +41,9 @@ const (
|
||||
LowRisk = "low"
|
||||
NoRisk = "none"
|
||||
|
||||
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}`
|
||||
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
|
||||
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
|
||||
|
||||
DefaultRequestCheckService = "llm_query_moderation"
|
||||
@@ -262,8 +262,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
log.Debugf("checking request body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx.SetContext("requestModel", model)
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
@@ -308,11 +306,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := generateRandomID()
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := generateRandomID()
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
@@ -369,15 +367,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func convertHeaders(hs [][2]string) map[string][]string {
|
||||
ret := make(map[string][]string)
|
||||
for _, h := range hs {
|
||||
k, v := strings.ToLower(h[0]), h[1]
|
||||
ret[k] = append(ret[k], v)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||
if !config.checkResponse {
|
||||
log.Debugf("response checking is disabled")
|
||||
@@ -398,7 +387,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
startTime := time.Now().UnixMilli()
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
isStreamingResponse := strings.Contains(contentType, "event-stream")
|
||||
model := ctx.GetStringContext("requestModel", "unknown")
|
||||
var content string
|
||||
if isStreamingResponse {
|
||||
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
|
||||
@@ -449,11 +437,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if isStreamingResponse {
|
||||
randomID := generateRandomID()
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := generateRandomID()
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
config.incrementCounter("ai_sec_response_deny", 1)
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
RouteName = "route"
|
||||
ClusterName = "cluster"
|
||||
APIName = "api"
|
||||
ConsumerKey = "x-mse-consumer"
|
||||
|
||||
// Source Type
|
||||
FixedValue = "fixed_value"
|
||||
@@ -81,8 +82,8 @@ type AIStatisticsConfig struct {
|
||||
shouldBufferStreamingBody bool
|
||||
}
|
||||
|
||||
func generateMetricName(route, cluster, model, metricName string) string {
|
||||
return fmt.Sprintf("route.%s.upstream.%s.model.%s.metric.%s", route, cluster, model, metricName)
|
||||
func generateMetricName(route, cluster, model, consumer, metricName string) string {
|
||||
return fmt.Sprintf("route.%s.upstream.%s.model.%s.consumer.%s.metric.%s", route, cluster, model, consumer, metricName)
|
||||
}
|
||||
|
||||
func getRouteName() (string, error) {
|
||||
@@ -115,6 +116,9 @@ func getClusterName() (string, error) {
|
||||
}
|
||||
|
||||
func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) {
|
||||
if inc == 0 {
|
||||
return
|
||||
}
|
||||
counter, ok := config.counterMetrics[metricName]
|
||||
if !ok {
|
||||
counter = proxywasm.DefineCounterMetric(metricName)
|
||||
@@ -158,6 +162,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
|
||||
ctx.SetContext(ClusterName, cluster)
|
||||
ctx.SetUserAttribute(APIName, api)
|
||||
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
|
||||
if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" {
|
||||
ctx.SetContext(ConsumerKey, consumer)
|
||||
}
|
||||
|
||||
// Set user defined log & span attributes which type is fixed_value
|
||||
setAttributeBySource(ctx, config, FixedValue, nil, log)
|
||||
@@ -388,6 +395,7 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
|
||||
var ok bool
|
||||
var route, cluster, model string
|
||||
var inputToken, outputToken uint64
|
||||
consumer := ctx.GetStringContext(ConsumerKey, "none")
|
||||
route, ok = ctx.GetContext(RouteName).(string)
|
||||
if !ok {
|
||||
log.Warnf("RouteName typd assert failed, skip metric record")
|
||||
@@ -421,8 +429,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
|
||||
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken)
|
||||
|
||||
// Generate duration metrics
|
||||
var llmFirstTokenDuration, llmServiceDuration uint64
|
||||
@@ -433,8 +441,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
|
||||
log.Warnf("LLMFirstTokenDuration typd assert failed")
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMFirstTokenDuration), llmFirstTokenDuration)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMStreamDurationCount), 1)
|
||||
}
|
||||
if ctx.GetUserAttribute(LLMServiceDuration) != nil {
|
||||
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
|
||||
@@ -442,8 +450,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
|
||||
log.Warnf("LLMServiceDuration typd assert failed")
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMServiceDuration), llmServiceDuration)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMDurationCount), 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user