From 24dca0455ec09481827a1ba9919b20e2ebe13dc1 Mon Sep 17 00:00:00 2001 From: Kent Dong Date: Fri, 15 Aug 2025 17:40:13 +0800 Subject: [PATCH] fix: Fix bugs in the bedrock model name escaping logic (#2663) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 澄潭 --- .../extensions/ai-proxy/provider/bedrock.go | 25 +++++++++++-------- .../wasm-go/extensions/ai-proxy/util/http.go | 5 ---- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index 0cba0cbfc..fee47046e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -22,8 +22,6 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) const ( @@ -591,11 +589,6 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName } func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { - if gjson.GetBytes(body, "model").Exists() { - rawModel := gjson.GetBytes(body, "model").String() - encodedModel := url.QueryEscape(rawModel) - body, _ = sjson.SetBytes(body, "model", encodedModel) - } switch apiName { case ApiNameChatCompletion: return b.onChatCompletionRequestBody(ctx, body, headers) @@ -633,7 +626,7 @@ func (b *bedrockProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, return nil, err } headers.Set("Accept", "*/*") - util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, request.Model)) + b.overwriteRequestPathHeader(headers, bedrockInvokeModelPath, request.Model) return b.buildBedrockImageGenerationRequest(request, headers) } @@ -657,7 +650,6 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG Quality: origRequest.Quality, }, } - util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, origRequest.Model)) requestBytes, err := json.Marshal(request) b.setAuthHeaders(requestBytes, headers) return requestBytes, err @@ -696,9 +688,9 @@ func (b *bedrockProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, b streaming := request.Stream headers.Set("Accept", "*/*") if streaming { - util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockStreamChatCompletionPath, request.Model)) + b.overwriteRequestPathHeader(headers, bedrockStreamChatCompletionPath, request.Model) } else { - util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockChatCompletionPath, request.Model)) + b.overwriteRequestPathHeader(headers, bedrockChatCompletionPath, request.Model) } return b.buildBedrockTextGenerationRequest(request, headers) } @@ -770,6 +762,17 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b } } +func (b *bedrockProvider) overwriteRequestPathHeader(headers http.Header, format, model string) { + modelInPath := model + // Just in case the model name has already been URL-escaped, we shouldn't escape it again. + if !strings.ContainsRune(model, '%') { + modelInPath = url.QueryEscape(model) + } + path := fmt.Sprintf(format, modelInPath) + log.Debugf("overwriting bedrock request path: %s", path) + util.OverwriteRequestPathHeader(headers, path) +} + func stopReasonBedrock2OpenAI(reason string) string { switch reason { case "end_turn": diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 1645f1794..0266f245d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -167,11 +167,6 @@ func SetOriginalRequestAuth(auth string) { } func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) { - if exist := headers.Get(HeaderOriginalAuth); exist == "" { - if originAuth := headers.Get(HeaderAuthorization); originAuth != "" { - headers.Set(HeaderOriginalAuth, originAuth) - } - } headers.Set(HeaderAuthorization, credential) }