diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index 84a03dee0..c8ff88559 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -40,6 +40,11 @@ const ( requestIdHeader = "X-Amzn-Requestid" ) +var ( + bedrockConversePathPattern = regexp.MustCompile(`/model/[^/]+/converse(-stream)?$`) + bedrockInvokePathPattern = regexp.MustCompile(`/model/[^/]+/invoke(-with-response-stream)?$`) +) + type bedrockProviderInitializer struct{} func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error { @@ -630,13 +635,24 @@ func (b *bedrockProvider) GetProviderType() string { return providerTypeBedrock } +func (b *bedrockProvider) GetApiName(path string) ApiName { + switch { + case bedrockConversePathPattern.MatchString(path): + return ApiNameChatCompletion + case bedrockInvokePathPattern.MatchString(path): + return ApiNameImageGeneration + default: + return "" + } +} + func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { b.config.handleRequestHeaders(b, ctx, apiName) return nil } func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { - util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion)) + util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, strings.TrimSpace(b.config.awsRegion))) // If apiTokens is configured, set Bearer token authentication here // This follows the same pattern as other providers (qwen, zhipuai, etc.) @@ -647,6 +663,15 @@ func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa } func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { + // In original protocol mode (e.g. /model/{modelId}/converse-stream), keep the body/path untouched + // and only apply auth headers. + if b.config.IsOriginal() { + headers := util.GetRequestHeaders() + b.setAuthHeaders(body, headers) + util.ReplaceRequestHeaders(headers) + return types.ActionContinue, replaceRequestBody(body) + } + if !b.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } @@ -654,14 +679,25 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName } func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { + var transformedBody []byte + var err error switch apiName { case ApiNameChatCompletion: - return b.onChatCompletionRequestBody(ctx, body, headers) + transformedBody, err = b.onChatCompletionRequestBody(ctx, body, headers) case ApiNameImageGeneration: - return b.onImageGenerationRequestBody(ctx, body, headers) + transformedBody, err = b.onImageGenerationRequestBody(ctx, body, headers) default: - return b.config.defaultTransformRequestBody(ctx, apiName, body) + transformedBody, err = b.config.defaultTransformRequestBody(ctx, apiName, body) } + + if err != nil { + return nil, err + } + + // Always apply auth after request body/path are finalized. + // For Bearer token mode this is a no-op; for AK/SK mode this generates SigV4 headers. + b.setAuthHeaders(transformedBody, headers) + return transformedBody, nil } func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { @@ -715,9 +751,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG Quality: origRequest.Quality, }, } - requestBytes, err := json.Marshal(request) - b.setAuthHeaders(requestBytes, headers) - return requestBytes, err + return json.Marshal(request) } func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse { @@ -847,9 +881,7 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom request.AdditionalModelRequestFields[key] = value } - requestBytes, err := json.Marshal(request) - b.setAuthHeaders(requestBytes, headers) - return requestBytes, err + return json.Marshal(request) } func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse { @@ -1163,45 +1195,88 @@ func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) { } // Use AWS Signature V4 authentication + accessKey := strings.TrimSpace(b.config.awsAccessKey) + region := strings.TrimSpace(b.config.awsRegion) t := time.Now().UTC() amzDate := t.Format("20060102T150405Z") dateStamp := t.Format("20060102") path := headers.Get(":path") signature := b.generateSignature(path, amzDate, dateStamp, body) headers.Set("X-Amz-Date", amzDate) - util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature)) + util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", accessKey, dateStamp, region, awsService, bedrockSignedHeaders, signature)) } func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string { - path = encodeSigV4Path(path) + canonicalURI := encodeSigV4Path(path) hashedPayload := sha256Hex(body) + region := strings.TrimSpace(b.config.awsRegion) + secretKey := strings.TrimSpace(b.config.awsSecretKey) - endpoint := fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion) + endpoint := fmt.Sprintf(bedrockDefaultDomain, region) canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate) canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s", - httpPostMethod, path, canonicalHeaders, bedrockSignedHeaders, hashedPayload) + httpPostMethod, canonicalURI, canonicalHeaders, bedrockSignedHeaders, hashedPayload) - credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, b.config.awsRegion, awsService) + credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, region, awsService) hashedCanonReq := sha256Hex([]byte(canonicalRequest)) stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s", amzDate, credentialScope, hashedCanonReq) - signingKey := getSignatureKey(b.config.awsSecretKey, dateStamp, b.config.awsRegion, awsService) + signingKey := getSignatureKey(secretKey, dateStamp, region, awsService) signature := hmacHex(signingKey, stringToSign) return signature } func encodeSigV4Path(path string) string { + // Keep only the URI path for canonical URI. Query string is handled separately in SigV4, + // and this implementation uses an empty canonical query string. + if queryIndex := strings.Index(path, "?"); queryIndex >= 0 { + path = path[:queryIndex] + } + segments := strings.Split(path, "/") for i, seg := range segments { if seg == "" { continue } - segments[i] = url.PathEscape(seg) + // Normalize to "single-encoded" form: + // - raw ":" -> %3A + // - already encoded "%3A" -> still %3A (not %253A) + decoded, err := url.PathUnescape(seg) + if err == nil { + segments[i] = sigV4EscapePathSegment(decoded) + } else { + // If segment has invalid escape sequence, fall back to escaping raw segment. + segments[i] = sigV4EscapePathSegment(seg) + } } return strings.Join(segments, "/") } +func sigV4EscapePathSegment(segment string) string { + const upperHex = "0123456789ABCDEF" + var b strings.Builder + b.Grow(len(segment) * 3) + for i := 0; i < len(segment); i++ { + c := segment[i] + if isSigV4Unreserved(c) { + b.WriteByte(c) + continue + } + b.WriteByte('%') + b.WriteByte(upperHex[c>>4]) + b.WriteByte(upperHex[c&0x0F]) + } + return b.String() +} + +func isSigV4Unreserved(c byte) bool { + return (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || + c == '-' || c == '_' || c == '.' || c == '~' +} + func getSignatureKey(key, dateStamp, region, service string) []byte { kDate := hmacSha256([]byte("AWS4"+key), dateStamp) kRegion := hmacSha256(kDate, region) diff --git a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go index c4d9d4c23..6a4b17e62 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go @@ -25,6 +25,76 @@ var basicBedrockConfig = func() json.RawMessage { return data }() +// Test config: Bedrock original protocol config with AWS Access Key/Secret Key +var bedrockOriginalAkSkConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "bedrock", + "protocol": "original", + "awsAccessKey": "test-ak-for-unit-test", + "awsSecretKey": "test-sk-for-unit-test", + "awsRegion": "us-east-1", + }, + }) + return data +}() + +// Test config: Bedrock original protocol config with api token +var bedrockOriginalApiTokenConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "bedrock", + "protocol": "original", + "awsRegion": "us-east-1", + "apiTokens": []string{ + "test-token-for-unit-test", + }, + }, + }) + return data +}() + +// Test config: Bedrock original protocol config with AWS Access Key/Secret Key and custom settings +var bedrockOriginalAkSkWithCustomSettingsConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "bedrock", + "protocol": "original", + "awsAccessKey": "test-ak-for-unit-test", + "awsSecretKey": "test-sk-for-unit-test", + "awsRegion": "us-east-1", + "customSettings": []map[string]interface{}{ + { + "name": "foo", + "value": "\"bar\"", + "mode": "raw", + "overwrite": true, + }, + }, + }, + }) + return data +}() + +// Test config: Bedrock config with embeddings capability to verify generic SigV4 flow +var bedrockEmbeddingsCapabilityConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "bedrock", + "awsAccessKey": "test-ak-for-unit-test", + "awsSecretKey": "test-sk-for-unit-test", + "awsRegion": "us-east-1", + "capabilities": map[string]string{ + "openai/v1/embeddings": "/model/amazon.titan-embed-text-v2:0/invoke", + }, + "modelMapping": map[string]string{ + "*": "amazon.titan-embed-text-v2:0", + }, + }, + }) + return data +}() + // Test config: Bedrock config with Bearer Token authentication var bedrockApiTokenConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ @@ -352,6 +422,169 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint") }) + // Test Bedrock generic request body processing with AWS Signature V4 authentication + t.Run("bedrock embeddings request body with ak/sk should use sigv4", func(t *testing.T) { + host, status := test.NewTestHost(bedrockEmbeddingsCapabilityConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "text-embedding-3-small", + "input": "Hello from embeddings" + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature") + require.Contains(t, authValue, "Credential=", "Authorization should contain Credential") + require.Contains(t, authValue, "Signature=", "Authorization should contain Signature") + + dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date") + require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4") + require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty") + }) + + // Test Bedrock original converse-stream path with AWS Signature V4 authentication + t.Run("bedrock original converse-stream with ak/sk should use sigv4", func(t *testing.T) { + host, status := test.NewTestHost(bedrockOriginalAkSkConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + originalPath := "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream" + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", originalPath}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "messages": [ + { + "role": "user", + "content": [{"text": "Hello from original bedrock path"}] + } + ], + "inferenceConfig": { + "maxTokens": 64 + } + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature") + require.Contains(t, authValue, "Credential=", "Authorization should contain Credential") + require.Contains(t, authValue, "Signature=", "Authorization should contain Signature") + + dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date") + require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4") + require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty") + + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath, "Path header should exist") + require.Equal(t, originalPath, pathValue, "Original Bedrock path should be kept unchanged") + }) + + // Test Bedrock original converse-stream path with Bearer Token authentication + t.Run("bedrock original converse-stream with api token should pass bearer auth", func(t *testing.T) { + host, status := test.NewTestHost(bedrockOriginalApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + originalPath := "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream" + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", originalPath}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "messages": [ + { + "role": "user", + "content": [{"text": "Hello from original bedrock path"}] + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.Contains(t, authValue, "Bearer ", "Authorization should use Bearer token") + require.Contains(t, authValue, "test-token-for-unit-test", "Authorization should contain configured token") + + _, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date") + require.False(t, hasDate, "X-Amz-Date should not be set in Bearer token mode") + + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath, "Path header should exist") + require.Equal(t, originalPath, pathValue, "Original Bedrock path should be kept unchanged") + }) + + // Test Bedrock original converse-stream path keeps signed body consistent with custom settings + t.Run("bedrock original converse-stream with custom settings should replace body before forwarding", func(t *testing.T) { + host, status := test.NewTestHost(bedrockOriginalAkSkWithCustomSettingsConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + originalPath := "/model/amazon.nova-2-lite-v1:0/converse-stream" + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", originalPath}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "messages": [ + { + "role": "user", + "content": [{"text": "Hello"}] + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + require.Equal(t, "\"bar\"", bodyMap["foo"], "Custom settings should be applied to forwarded body") + + authValue, hasAuth := test.GetHeaderValue(host.GetRequestHeaders(), "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature") + }) + // Test Bedrock streaming request t.Run("bedrock streaming request", func(t *testing.T) { host, status := test.NewTestHost(bedrockApiTokenConfig)