diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 19a9b2b85..b46fd28e8 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -74,6 +74,9 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) + ctx.SetUserAttribute("cache_status", "hit") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + if stream { proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1) } else { diff --git a/plugins/wasm-go/extensions/ai-cache/go.mod b/plugins/wasm-go/extensions/ai-cache/go.mod index e4aae265e..56bea605f 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.mod +++ b/plugins/wasm-go/extensions/ai-cache/go.mod @@ -8,14 +8,14 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( github.com/alibaba/higress/plugins/wasm-go v1.4.2 - github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f + github.com/google/uuid v1.6.0 + github.com/higress-group/proxy-wasm-go-sdk v1.0.0 github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 // github.com/weaviate/weaviate-go-client/v4 v4.15.1 ) require ( - github.com/google/uuid v1.6.0 // indirect github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect github.com/magefile/mage v1.14.0 // indirect github.com/stretchr/testify v1.9.0 // indirect diff --git a/plugins/wasm-go/extensions/ai-cache/go.sum b/plugins/wasm-go/extensions/ai-cache/go.sum index 7ada0c8b7..0a3635868 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.sum +++ b/plugins/wasm-go/extensions/ai-cache/go.sum @@ -3,8 +3,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 1aca29f0e..62edb80dc 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -128,9 +128,15 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { skipCache := ctx.GetContext(SKIP_CACHE_HEADER) if skipCache != nil { + ctx.SetUserAttribute("cache_status", "skip") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) ctx.DontReadResponseBody() return types.ActionContinue } + if ctx.GetContext(CACHE_KEY_CONTEXT_KEY) != nil { + ctx.SetUserAttribute("cache_status", "miss") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index 68eeeae20..a005299da 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -31,6 +31,7 @@ description: 阿里云内容安全检测 | `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 | | `protocol` | string | optional | openai | 协议格式,非openai协议填`original` | | `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low | +| `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 | 补充说明一下 `denyMessage`,对非法请求的处理逻辑为: - 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应 diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index f4aee5632..4fa6e07c6 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -53,6 +53,7 @@ const ( DefaultStreamingResponseJsonPath = "choices.0.delta.content" DefaultDenyCode = 200 DefaultDenyMessage = "很抱歉,我无法回答您的问题" + DefaultTimeout = 2000 AliyunUserAgent = "CIPFrom/AIGateway" LengthLimit = 1800 @@ -100,6 +101,7 @@ type AISecurityConfig struct { denyMessage string protocolOriginal bool riskLevelBar string + timeout uint32 metrics map[string]proxywasm.MetricCounter } @@ -225,6 +227,11 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e } else { config.riskLevelBar = HighRisk } + if obj := json.Get("timeout"); obj.Exists() { + config.timeout = uint32(obj.Int()) + } else { + config.timeout = DefaultTimeout + } config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ FQDN: serviceName, Port: servicePort, @@ -253,6 +260,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { log.Debugf("checking request body...") + startTime := time.Now().UnixMilli() content := gjson.GetBytes(body, config.requestContentJsonPath).String() model := gjson.GetBytes(body, "model").String() ctx.SetContext("requestModel", model) @@ -279,6 +287,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] } if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { if contentIndex >= len(content) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) proxywasm.ResumeHttpRequest() } else { singleCall() @@ -305,7 +317,9 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] } ctx.DontReadResponseBody() config.incrementCounter("ai_sec_request_deny", 1) - ctx.SetUserAttribute("safecheck_status", "request deny") + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") if response.Data.Advice != nil { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) @@ -345,7 +359,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] reqParams.Add(k, v) } reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback) + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) proxywasm.ResumeHttpRequest() @@ -364,20 +378,6 @@ func convertHeaders(hs [][2]string) map[string][]string { return ret } -// headers: map[string][]string -> [][2]string -func reconvertHeaders(hs map[string][]string) [][2]string { - var ret [][2]string - for k, vs := range hs { - for _, v := range vs { - ret = append(ret, [2]string{k, v}) - } - } - sort.SliceStable(ret, func(i, j int) bool { - return ret[i][0] < ret[j][0] - }) - return ret -} - func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { if !config.checkResponse { log.Debugf("response checking is disabled") @@ -401,6 +401,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { log.Debugf("checking response body...") + startTime := time.Now().UnixMilli() hdsMap := ctx.GetContext("headers").(map[string][]string) isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") model := ctx.GetStringContext("requestModel", "unknown") @@ -433,6 +434,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ } if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { if contentIndex >= len(content) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "response pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) proxywasm.ResumeHttpResponse() } else { singleCall() @@ -458,6 +463,8 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) } config.incrementCounter("ai_sec_response_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) ctx.SetUserAttribute("safecheck_status", "response deny") if response.Data.Advice != nil { ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) @@ -498,7 +505,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ reqParams.Add(k, v) } reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback) + err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout) if err != nil { log.Errorf("failed call the safe check service: %v", err) proxywasm.ResumeHttpResponse() diff --git a/plugins/wasm-go/extensions/ai-statistics/go.sum b/plugins/wasm-go/extensions/ai-statistics/go.sum index 6b1c2c3cd..b4ab172fe 100644 --- a/plugins/wasm-go/extensions/ai-statistics/go.sum +++ b/plugins/wasm-go/extensions/ai-statistics/go.sum @@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= diff --git a/plugins/wasm-go/extensions/ai-statistics/main.go b/plugins/wasm-go/extensions/ai-statistics/main.go index 14fcc4d2a..1c8765638 100644 --- a/plugins/wasm-go/extensions/ai-statistics/main.go +++ b/plugins/wasm-go/extensions/ai-statistics/main.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "strconv" "strings" "time" @@ -28,14 +27,15 @@ func main() { } const ( - // Trace span prefix - TracePrefix = "trace_span_tag." // Context consts StatisticsRequestStartTime = "ai-statistics-request-start-time" StatisticsFirstTokenTime = "ai-statistics-first-token-time" CtxGeneralAtrribute = "attributes" CtxLogAtrribute = "logAttributes" CtxStreamingBodyBuffer = "streamingBodyBuffer" + RouteName = "route" + ClusterName = "cluster" + APIName = "api" // Source Type FixedValue = "fixed_value" @@ -46,12 +46,14 @@ const ( ResponseBody = "response_body" // Inner metric & log attributes name - Model = "model" - InputToken = "input_token" - OutputToken = "output_token" - LLMFirstTokenDuration = "llm_first_token_duration" - LLMServiceDuration = "llm_service_duration" - LLMDurationCount = "llm_duration_count" + Model = "model" + InputToken = "input_token" + OutputToken = "output_token" + LLMFirstTokenDuration = "llm_first_token_duration" + LLMServiceDuration = "llm_service_duration" + LLMDurationCount = "llm_duration_count" + LLMStreamDurationCount = "llm_stream_duration_count" + ResponseType = "response_type" // Extract Rule RuleFirst = "first" @@ -91,6 +93,19 @@ func getRouteName() (string, error) { } } +func getAPIName() (string, error) { + if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil { + return "-", err + } else { + parts := strings.Split(string(raw), "@") + if len(parts) != 5 { + return "-", errors.New("not api type") + } else { + return strings.Join(parts[:3], "@"), nil + } + } +} + func getClusterName() (string, error) { if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil { return "-", err @@ -133,8 +148,15 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log wrappe } func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action { - ctx.SetContext(CtxGeneralAtrribute, map[string]string{}) - ctx.SetContext(CtxLogAtrribute, map[string]string{}) + route, _ := getRouteName() + cluster, _ := getClusterName() + api, api_error := getAPIName() + if api_error == nil { + route = api + } + ctx.SetContext(RouteName, route) + ctx.SetContext(ClusterName, cluster) + ctx.SetUserAttribute(APIName, api) ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli()) // Set user defined log & span attributes which type is fixed_value @@ -149,6 +171,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { // Set user defined log & span attributes. setAttributeBySource(ctx, config, RequestBody, body, log) + + // Write log + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) return types.ActionContinue } @@ -177,6 +202,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat ctx.SetContext(CtxStreamingBodyBuffer, streamingBodyBuffer) } + ctx.SetUserAttribute(ResponseType, "stream") + // Get requestStartTime from http context requestStartTime, ok := ctx.GetContext(StatisticsRequestStartTime).(int64) if !ok { @@ -188,28 +215,19 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat if ctx.GetContext(StatisticsFirstTokenTime) == nil { firstTokenTime := time.Now().UnixMilli() ctx.SetContext(StatisticsFirstTokenTime, firstTokenTime) - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - attributes[LLMFirstTokenDuration] = fmt.Sprint(firstTokenTime - requestStartTime) - ctx.SetContext(CtxGeneralAtrribute, attributes) + ctx.SetUserAttribute(LLMFirstTokenDuration, firstTokenTime-requestStartTime) } // Set information about this request - if model, inputToken, outputToken, ok := getUsage(data); ok { - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - // Record Log Attributes - attributes[Model] = model - attributes[InputToken] = fmt.Sprint(inputToken) - attributes[OutputToken] = fmt.Sprint(outputToken) - // Set attributes to http context - ctx.SetContext(CtxGeneralAtrribute, attributes) + ctx.SetUserAttribute(Model, model) + ctx.SetUserAttribute(InputToken, inputToken) + ctx.SetUserAttribute(OutputToken, outputToken) } // If the end of the stream is reached, record metrics/logs/spans. if endOfStream { responseEndTime := time.Now().UnixMilli() - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime) - ctx.SetContext(CtxGeneralAtrribute, attributes) + ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime) // Set user defined log & span attributes. if config.shouldBufferStreamingBody { @@ -220,11 +238,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log) } - // Write inner filter states which can be used by other plugins such as ai-token-ratelimit - writeFilterStates(ctx, log) - // Write log - writeLog(ctx, log) + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) // Write metrics writeMetric(ctx, config, log) @@ -233,33 +248,26 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat } func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { - // Get attributes from http context - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - // Get requestStartTime from http context requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64) responseEndTime := time.Now().UnixMilli() - attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime) + ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime) + + ctx.SetUserAttribute(ResponseType, "normal") // Set information about this request - model, inputToken, outputToken, ok := getUsage(body) - if ok { - attributes[Model] = model - attributes[InputToken] = fmt.Sprint(inputToken) - attributes[OutputToken] = fmt.Sprint(outputToken) - // Update attributes - ctx.SetContext(CtxGeneralAtrribute, attributes) + if model, inputToken, outputToken, ok := getUsage(body); ok { + ctx.SetUserAttribute(Model, model) + ctx.SetUserAttribute(InputToken, inputToken) + ctx.SetUserAttribute(OutputToken, outputToken) } // Set user defined log & span attributes. setAttributeBySource(ctx, config, ResponseBody, body, log) - // Write inner filter states which can be used by other plugins such as ai-token-ratelimit - writeFilterStates(ctx, log) - // Write log - writeLog(ctx, log) + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) // Write metrics writeMetric(ctx, config, log) @@ -294,57 +302,45 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag // fetches the tracing span value from the specified source. func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) { - attributes, ok := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - if !ok { - log.Error("failed to get attributes from http context") - return - } for _, attribute := range config.attributes { + var key, value string + var err error if source == attribute.ValueSource { + key = attribute.Key switch source { case FixedValue: log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value) - attributes[attribute.Key] = attribute.Value + value = attribute.Value case RequestHeader: - if value, err := proxywasm.GetHttpRequestHeader(attribute.Value); err == nil { + if value, err = proxywasm.GetHttpRequestHeader(attribute.Value); err == nil { log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value } case RequestBody: raw := gjson.GetBytes(body, attribute.Value).Raw - var value string if len(raw) > 2 { value = raw[1 : len(raw)-1] } log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value case ResponseHeader: - if value, err := proxywasm.GetHttpResponseHeader(attribute.Value); err == nil { + if value, err = proxywasm.GetHttpResponseHeader(attribute.Value); err == nil { log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value } case ResponseStreamingBody: - value := extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) + value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value case ResponseBody: - value := gjson.GetBytes(body, attribute.Value).Raw - if len(value) > 2 && value[0] == '"' && value[len(value)-1] == '"' { - value = value[1 : len(value)-1] - } + value = gjson.GetBytes(body, attribute.Value).String() log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value) - attributes[attribute.Key] = value default: } - } - if attribute.ApplyToLog { - setLogAttribute(ctx, attribute.Key, attributes[attribute.Key], log) - } - if attribute.ApplyToSpan { - setSpanAttribute(attribute.Key, attributes[attribute.Key], log) + if attribute.ApplyToLog { + ctx.SetUserAttribute(key, value) + } + if attribute.ApplyToSpan { + setSpanAttribute(key, value, log) + } } } - ctx.SetContext(CtxGeneralAtrribute, attributes) } func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string { @@ -368,9 +364,9 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l } else if rule == RuleAppend { // extract llm response for _, chunk := range chunks { - raw := gjson.GetBytes(chunk, jsonPath).Raw - if len(raw) > 2 && raw[0] == '"' && raw[len(raw)-1] == '"' { - value += raw[1 : len(raw)-1] + jsonObj := gjson.GetBytes(chunk, jsonPath) + if jsonObj.Exists() { + value += jsonObj.String() } } } else { @@ -379,123 +375,49 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l return value } -func setFilterState(key, value string, log wrapper.Log) { - if value != "" { - if e := proxywasm.SetProperty([]string{key}, []byte(fmt.Sprint(value))); e != nil { - log.Errorf("failed to set %s in filter state: %v", key, e) - } - } else { - log.Debugf("failed to write filter state [%s], because it's value is empty") - } -} - // Set the tracing span with value. func setSpanAttribute(key, value string, log wrapper.Log) { if value != "" { - traceSpanTag := TracePrefix + key + traceSpanTag := wrapper.TraceSpanTagPrefix + key if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil { - log.Errorf("failed to set %s in filter state: %v", traceSpanTag, e) + log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e) } } else { log.Debugf("failed to write span attribute [%s], because it's value is empty") } } -// fetches the tracing span value from the specified source. -func setLogAttribute(ctx wrapper.HttpContext, key string, value interface{}, log wrapper.Log) { - logAttributes, ok := ctx.GetContext(CtxLogAtrribute).(map[string]string) - if !ok { - log.Error("failed to get logAttributes from http context") - return - } - logAttributes[key] = fmt.Sprint(value) - ctx.SetContext(CtxLogAtrribute, logAttributes) -} - -func writeFilterStates(ctx wrapper.HttpContext, log wrapper.Log) { - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - setFilterState(Model, attributes[Model], log) - setFilterState(InputToken, attributes[InputToken], log) - setFilterState(OutputToken, attributes[OutputToken], log) -} - func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) { - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - route, _ := getRouteName() - cluster, _ := getClusterName() - model, ok := attributes["model"] - if !ok { - log.Errorf("Get model failed") + route := ctx.GetContext(RouteName).(string) + cluster := ctx.GetContext(ClusterName).(string) + // Generate usage metrics + var model string + var inputToken, outputToken int64 + if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil { + log.Warnf("get usage information failed, skip metric record") return } - if inputToken, ok := attributes[InputToken]; ok { - inputTokenUint64, err := strconv.ParseUint(inputToken, 10, 0) - if err != nil || inputTokenUint64 == 0 { - log.Errorf("inputToken convert failed, value is %d, err msg is [%v]", inputTokenUint64, err) - return - } - config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputTokenUint64) + model = ctx.GetUserAttribute(Model).(string) + inputToken = ctx.GetUserAttribute(InputToken).(int64) + outputToken = ctx.GetUserAttribute(OutputToken).(int64) + if inputToken == 0 || outputToken == 0 { + log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record") + return } - if outputToken, ok := attributes[OutputToken]; ok { - outputTokenUint64, err := strconv.ParseUint(outputToken, 10, 0) - if err != nil || outputTokenUint64 == 0 { - log.Errorf("outputToken convert failed, value is %d, err msg is [%v]", outputTokenUint64, err) - return - } - config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputTokenUint64) - } - if llmFirstTokenDuration, ok := attributes[LLMFirstTokenDuration]; ok { - llmFirstTokenDurationUint64, err := strconv.ParseUint(llmFirstTokenDuration, 10, 0) - if err != nil || llmFirstTokenDurationUint64 == 0 { - log.Errorf("llmFirstTokenDuration convert failed, value is %d, err msg is [%v]", llmFirstTokenDurationUint64, err) - return - } - config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDurationUint64) - } - if llmServiceDuration, ok := attributes[LLMServiceDuration]; ok { - llmServiceDurationUint64, err := strconv.ParseUint(llmServiceDuration, 10, 0) - if err != nil || llmServiceDurationUint64 == 0 { - log.Errorf("llmServiceDuration convert failed, value is %d, err msg is [%v]", llmServiceDurationUint64, err) - return - } - config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDurationUint64) - } - config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1) -} + config.incrementCounter(generateMetricName(route, cluster, model, InputToken), uint64(inputToken)) + config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), uint64(outputToken)) -func writeLog(ctx wrapper.HttpContext, log wrapper.Log) { - attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) - logAttributes, _ := ctx.GetContext(CtxLogAtrribute).(map[string]string) - // Set inner log fields - if attributes[Model] != "" { - logAttributes[Model] = attributes[Model] + // Generate duration metrics + var llmFirstTokenDuration, llmServiceDuration int64 + // Is stream response + if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil { + llmFirstTokenDuration = ctx.GetUserAttribute(LLMFirstTokenDuration).(int64) + config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), uint64(llmFirstTokenDuration)) + config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1) } - if attributes[InputToken] != "" { - logAttributes[InputToken] = attributes[InputToken] - } - if attributes[OutputToken] != "" { - logAttributes[OutputToken] = attributes[OutputToken] - } - if attributes[LLMFirstTokenDuration] != "" { - logAttributes[LLMFirstTokenDuration] = attributes[LLMFirstTokenDuration] - } - if attributes[LLMServiceDuration] != "" { - logAttributes[LLMServiceDuration] = attributes[LLMServiceDuration] - } - // Traverse log fields - items := []string{} - for k, v := range logAttributes { - items = append(items, fmt.Sprintf(`"%s":"%s"`, k, v)) - } - aiLogField := fmt.Sprintf(`{%s}`, strings.Join(items, ",")) - // log.Infof("ai request json log: %s", aiLogField) - jsonMap := map[string]string{ - "ai_log": aiLogField, - } - serialized, _ := json.Marshal(jsonMap) - jsonLogRaw := gjson.GetBytes(serialized, "ai_log").Raw - jsonLog := jsonLogRaw[1 : len(jsonLogRaw)-1] - if err := proxywasm.SetProperty([]string{"ai_log"}, []byte(jsonLog)); err != nil { - log.Errorf("failed to set ai_log in filter state: %v", err) + if ctx.GetUserAttribute(LLMServiceDuration) != nil { + llmServiceDuration = ctx.GetUserAttribute(LLMServiceDuration).(int64) + config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), uint64(llmServiceDuration)) + config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1) } } diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum index 4bc7bb752..7b8c22894 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/go.sum @@ -5,8 +5,7 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= @@ -14,8 +13,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8= github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go index afe463a12..6877ae5c2 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go @@ -15,6 +15,7 @@ package main import ( + "bytes" "fmt" "net" "net/url" @@ -61,9 +62,9 @@ const ( ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字 CookieHeader string = "cookie" - RateLimitLimitHeader string = "X-RateLimit-Limit" // 限制的总请求数 - RateLimitRemainingHeader string = "X-RateLimit-Remaining" // 剩余还可以发送的请求数 - RateLimitResetHeader string = "X-RateLimit-Reset" // 限流重置时间(触发限流时返回) + RateLimitLimitHeader string = "X-TokenRateLimit-Limit" // 限制的总请求数 + RateLimitRemainingHeader string = "X-TokenRateLimit-Remaining" // 剩余还可以发送的请求数 + RateLimitResetHeader string = "X-TokenRateLimit-Reset" // 限流重置时间(触发限流时返回) ) type LimitContext struct { @@ -124,6 +125,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon } if context.remaining < 0 { // 触发限流 + ctx.SetUserAttribute("token_ratelimit_status", "limited") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) rejected(config, context) } else { proxywasm.ResumeHttpRequest() @@ -137,39 +140,49 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon } func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte { - if !endOfStream { - return data + var inputToken, outputToken int64 + if inputToken, outputToken, ok := getUsage(data); ok { + ctx.SetContext("input_token", inputToken) + ctx.SetContext("output_token", outputToken) } - inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"}) - if err != nil { - return data + if endOfStream { + if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil { + return data + } + inputToken = ctx.GetContext("input_token").(int64) + outputToken = ctx.GetContext("output_token").(int64) + limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext) + if !ok { + return data + } + keys := []interface{}{limitRedisContext.key} + args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken} + err := config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil) + if err != nil { + log.Errorf("redis call failed: %v", err) + } } - outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"}) - if err != nil { - return data - } - inputToken, err := strconv.Atoi(string(inputTokenStr)) - if err != nil { - return data - } - outputToken, err := strconv.Atoi(string(outputTokenStr)) - if err != nil { - return data - } - limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext) - if !ok { - return data - } - keys := []interface{}{limitRedisContext.key} - args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken} + return data +} - err = config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil) - if err != nil { - log.Errorf("redis call failed: %v", err) - return data - } else { - return data +func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) { + chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) + for _, chunk := range chunks { + // the feature strings are used to identify the usage data, like: + // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} + if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) { + continue + } + inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens") + outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens") + if inputTokenObj.Exists() && outputTokenObj.Exists() { + inputTokenUsage = inputTokenObj.Int() + outputTokenUsage = outputTokenObj.Int() + ok = true + return + } } + return } func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) { diff --git a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go index be9144adf..8b342d57b 100644 --- a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go @@ -45,6 +45,8 @@ type HttpContext interface { GetStringContext(key, defaultValue string) string GetUserAttribute(key string) interface{} SetUserAttribute(key string, value interface{}) + SetUserAttributeMap(kvmap map[string]interface{}) + GetUserAttributeMap() map[string]interface{} // You can call this function to set custom log WriteUserAttributeToLog() error // You can call this function to set custom log with your specific key @@ -403,6 +405,14 @@ func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttribute(key string) interface{} return ctx.userAttribute[key] } +func (ctx *CommonHttpCtx[PluginConfig]) SetUserAttributeMap(kvmap map[string]interface{}) { + ctx.userAttribute = kvmap +} + +func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttributeMap() map[string]interface{} { + return ctx.userAttribute +} + func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLog() error { return ctx.WriteUserAttributeToLogWithKey(CustomLogKey) }