diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index d0a4505ab..d2caa2433 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -41,6 +41,7 @@ description: AI 代理插件配置参考 | `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | | `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | | `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | +| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 | `context`的配置字段说明如下: @@ -78,14 +79,22 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字 `failover` 的配置字段说明如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -|------------------|--------|------|-------|-----------------------------| -| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 | -| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) | -| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) | -| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 | -| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 | -| healthCheckModel | string | 必填 | | 健康检测使用的模型 | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|------------------|--------|-----------------|-------|-----------------------------| +| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 | +| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) | +| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) | +| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 | +| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 | +| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 | + +`retryOnFailure` 的配置字段说明如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|------------------|--------|-----------------|-------|-------------| +| enabled | bool | 非必填 | false | 是否启用失败请求重试 | +| maxRetries | int | 非必填 | 1 | 最大重试次数 | +| retryTimeout | int | 非必填 | 5000 | 重试超时时间,单位毫秒 | ### 提供商特有配置 diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 3f4dc49ba..6c7756e45 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -20,8 +20,6 @@ import ( const ( pluginName = "ai-proxy" - ctxKeyApiName = "apiName" - defaultMaxBodyBytes uint32 = 10 * 1024 * 1024 ) @@ -92,14 +90,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path) return types.ActionContinue } + + ctx.SetContext(provider.CtxKeyApiName, apiName) // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. ctx.DisableReroute() - ctx.SetContext(ctxKeyApiName, apiName) - - _, needHandleBody := activeProvider.(provider.ResponseBodyHandler) _, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler) - if needHandleBody || needHandleStreamingBody { + if needHandleStreamingBody { proxywasm.RemoveHttpRequestHeader("Accept-Encoding") } @@ -138,7 +135,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType()) if handler, ok := activeProvider.(provider.RequestBodyHandler); ok { - apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) + apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body) if settingErr != nil { @@ -186,32 +183,25 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo log.Errorf("unable to load :status header from response: %v", err) } ctx.DontReadResponseBody() - providerConfig.OnRequestFailed(ctx, apiTokenInUse, log) - - return types.ActionContinue + return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, log) } // Reset ctxApiTokenRequestFailureCount if the request is successful, // the apiToken is removed only when the number of consecutive request failures exceeds the threshold. providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log) - if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok { - apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) - action, err := handler.OnResponseHeaders(ctx, apiName, log) - if err == nil { - checkStream(&ctx, log) - return action - } - util.ErrorHandler("ai-proxy.proc_resp_headers_failed", fmt.Errorf("failed to process response headers: %v", err)) - return types.ActionContinue + headers := util.GetOriginalResponseHeaders() + if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok { + apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) + handler.TransformResponseHeaders(ctx, apiName, headers, log) + } else { + providerConfig.DefaultTransformResponseHeaders(ctx, headers) } + util.ReplaceResponseHeaders(headers) checkStream(&ctx, log) - _, needHandleBody := activeProvider.(provider.ResponseBodyHandler) _, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler) - if !needHandleBody && !needHandleStreamingBody { - ctx.DontReadResponseBody() - } else if !needHandleStreamingBody { + if !needHandleStreamingBody { ctx.BufferResponseBody() } @@ -230,7 +220,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk)) if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok { - apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) + apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log) if err == nil && modifiedChunk != nil { return modifiedChunk @@ -249,16 +239,17 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi } log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType()) - //log.Debugf("response body: %s", string(body)) - if handler, ok := activeProvider.(provider.ResponseBodyHandler); ok { - apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) - action, err := handler.OnResponseBody(ctx, apiName, body, log) - if err == nil { - return action + if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok { + apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) + body, err := handler.TransformResponseBody(ctx, apiName, body, log) + if err != nil { + util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err)) + return types.ActionContinue + } + if err = provider.ReplaceResponseBody(body, log); err != nil { + util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err)) } - util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err)) - return types.ActionContinue } return types.ActionContinue } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 994346974..ac7322322 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -10,7 +10,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -139,27 +138,16 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A return json.Marshal(claudeRequest) } -func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { +func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { claudeResponse := &claudeTextGenResponse{} if err := json.Unmarshal(body, claudeResponse); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal claude response: %v", err) + return nil, fmt.Errorf("unable to unmarshal claude response: %v", err) } if claudeResponse.Error != nil { - return types.ActionContinue, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message) + return nil, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message) } response := c.responseClaude2OpenAI(ctx, claudeResponse) - return types.ActionContinue, replaceJsonResponseBody(response, log) -} - -func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - // use original protocol, skip OnStreamingResponseBody() and OnResponseBody() - if c.config.protocol == protocolOriginal { - ctx.DontReadResponseBody() - return types.ActionContinue, nil - } - - _ = proxywasm.RemoveHttpResponseHeader("Content-Length") - return types.ActionContinue, nil + return json.Marshal(response) } func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/context.go b/plugins/wasm-go/extensions/ai-proxy/provider/context.go index 9ba64ee5a..fb38f86c5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/context.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/context.go @@ -151,7 +151,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo if err != nil { util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err)) } - if err := replaceHttpJsonRequestBody(body, log); err != nil { + if err := replaceRequestBody(body, log); err != nil { util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err)) } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 345a70c94..d7d10e75a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -10,7 +10,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -112,18 +111,13 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api return json.Marshal(baiduRequest) } -func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - _ = proxywasm.RemoveHttpResponseHeader("Content-Length") - return types.ActionContinue, nil -} - -func (d *deeplProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { +func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { deeplResponse := &deeplResponse{} if err := json.Unmarshal(body, deeplResponse); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal deepl response: %v", err) + return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err) } response := d.responseDeepl2OpenAI(ctx, deeplResponse) - return types.ActionContinue, replaceJsonResponseBody(response, log) + return json.Marshal(response) } func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 56b03fbd7..e1b0b9819 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -19,7 +19,7 @@ import ( type failover struct { // @Title zh-CN 是否启用 apiToken 的 failover 机制 - enabled bool `required:"true" yaml:"enabled" json:"enabled"` + enabled bool `required:"false" yaml:"enabled" json:"enabled"` // @Title zh-CN 触发 failover 连续请求失败的阈值 failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"` // @Title zh-CN 健康检测的成功阈值 @@ -29,7 +29,7 @@ type failover struct { // @Title zh-CN 健康检测的超时时间,单位毫秒 healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"` // @Title zh-CN 健康检测使用的模型 - healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` + healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"` // @Title zh-CN 本次请求使用的 apiToken ctxApiTokenInUse string // @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数 @@ -184,9 +184,9 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, if handler, ok := activeProvider.(TransformRequestBodyHandler); ok { body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log) } else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok { - headers := util.GetOriginalHttpHeaders() + headers := util.GetOriginalRequestHeaders() body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log) - util.ReplaceOriginalHttpHeaders(headers) + util.ReplaceRequestHeaders(headers) } else { body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log) } @@ -539,10 +539,15 @@ func (c *ProviderConfig) resetSharedData() { _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0) } -func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) { +func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) types.Action { if c.isFailoverEnabled() { c.handleUnavailableApiToken(ctx, apiTokenInUse, log) } + if c.isRetryOnFailureEnabled() && ctx.GetContext(ctxKeyIsStreaming) != nil && !ctx.GetContext(ctxKeyIsStreaming).(bool) { + c.retryFailedRequest(activeProvider, ctx, log) + return types.HeaderStopAllIterationAndWatermark + } + return types.ActionContinue } func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { @@ -557,7 +562,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L } else { apiToken = c.GetRandomToken() } - log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken) + log.Debugf("Use apiToken %s to send request", apiToken) ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 7a9b0a3dd..526b55145 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -105,16 +105,6 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [ return json.Marshal(geminiRequest) } -func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - if g.config.protocol == protocolOriginal { - ctx.DontReadResponseBody() - return types.ActionContinue, nil - } - - _ = proxywasm.RemoveHttpResponseHeader("Content-Length") - return types.ActionContinue, nil -} - func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { log.Infof("chunk body:%s", string(chunk)) if isLastChunk || len(chunk) == 0 { @@ -148,39 +138,38 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A return []byte(modifiedResponseChunk), nil } -func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { +func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { if apiName == ApiNameChatCompletion { return g.onChatCompletionResponseBody(ctx, body, log) - } else if apiName == ApiNameEmbeddings { + } else { return g.onEmbeddingsResponseBody(ctx, body, log) } - return types.ActionContinue, errUnsupportedApiName } -func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { +func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { geminiResponse := &geminiChatResponse{} if err := json.Unmarshal(body, geminiResponse); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini chat response: %v", err) + return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err) } if geminiResponse.Error != nil { - return types.ActionContinue, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s", + return nil, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s", geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message) } response := g.buildChatCompletionResponse(ctx, geminiResponse) - return types.ActionContinue, replaceJsonResponseBody(response, log) + return json.Marshal(response) } -func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { +func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { geminiResponse := &geminiEmbeddingResponse{} if err := json.Unmarshal(body, geminiResponse); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err) + return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err) } if geminiResponse.Error != nil { - return types.ActionContinue, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s", + return nil, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s", geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message) } response := g.buildEmbeddingsResponse(ctx, geminiResponse) - return types.ActionContinue, replaceJsonResponseBody(response, log) + return json.Marshal(response) } func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 4b10a4d7c..7ed728a21 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -288,11 +288,6 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a return json.Marshal(hunyuanRequest) } -func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - _ = proxywasm.RemoveHttpResponseHeader("Content-Length") - return types.ActionContinue, nil -} - func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { if m.config.protocol == protocolOriginal { return chunk, nil @@ -409,21 +404,14 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex return []byte(openAIChunk.String()), nil } -func (m *hunyuanProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - +func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body)) hunyuanResponse := &hunyuanTextGenResponseNonStreaming{} if err := json.Unmarshal(body, hunyuanResponse); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal hunyuan response: %v", err) + return nil, fmt.Errorf("unable to unmarshal hunyuan response: %v", err) } - - if m.config.protocol == protocolOriginal { - return types.ActionContinue, replaceJsonResponseBody(hunyuanResponse, log) - } - response := m.buildChatCompletionResponse(ctx, hunyuanResponse) - - return types.ActionContinue, replaceJsonResponseBody(response, log) + return json.Marshal(response) } func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 9531edcf1..a57fe23c3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -144,19 +144,16 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, heade return sjson.SetBytes(body, "model", mappedModel) } -func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - // Skip OnStreamingResponseBody() and OnResponseBody() when using original protocol. +// Skip OnStreamingResponseBody() and OnResponseBody() when using original protocol. +func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { if m.config.protocol == protocolOriginal { ctx.DontReadResponseBody() - return types.ActionContinue, nil } + // Skip OnStreamingResponseBody() and OnResponseBody() when the model corresponds to the chat completion V2 interface. if minimaxApiTypePro != m.config.minimaxApiType { ctx.DontReadResponseBody() - return types.ActionContinue, nil } - _ = proxywasm.RemoveHttpResponseHeader("Content-Length") - return types.ActionContinue, nil } // OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API. @@ -196,16 +193,16 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name } // OnResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API. -func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { +func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { minimaxResp := &minimaxChatCompletionV2Resp{} if err := json.Unmarshal(body, minimaxResp); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal minimax response: %v", err) + return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err) } if minimaxResp.BaseResp.StatusCode != 0 { - return types.ActionContinue, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg) + return nil, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg) } response := m.responseV2ToOpenAI(minimaxResp) - return types.ActionContinue, replaceJsonResponseBody(response, log) + return json.Marshal(response) } // minimaxChatCompletionV2Request represents the structure of a chat completion V2 request. diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 0dc70428f..ea0870e67 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -59,7 +59,9 @@ const ( finishReasonLength = "length" ctxKeyIncrementalStreaming = "incrementalStreaming" - ctxKeyApiName = "apiKey" + ctxKeyApiKey = "apiKey" + CtxKeyApiName = "apiName" + ctxKeyIsStreaming = "isStreaming" ctxKeyStreamingBody = "streamingBody" ctxKeyOriginalRequestModel = "originalRequestModel" ctxKeyFinalRequestModel = "finalRequestModel" @@ -115,22 +117,26 @@ type Provider interface { GetProviderType() string } -type ApiNameHandler interface { - GetApiName(path string) ApiName -} - type RequestHeadersHandler interface { OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error } -type TransformRequestHeadersHandler interface { - TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) -} - type RequestBodyHandler interface { OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) } +type StreamingResponseBodyHandler interface { + OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) +} + +type ApiNameHandler interface { + GetApiName(path string) ApiName +} + +type TransformRequestHeadersHandler interface { + TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) +} + type TransformRequestBodyHandler interface { TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) } @@ -141,16 +147,12 @@ type TransformRequestBodyHeadersHandler interface { TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) } -type ResponseHeadersHandler interface { - OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) +type TransformResponseHeadersHandler interface { + TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) } -type StreamingResponseBodyHandler interface { - OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) -} - -type ResponseBodyHandler interface { - OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) +type TransformResponseBodyHandler interface { + TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) } // TickFuncHandler allows the provider to execute a function periodically @@ -175,6 +177,9 @@ type ProviderConfig struct { // @Title zh-CN apiToken 故障切换 // @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表 failover *failover `required:"false" yaml:"failover" json:"failover"` + // @Title zh-CN 失败请求重试 + // @Description zh-CN 对失败的请求立即进行重试 + retryOnFailure *retryOnFailure `required:"false" yaml:"retryOnFailure" json:"retryOnFailure"` // @Title zh-CN 基于OpenAI协议的自定义后端URL // @Description zh-CN 仅适用于支持 openai 协议的服务。 openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"` @@ -352,6 +357,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.failover.FromJson(failoverJson) } + retryOnFailureJson := json.Get("retryOnFailure") + c.retryOnFailure = &retryOnFailure{ + enabled: false, + } + if retryOnFailureJson.Exists() { + c.retryOnFailure.FromJson(retryOnFailureJson) + } + for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() { c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String()) } @@ -399,10 +412,10 @@ func (c *ProviderConfig) Validate() error { } func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string { - ctxApiKey := ctx.GetContext(ctxKeyApiName) + ctxApiKey := ctx.GetContext(ctxKeyApiKey) if ctxApiKey == nil { ctxApiKey = c.GetRandomToken() - ctx.SetContext(ctxKeyApiName, ctxApiKey) + ctx.SetContext(ctxKeyApiKey, ctxApiKey) } return ctxApiKey.(string) } @@ -446,6 +459,9 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques streaming := req.Stream if streaming { _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") + ctx.SetContext(ctxKeyIsStreaming, true) + } else { + ctx.SetContext(ctxKeyIsStreaming, false) } return c.setRequestModel(ctx, req, log) @@ -540,9 +556,9 @@ func (c *ProviderConfig) handleRequestBody( if handler, ok := provider.(TransformRequestBodyHandler); ok { body, err = handler.TransformRequestBody(ctx, apiName, body, log) } else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok { - headers := util.GetOriginalHttpHeaders() + headers := util.GetOriginalRequestHeaders() body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log) - util.ReplaceOriginalHttpHeaders(headers) + util.ReplaceRequestHeaders(headers) } else { body, err = c.defaultTransformRequestBody(ctx, apiName, body, log) } @@ -551,9 +567,14 @@ func (c *ProviderConfig) handleRequestBody( return types.ActionContinue, err } + // If retryOnFailure is enabled, save the transformed body to the context in case of retry + if c.isRetryOnFailureEnabled() { + ctx.SetContext(ctxRequestBody, body) + } + if apiName == ApiNameChatCompletion { if c.context == nil { - return types.ActionContinue, replaceHttpJsonRequestBody(body, log) + return types.ActionContinue, replaceRequestBody(body, log) } err = contextCache.GetContextFromFile(ctx, provider, body, log) @@ -562,14 +583,14 @@ func (c *ProviderConfig) handleRequestBody( } return types.ActionContinue, err } - return types.ActionContinue, replaceHttpJsonRequestBody(body, log) + return types.ActionContinue, replaceRequestBody(body, log) } func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) { + headers := util.GetOriginalRequestHeaders() if handler, ok := provider.(TransformRequestHeadersHandler); ok { - originalHeaders := util.GetOriginalHttpHeaders() - handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log) - util.ReplaceOriginalHttpHeaders(originalHeaders) + handler.TransformRequestHeaders(ctx, apiName, headers, log) + util.ReplaceRequestHeaders(headers) } } @@ -585,3 +606,11 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap } return json.Marshal(request) } + +func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) { + if c.protocol == protocolOriginal { + ctx.DontReadResponseBody() + } else { + headers.Del("Content-Length") + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 95fe28e4b..fc3d5b77a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -183,16 +183,6 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b return json.Marshal(qwenRequest) } -func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - if m.config.protocol == protocolOriginal { - ctx.DontReadResponseBody() - return types.ActionContinue, nil - } - - _ = proxywasm.RemoveHttpResponseHeader("Content-Length") - return types.ActionContinue, nil -} - func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { if m.config.qwenEnableCompatible || name != ApiNameChatCompletion { return chunk, nil @@ -278,9 +268,9 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api return []byte(modifiedResponseChunk), nil } -func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { +func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { if m.config.qwenEnableCompatible { - return types.ActionContinue, nil + return body, nil } if apiName == ApiNameChatCompletion { return m.onChatCompletionResponseBody(ctx, body, log) @@ -288,25 +278,25 @@ func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, if apiName == ApiNameEmbeddings { return m.onEmbeddingsResponseBody(ctx, body, log) } - return types.ActionContinue, errUnsupportedApiName + return nil, errUnsupportedApiName } -func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { +func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { qwenResponse := &qwenTextGenResponse{} if err := json.Unmarshal(body, qwenResponse); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err) + return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err) } response := m.buildChatCompletionResponse(ctx, qwenResponse) - return types.ActionContinue, replaceJsonResponseBody(response, log) + return json.Marshal(response) } -func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { +func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { qwenResponse := &qwenTextEmbeddingResponse{} if err := json.Unmarshal(body, qwenResponse); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err) + return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err) } response := m.buildEmbeddingsResponse(ctx, qwenResponse) - return types.ActionContinue, replaceJsonResponseBody(response, log) + return json.Marshal(response) } func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go index dd9864702..0018f0bde 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go @@ -37,7 +37,7 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error { return err } -func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error { +func replaceRequestBody(body []byte, log wrapper.Log) error { log.Debugf("request body: %s", string(body)) err := proxywasm.ReplaceHttpRequestBody(body) if err != nil { @@ -65,15 +65,11 @@ func insertContextMessage(request *chatCompletionRequest, content string) { } } -func replaceJsonResponseBody(response interface{}, log wrapper.Log) error { - body, err := json.Marshal(response) - if err != nil { - return fmt.Errorf("unable to marshal response: %v", err) - } +func ReplaceResponseBody(body []byte, log wrapper.Log) error { log.Debugf("response body: %s", string(body)) - err = proxywasm.ReplaceHttpResponseBody(body) + err := proxywasm.ReplaceHttpResponseBody(body) if err != nil { return fmt.Errorf("unable to replace the original response body: %v", err) } - return err + return nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/retry.go b/plugins/wasm-go/extensions/ai-proxy/provider/retry.go new file mode 100644 index 000000000..e88b9cd3d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/retry.go @@ -0,0 +1,141 @@ +package provider + +import ( + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/tidwall/gjson" + "net/http" +) + +const ( + ctxRequestBody = "requestBody" + ctxRetryCount = "retryCount" +) + +type retryOnFailure struct { + // @Title zh-CN 是否启用请求重试 + enabled bool `required:"false" yaml:"enabled" json:"enabled"` + // @Title zh-CN 重试次数 + maxRetries int64 `required:"false" yaml:"maxRetries" json:"maxRetries"` + // @Title zh-CN 重试超时时间 + retryTimeout int64 `required:"false" yaml:"retryTimeout" json:"retryTimeout"` +} + +func (r *retryOnFailure) FromJson(json gjson.Result) { + r.enabled = json.Get("enabled").Bool() + r.maxRetries = json.Get("maxRetries").Int() + if r.maxRetries == 0 { + r.maxRetries = 1 + } + r.retryTimeout = json.Get("retryTimeout").Int() + if r.retryTimeout == 0 { + r.retryTimeout = 5000 + } +} + +func (c *ProviderConfig) isRetryOnFailureEnabled() bool { + return c.retryOnFailure.enabled +} + +func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, log wrapper.Log) { + log.Debugf("Retry failed request: provider=%s", activeProvider.GetProviderType()) + retryClient := createRetryClient(ctx) + apiName, _ := ctx.GetContext(CtxKeyApiName).(ApiName) + ctx.SetContext(ctxRetryCount, 1) + c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log) +} + +func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte, log wrapper.Log) ([][2]string, []byte) { + if handler, ok := activeProvider.(TransformResponseHeadersHandler); ok { + handler.TransformResponseHeaders(ctx, apiName, headers, log) + } else { + c.DefaultTransformResponseHeaders(ctx, headers) + } + + if handler, ok := activeProvider.(TransformResponseBodyHandler); ok { + var err error + body, err = handler.TransformResponseBody(ctx, apiName, body, log) + if err != nil { + log.Errorf("Failed to transform response body: %v", err) + } + } + + return util.HeaderToSlice(headers), body +} + +func (c *ProviderConfig) retryCall( + ctx wrapper.HttpContext, log wrapper.Log, activeProvider Provider, + apiName ApiName, statusCode int, responseHeaders http.Header, responseBody []byte, + retryClient *wrapper.ClusterClient[wrapper.RouteCluster]) { + + retryCount := ctx.GetContext(ctxRetryCount).(int) + log.Debugf("Sent retry request: %d/%d", retryCount, c.retryOnFailure.maxRetries) + + if statusCode == 200 { + log.Debugf("Retry request succeeded") + headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody, log) + proxywasm.SendHttpResponse(200, headers, body, -1) + } else { + log.Debugf("The retry request still failed, status: %d, responseHeaders: %v, responseBody: %s", statusCode, responseHeaders, string(responseBody)) + } + + retryCount++ + if retryCount <= int(c.retryOnFailure.maxRetries) { + ctx.SetContext(ctxRetryCount, retryCount) + c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log) + } else { + log.Debugf("Reached the maximum retry count: %d", c.retryOnFailure.maxRetries) + proxywasm.ResumeHttpResponse() + } +} + +func (c *ProviderConfig) sendRetryRequest( + ctx wrapper.HttpContext, apiName ApiName, activeProvider Provider, + retryClient *wrapper.ClusterClient[wrapper.RouteCluster], log wrapper.Log) { + + requestHeaders, requestBody := c.getRetryRequestHeadersAndBody(ctx, activeProvider, apiName, log) + path := getRetryPath(ctx) + + err := retryClient.Post(path, util.HeaderToSlice(requestHeaders), requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + c.retryCall(ctx, log, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient) + }, uint32(c.retryOnFailure.retryTimeout)) + if err != nil { + log.Errorf("Failed to send retry request: %v", err) + proxywasm.ResumeHttpResponse() + } +} + +func createRetryClient(ctx wrapper.HttpContext) *wrapper.ClusterClient[wrapper.RouteCluster] { + host := wrapper.GetRequestHost() + if host == "" { + host = ctx.GetContext(ctxRequestHost).(string) + } + retryClient := wrapper.NewClusterClient(wrapper.RouteCluster{ + Host: host, + }) + return retryClient +} + +func getRetryPath(ctx wrapper.HttpContext) string { + path := wrapper.GetRequestPath() + if path == "" { + path = ctx.GetContext(ctxRequestPath).(string) + } + return path +} + +func (c *ProviderConfig) getRetryRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, log wrapper.Log) (http.Header, []byte) { + // The retry request may be sent with different apiToken, so the header needs to be regenerated + c.SetApiTokenInUse(ctx, log) + + requestHeaders := http.Header{ + "Content-Type": []string{"application/json"}, + } + if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok { + handler.TransformRequestHeaders(ctx, apiName, requestHeaders, log) + } + requestBody := ctx.GetContext(ctxRequestBody).([]byte) + + return requestHeaders, requestBody +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index f44b9e3c0..6504fb7fb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -9,7 +9,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -82,21 +81,16 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) } -func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - _ = proxywasm.RemoveHttpResponseHeader("Content-Length") - return types.ActionContinue, nil -} - -func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { +func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { sparkResponse := &sparkResponse{} if err := json.Unmarshal(body, sparkResponse); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal spark response: %v", err) + return nil, fmt.Errorf("unable to unmarshal spark response: %v", err) } if sparkResponse.Code != 0 { - return types.ActionContinue, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message) + return nil, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message) } response := p.responseSpark2OpenAI(ctx, sparkResponse) - return types.ActionContinue, replaceJsonResponseBody(response, log) + return json.Marshal(response) } func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index fbc3c6a80..4f36871b7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -86,12 +86,22 @@ func SliceToHeader(slice [][2]string) http.Header { return header } -func GetOriginalHttpHeaders() http.Header { +func GetOriginalRequestHeaders() http.Header { originalHeaders, _ := proxywasm.GetHttpRequestHeaders() return SliceToHeader(originalHeaders) } -func ReplaceOriginalHttpHeaders(headers http.Header) { +func GetOriginalResponseHeaders() http.Header { + originalHeaders, _ := proxywasm.GetHttpResponseHeaders() + return SliceToHeader(originalHeaders) +} + +func ReplaceRequestHeaders(headers http.Header) { modifiedHeaders := HeaderToSlice(headers) _ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders) } + +func ReplaceResponseHeaders(headers http.Header) { + modifiedHeaders := HeaderToSlice(headers) + _ = proxywasm.ReplaceHttpResponseHeaders(modifiedHeaders) +}