diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 68b0b7efb..43961e747 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -97,6 +97,9 @@ func initContext(ctx wrapper.HttpContext) { value, _ := proxywasm.GetHttpRequestHeader(header) ctx.SetContext(ctxKey, value) } + for _, originHeader := range headerToOriginalHeaderMapping { + proxywasm.RemoveHttpRequestHeader(originHeader) + } } func saveContextsToHeaders(ctx wrapper.HttpContext) { @@ -127,6 +130,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType()) + // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. + ctx.DisableReroute() + initContext(ctx) rawPath := ctx.Path() @@ -156,8 +162,6 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf } ctx.SetContext(provider.CtxKeyApiName, apiName) - // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. - ctx.DisableReroute() // Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses, // allowing plugins to inspect or modify the response correctly diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index aa3def679..0cba0cbfc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -19,10 +19,9 @@ import ( "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -591,23 +590,6 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body) } -func (b *bedrockProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) { - request := &bedrockTextGenRequest{} - if err := json.Unmarshal(body, request); err != nil { - return nil, fmt.Errorf("unable to unmarshal request: %v", err) - } - - if len(request.System) > 0 { - request.System = append(request.System, systemContentBlock{Text: content}) - } else { - request.System = []systemContentBlock{{Text: content}} - } - - requestBytes, err := json.Marshal(request) - b.setAuthHeaders(requestBytes, nil) - return requestBytes, err -} - func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { if gjson.GetBytes(body, "model").Exists() { rawModel := gjson.GetBytes(body, "model").String() @@ -906,18 +888,10 @@ func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) { t := time.Now().UTC() amzDate := t.Format("20060102T150405Z") dateStamp := t.Format("20060102") - path, _ := proxywasm.GetHttpRequestHeader(":path") - if headers != nil { - path = headers.Get(":path") - } + path := headers.Get(":path") signature := b.generateSignature(path, amzDate, dateStamp, body) - if headers != nil { - headers.Set("X-Amz-Date", amzDate) - headers.Set("Authorization", fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature)) - } else { - _ = proxywasm.ReplaceHttpRequestHeader("X-Amz-Date", amzDate) - _ = proxywasm.ReplaceHttpRequestHeader("Authorization", fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature)) - } + headers.Set("X-Amz-Date", amzDate) + util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature)) } func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string { diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 70db42fb0..1645f1794 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -59,18 +59,10 @@ func OverwriteRequestPath(path string) error { } func OverwriteRequestAuthorization(credential string) error { - if exist, _ := proxywasm.GetHttpRequestHeader(HeaderOriginalAuth); exist == "" { - if originAuth, err := proxywasm.GetHttpRequestHeader(HeaderAuthorization); err == nil { - _ = proxywasm.AddHttpRequestHeader(HeaderOriginalPath, originAuth) - } - } return proxywasm.ReplaceHttpRequestHeader(HeaderAuthorization, credential) } func OverwriteRequestHostHeader(headers http.Header, host string) { - if originHost, err := proxywasm.GetHttpRequestHeader(HeaderAuthority); err == nil { - headers.Set(HeaderOriginalHost, originHost) - } headers.Set(HeaderAuthority, host) }