diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index fae78d390..125098f52 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -86,10 +86,6 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf return types.ActionContinue } - if _, needHandleBody := activeProvider.(provider.RequestBodyHandler); needHandleBody { - ctx.DontReadRequestBody() - } - return types.ActionContinue } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 36135910f..c05561a78 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -58,13 +58,7 @@ func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam _ = util.OverwriteRequestPath(m.serviceUrl.RequestURI()) _ = util.OverwriteRequestHost(m.serviceUrl.Host) _ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0]) - - if m.contextCache == nil { - ctx.DontReadRequestBody() - } else { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } - + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index 354c879be..486d7400e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -46,13 +46,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api _ = util.OverwriteRequestPath(baichuanChatCompletionPath) _ = util.OverwriteRequestHost(baichuanDomain) _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - - if m.contextCache == nil { - ctx.DontReadRequestBody() - } else { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } - + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 9cb12977f..84af754a5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -48,9 +48,6 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A _ = util.OverwriteRequestHost(cloudflareDomain) _ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetRandomToken()) - if c.config.context == nil && c.config.protocol == protocolOriginal { - ctx.DontReadRequestBody() - } _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index 0b914f587..2680377eb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -46,13 +46,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api _ = util.OverwriteRequestPath(deepseekChatCompletionPath) _ = util.OverwriteRequestHost(deepseekDomain) _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - - if m.contextCache == nil { - ctx.DontReadRequestBody() - } else { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } - + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index dd11fe147..17cf086e2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -44,13 +44,7 @@ func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName _ = util.OverwriteRequestPath(groqChatCompletionPath) _ = util.OverwriteRequestHost(groqDomain) _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - - if m.contextCache == nil { - ctx.DontReadRequestBody() - } else { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } - + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 2f02ba2a4..170990add 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -41,24 +41,16 @@ func (m *openaiProvider) GetProviderType() string { } func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - skipRequestBody := true switch apiName { case ApiNameChatCompletion: _ = util.OverwriteRequestPath(openaiChatCompletionPath) - skipRequestBody = m.contextCache == nil break case ApiNameEmbeddings: _ = util.OverwriteRequestPath(openaiEmbeddingsPath) break } _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - - if skipRequestBody { - ctx.DontReadRequestBody() - } else { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } - + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } @@ -67,13 +59,31 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, // We don't need to process the request body for other APIs. return types.ActionContinue, nil } - if m.contextCache == nil { - return types.ActionContinue, nil - } request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { return types.ActionContinue, err } + bodyAltered := false + if request.Stream { + // For stream requests, we need to include usage in the response. + if request.StreamOptions == nil { + request.StreamOptions = &streamOptions{IncludeUsage: true} + bodyAltered = true + } else if !request.StreamOptions.IncludeUsage { + request.StreamOptions.IncludeUsage = true + bodyAltered = true + } + } + if m.contextCache == nil { + if bodyAltered { + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + } + return types.ActionContinue, nil + } else { + // If context cache is configured and body has been altered, the new body will be replaced when inserting the context data. + } err := m.contextCache.GetContent(func(content string, err error) { defer func() { _ = proxywasm.ResumeHttpRequest() diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 0d4935be3..1625d5557 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -60,10 +60,8 @@ func (m *qwenProvider) GetProviderType() string { } func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - needRequestBody := false if apiName == ApiNameChatCompletion { _ = util.OverwriteRequestPath(qwenChatCompletionPath) - needRequestBody = m.config.context != nil } else if apiName == ApiNameEmbeddings { _ = util.OverwriteRequestPath(qwenTextEmbeddingPath) } else { @@ -72,8 +70,7 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName _ = util.OverwriteRequestHost(qwenDomain) _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - if m.config.protocol == protocolOriginal && !needRequestBody { - ctx.DontReadRequestBody() + if m.config.protocol == protocolOriginal { return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index 5d2961278..1b66eeddb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -44,13 +44,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN _ = util.OverwriteRequestPath(stepfunChatCompletionPath) _ = util.OverwriteRequestHost(stepfunDomain) _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - - if m.contextCache == nil { - ctx.DontReadRequestBody() - } else { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } - + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 626d07498..1839b27d2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -44,13 +44,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, _ = util.OverwriteRequestPath(yiChatCompletionPath) _ = util.OverwriteRequestHost(yiDomain) _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - - if m.contextCache == nil { - ctx.DontReadRequestBody() - } else { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } - + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 5b1e7a597..eeb0412b1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -43,13 +43,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN _ = util.OverwriteRequestPath(zhipuAiChatCompletionPath) _ = util.OverwriteRequestHost(zhipuAiDomain) _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - - if m.contextCache == nil { - ctx.DontReadRequestBody() - } else { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } - + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") return types.ActionContinue, nil }