Improve ai plugins (#1657)

Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
rinfx
2025-01-09 22:04:51 +08:00
committed by GitHub
parent 2a89c3bb70
commit ea0d5e7564
5 changed files with 31 additions and 34 deletions

View File

@@ -79,11 +79,11 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) {
c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() c.StreamResponseTemplate = json.Get("streamResponseTemplate").String()
if c.StreamResponseTemplate == "" { if c.StreamResponseTemplate == "" {
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
} }
c.ResponseTemplate = json.Get("responseTemplate").String() c.ResponseTemplate = json.Get("responseTemplate").String()
if c.ResponseTemplate == "" { if c.ResponseTemplate == "" {
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
} }
if json.Get("enableSemanticCache").Exists() { if json.Get("enableSemanticCache").Exists() {

View File

@@ -34,7 +34,7 @@ func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log wrapper.
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action { func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action {
templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable") templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable")
if templateEnable != "true" { if templateEnable == "false" {
ctx.DontReadRequestBody() ctx.DontReadRequestBody()
return types.ActionContinue return types.ActionContinue
} }

View File

@@ -4,14 +4,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/google/uuid"
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -551,7 +551,8 @@ func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.Ht
} }
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
return ctx.GetContext(c.failover.ctxApiTokenInUse).(string) token, _ := ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
return token
} }
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {

View File

@@ -41,9 +41,9 @@ const (
LowRisk = "low" LowRisk = "low"
NoRisk = "none" NoRisk = "none"
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}` OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}` OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]` OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
DefaultRequestCheckService = "llm_query_moderation" DefaultRequestCheckService = "llm_query_moderation"
@@ -262,8 +262,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
log.Debugf("checking request body...") log.Debugf("checking request body...")
startTime := time.Now().UnixMilli() startTime := time.Now().UnixMilli()
content := gjson.GetBytes(body, config.requestContentJsonPath).String() content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").String()
ctx.SetContext("requestModel", model)
log.Debugf("Raw request content is: %s", content) log.Debugf("Raw request content is: %s", content)
if len(content) == 0 { if len(content) == 0 {
log.Info("request content is empty. skip") log.Info("request content is empty. skip")
@@ -308,11 +306,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() { } else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID() randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else { } else {
randomID := generateRandomID() randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} }
ctx.DontReadResponseBody() ctx.DontReadResponseBody()
@@ -369,15 +367,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
return types.ActionPause return types.ActionPause
} }
func convertHeaders(hs [][2]string) map[string][]string {
ret := make(map[string][]string)
for _, h := range hs {
k, v := strings.ToLower(h[0]), h[1]
ret[k] = append(ret[k], v)
}
return ret
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkResponse { if !config.checkResponse {
log.Debugf("response checking is disabled") log.Debugf("response checking is disabled")
@@ -398,7 +387,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
startTime := time.Now().UnixMilli() startTime := time.Now().UnixMilli()
contentType, _ := proxywasm.GetHttpResponseHeader("content-type") contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
isStreamingResponse := strings.Contains(contentType, "event-stream") isStreamingResponse := strings.Contains(contentType, "event-stream")
model := ctx.GetStringContext("requestModel", "unknown")
var content string var content string
if isStreamingResponse { if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath) content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
@@ -449,11 +437,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if isStreamingResponse { } else if isStreamingResponse {
randomID := generateRandomID() randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else { } else {
randomID := generateRandomID() randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} }
config.incrementCounter("ai_sec_response_deny", 1) config.incrementCounter("ai_sec_response_deny", 1)

View File

@@ -36,6 +36,7 @@ const (
RouteName = "route" RouteName = "route"
ClusterName = "cluster" ClusterName = "cluster"
APIName = "api" APIName = "api"
ConsumerKey = "x-mse-consumer"
// Source Type // Source Type
FixedValue = "fixed_value" FixedValue = "fixed_value"
@@ -81,8 +82,8 @@ type AIStatisticsConfig struct {
shouldBufferStreamingBody bool shouldBufferStreamingBody bool
} }
func generateMetricName(route, cluster, model, metricName string) string { func generateMetricName(route, cluster, model, consumer, metricName string) string {
return fmt.Sprintf("route.%s.upstream.%s.model.%s.metric.%s", route, cluster, model, metricName) return fmt.Sprintf("route.%s.upstream.%s.model.%s.consumer.%s.metric.%s", route, cluster, model, consumer, metricName)
} }
func getRouteName() (string, error) { func getRouteName() (string, error) {
@@ -115,6 +116,9 @@ func getClusterName() (string, error) {
} }
func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) { func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) {
if inc == 0 {
return
}
counter, ok := config.counterMetrics[metricName] counter, ok := config.counterMetrics[metricName]
if !ok { if !ok {
counter = proxywasm.DefineCounterMetric(metricName) counter = proxywasm.DefineCounterMetric(metricName)
@@ -158,6 +162,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
ctx.SetContext(ClusterName, cluster) ctx.SetContext(ClusterName, cluster)
ctx.SetUserAttribute(APIName, api) ctx.SetUserAttribute(APIName, api)
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli()) ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" {
ctx.SetContext(ConsumerKey, consumer)
}
// Set user defined log & span attributes which type is fixed_value // Set user defined log & span attributes which type is fixed_value
setAttributeBySource(ctx, config, FixedValue, nil, log) setAttributeBySource(ctx, config, FixedValue, nil, log)
@@ -388,6 +395,7 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
var ok bool var ok bool
var route, cluster, model string var route, cluster, model string
var inputToken, outputToken uint64 var inputToken, outputToken uint64
consumer := ctx.GetStringContext(ConsumerKey, "none")
route, ok = ctx.GetContext(RouteName).(string) route, ok = ctx.GetContext(RouteName).(string)
if !ok { if !ok {
log.Warnf("RouteName typd assert failed, skip metric record") log.Warnf("RouteName typd assert failed, skip metric record")
@@ -421,8 +429,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record") log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
return return
} }
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken) config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken)
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken) config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken)
// Generate duration metrics // Generate duration metrics
var llmFirstTokenDuration, llmServiceDuration uint64 var llmFirstTokenDuration, llmServiceDuration uint64
@@ -433,8 +441,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
log.Warnf("LLMFirstTokenDuration typd assert failed") log.Warnf("LLMFirstTokenDuration typd assert failed")
return return
} }
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration) config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMFirstTokenDuration), llmFirstTokenDuration)
config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1) config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMStreamDurationCount), 1)
} }
if ctx.GetUserAttribute(LLMServiceDuration) != nil { if ctx.GetUserAttribute(LLMServiceDuration) != nil {
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration)) llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
@@ -442,8 +450,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
log.Warnf("LLMServiceDuration typd assert failed") log.Warnf("LLMServiceDuration typd assert failed")
return return
} }
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration) config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMServiceDuration), llmServiceDuration)
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1) config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMDurationCount), 1)
} }
} }