mirror of
https://github.com/alibaba/higress.git
synced 2026-06-03 17:47:25 +08:00
@@ -79,11 +79,11 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) {
|
|||||||
|
|
||||||
c.StreamResponseTemplate = json.Get("streamResponseTemplate").String()
|
c.StreamResponseTemplate = json.Get("streamResponseTemplate").String()
|
||||||
if c.StreamResponseTemplate == "" {
|
if c.StreamResponseTemplate == "" {
|
||||||
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
|
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
|
||||||
}
|
}
|
||||||
c.ResponseTemplate = json.Get("responseTemplate").String()
|
c.ResponseTemplate = json.Get("responseTemplate").String()
|
||||||
if c.ResponseTemplate == "" {
|
if c.ResponseTemplate == "" {
|
||||||
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||||
}
|
}
|
||||||
|
|
||||||
if json.Get("enableSemanticCache").Exists() {
|
if json.Get("enableSemanticCache").Exists() {
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log wrapper.
|
|||||||
|
|
||||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action {
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action {
|
||||||
templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable")
|
templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable")
|
||||||
if templateEnable != "true" {
|
if templateEnable == "false" {
|
||||||
ctx.DontReadRequestBody()
|
ctx.DontReadRequestBody()
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -551,7 +551,8 @@ func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.Ht
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
||||||
return ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
|
token, _ := ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
|
||||||
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
|
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||||
|
|||||||
@@ -41,9 +41,9 @@ const (
|
|||||||
LowRisk = "low"
|
LowRisk = "low"
|
||||||
NoRisk = "none"
|
NoRisk = "none"
|
||||||
|
|
||||||
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}`
|
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||||
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||||
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
|
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||||
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
|
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
|
||||||
|
|
||||||
DefaultRequestCheckService = "llm_query_moderation"
|
DefaultRequestCheckService = "llm_query_moderation"
|
||||||
@@ -262,8 +262,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
|||||||
log.Debugf("checking request body...")
|
log.Debugf("checking request body...")
|
||||||
startTime := time.Now().UnixMilli()
|
startTime := time.Now().UnixMilli()
|
||||||
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
||||||
model := gjson.GetBytes(body, "model").String()
|
|
||||||
ctx.SetContext("requestModel", model)
|
|
||||||
log.Debugf("Raw request content is: %s", content)
|
log.Debugf("Raw request content is: %s", content)
|
||||||
if len(content) == 0 {
|
if len(content) == 0 {
|
||||||
log.Info("request content is empty. skip")
|
log.Info("request content is empty. skip")
|
||||||
@@ -308,11 +306,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
|||||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||||
randomID := generateRandomID()
|
randomID := generateRandomID()
|
||||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
|
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||||
} else {
|
} else {
|
||||||
randomID := generateRandomID()
|
randomID := generateRandomID()
|
||||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
|
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||||
}
|
}
|
||||||
ctx.DontReadResponseBody()
|
ctx.DontReadResponseBody()
|
||||||
@@ -369,15 +367,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
|||||||
return types.ActionPause
|
return types.ActionPause
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertHeaders(hs [][2]string) map[string][]string {
|
|
||||||
ret := make(map[string][]string)
|
|
||||||
for _, h := range hs {
|
|
||||||
k, v := strings.ToLower(h[0]), h[1]
|
|
||||||
ret[k] = append(ret[k], v)
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||||
if !config.checkResponse {
|
if !config.checkResponse {
|
||||||
log.Debugf("response checking is disabled")
|
log.Debugf("response checking is disabled")
|
||||||
@@ -398,7 +387,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|||||||
startTime := time.Now().UnixMilli()
|
startTime := time.Now().UnixMilli()
|
||||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||||
isStreamingResponse := strings.Contains(contentType, "event-stream")
|
isStreamingResponse := strings.Contains(contentType, "event-stream")
|
||||||
model := ctx.GetStringContext("requestModel", "unknown")
|
|
||||||
var content string
|
var content string
|
||||||
if isStreamingResponse {
|
if isStreamingResponse {
|
||||||
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
|
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
|
||||||
@@ -449,11 +437,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
|||||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||||
} else if isStreamingResponse {
|
} else if isStreamingResponse {
|
||||||
randomID := generateRandomID()
|
randomID := generateRandomID()
|
||||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
|
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||||
} else {
|
} else {
|
||||||
randomID := generateRandomID()
|
randomID := generateRandomID()
|
||||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
|
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||||
}
|
}
|
||||||
config.incrementCounter("ai_sec_response_deny", 1)
|
config.incrementCounter("ai_sec_response_deny", 1)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ const (
|
|||||||
RouteName = "route"
|
RouteName = "route"
|
||||||
ClusterName = "cluster"
|
ClusterName = "cluster"
|
||||||
APIName = "api"
|
APIName = "api"
|
||||||
|
ConsumerKey = "x-mse-consumer"
|
||||||
|
|
||||||
// Source Type
|
// Source Type
|
||||||
FixedValue = "fixed_value"
|
FixedValue = "fixed_value"
|
||||||
@@ -81,8 +82,8 @@ type AIStatisticsConfig struct {
|
|||||||
shouldBufferStreamingBody bool
|
shouldBufferStreamingBody bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateMetricName(route, cluster, model, metricName string) string {
|
func generateMetricName(route, cluster, model, consumer, metricName string) string {
|
||||||
return fmt.Sprintf("route.%s.upstream.%s.model.%s.metric.%s", route, cluster, model, metricName)
|
return fmt.Sprintf("route.%s.upstream.%s.model.%s.consumer.%s.metric.%s", route, cluster, model, consumer, metricName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRouteName() (string, error) {
|
func getRouteName() (string, error) {
|
||||||
@@ -115,6 +116,9 @@ func getClusterName() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) {
|
func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) {
|
||||||
|
if inc == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
counter, ok := config.counterMetrics[metricName]
|
counter, ok := config.counterMetrics[metricName]
|
||||||
if !ok {
|
if !ok {
|
||||||
counter = proxywasm.DefineCounterMetric(metricName)
|
counter = proxywasm.DefineCounterMetric(metricName)
|
||||||
@@ -158,6 +162,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
|
|||||||
ctx.SetContext(ClusterName, cluster)
|
ctx.SetContext(ClusterName, cluster)
|
||||||
ctx.SetUserAttribute(APIName, api)
|
ctx.SetUserAttribute(APIName, api)
|
||||||
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
|
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
|
||||||
|
if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" {
|
||||||
|
ctx.SetContext(ConsumerKey, consumer)
|
||||||
|
}
|
||||||
|
|
||||||
// Set user defined log & span attributes which type is fixed_value
|
// Set user defined log & span attributes which type is fixed_value
|
||||||
setAttributeBySource(ctx, config, FixedValue, nil, log)
|
setAttributeBySource(ctx, config, FixedValue, nil, log)
|
||||||
@@ -388,6 +395,7 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
|
|||||||
var ok bool
|
var ok bool
|
||||||
var route, cluster, model string
|
var route, cluster, model string
|
||||||
var inputToken, outputToken uint64
|
var inputToken, outputToken uint64
|
||||||
|
consumer := ctx.GetStringContext(ConsumerKey, "none")
|
||||||
route, ok = ctx.GetContext(RouteName).(string)
|
route, ok = ctx.GetContext(RouteName).(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warnf("RouteName typd assert failed, skip metric record")
|
log.Warnf("RouteName typd assert failed, skip metric record")
|
||||||
@@ -421,8 +429,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
|
|||||||
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
|
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken)
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken)
|
||||||
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken)
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken)
|
||||||
|
|
||||||
// Generate duration metrics
|
// Generate duration metrics
|
||||||
var llmFirstTokenDuration, llmServiceDuration uint64
|
var llmFirstTokenDuration, llmServiceDuration uint64
|
||||||
@@ -433,8 +441,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
|
|||||||
log.Warnf("LLMFirstTokenDuration typd assert failed")
|
log.Warnf("LLMFirstTokenDuration typd assert failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration)
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMFirstTokenDuration), llmFirstTokenDuration)
|
||||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1)
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMStreamDurationCount), 1)
|
||||||
}
|
}
|
||||||
if ctx.GetUserAttribute(LLMServiceDuration) != nil {
|
if ctx.GetUserAttribute(LLMServiceDuration) != nil {
|
||||||
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
|
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
|
||||||
@@ -442,8 +450,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper
|
|||||||
log.Warnf("LLMServiceDuration typd assert failed")
|
log.Warnf("LLMServiceDuration typd assert failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration)
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMServiceDuration), llmServiceDuration)
|
||||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
|
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMDurationCount), 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user