From 178755329472ad5d04ad64c173133c93da73f863 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BE=84=E6=BD=AD?= Date: Wed, 26 Feb 2025 16:49:16 +0800 Subject: [PATCH] set include_usage by default for all model providers (#1818) --- plugins/wasm-go/extensions/ai-proxy/main.go | 27 +++++++++++++------ .../extensions/ai-proxy/provider/openai.go | 19 +++++-------- .../extensions/ai-proxy/provider/provider.go | 4 +++ 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 4a2d1fb98..60773f71c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -15,6 +15,7 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) const ( @@ -140,16 +141,14 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig if handler, ok := activeProvider.(provider.RequestBodyHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) - - newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body) + providerConfig := pluginConfig.GetProviderConfig() + newBody, settingErr := providerConfig.ReplaceByCustomSettings(body) if settingErr != nil { - _ = util.ErrorHandler( - "ai-proxy.proc_req_body_failed", - fmt.Errorf("failed to replace request body by custom settings: %v", settingErr), - ) - return types.ActionContinue + log.Errorf("failed to replace request body by custom settings: %v", settingErr) + } + if providerConfig.IsOpenAIProtocol() { + newBody = normalizeOpenAiRequestBody(newBody, log) } - log.Debugf("[onHttpRequestBody] newBody=%s", newBody) body = newBody action, err := handler.OnRequestBody(ctx, apiName, body, log) @@ -297,6 +296,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi return types.ActionContinue } +func normalizeOpenAiRequestBody(body []byte, log wrapper.Log) []byte { + var err error + // Default setting include_usage. + if gjson.GetBytes(body, "stream").Bool() { + body, err = sjson.SetBytes(body, "stream_options.include_usage", true) + if err != nil { + log.Errorf("set include_usage failed, err:%s", err) + } + } + return body +} + func checkStream(ctx wrapper.HttpContext, log wrapper.Log) { contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 2def57aa6..f875dbaa4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -127,21 +127,14 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return nil, err - } if m.config.responseJsonSchema != nil { + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return nil, err + } log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema) request.ResponseFormat = m.config.responseJsonSchema + body, _ = json.Marshal(request) } - if request.Stream { - // For stream requests, we need to include usage in the response. - if request.StreamOptions == nil { - request.StreamOptions = &streamOptions{IncludeUsage: true} - } else if !request.StreamOptions.IncludeUsage { - request.StreamOptions.IncludeUsage = true - } - } - return json.Marshal(request) + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index e3c54bfce..c5ec8ce2d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -292,6 +292,10 @@ func (c *ProviderConfig) GetProtocol() string { return c.protocol } +func (c *ProviderConfig) IsOpenAIProtocol() bool { + return c.protocol == protocolOpenAI +} + func (c *ProviderConfig) FromJson(json gjson.Result) { c.id = json.Get("id").String() c.typ = json.Get("type").String()