mirror of
https://github.com/alibaba/higress.git
synced 2026-03-09 03:00:54 +08:00
修复 ai-proxy 插件 Bedrock Provider 在 AWS AK/SK 鉴权模式下仅对部分 API 进行 SigV4 签名的问题 || Fixed the problem of ai-proxy plug-in Bedrock Provider only performing SigV4 signature on some APIs in AWS AK/SK authentication mode (#3549)
This commit is contained in:
@@ -40,6 +40,11 @@ const (
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
)
|
||||
|
||||
var (
|
||||
bedrockConversePathPattern = regexp.MustCompile(`/model/[^/]+/converse(-stream)?$`)
|
||||
bedrockInvokePathPattern = regexp.MustCompile(`/model/[^/]+/invoke(-with-response-stream)?$`)
|
||||
)
|
||||
|
||||
type bedrockProviderInitializer struct{}
|
||||
|
||||
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
@@ -630,13 +635,24 @@ func (b *bedrockProvider) GetProviderType() string {
|
||||
return providerTypeBedrock
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) GetApiName(path string) ApiName {
|
||||
switch {
|
||||
case bedrockConversePathPattern.MatchString(path):
|
||||
return ApiNameChatCompletion
|
||||
case bedrockInvokePathPattern.MatchString(path):
|
||||
return ApiNameImageGeneration
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
b.config.handleRequestHeaders(b, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
|
||||
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, strings.TrimSpace(b.config.awsRegion)))
|
||||
|
||||
// If apiTokens is configured, set Bearer token authentication here
|
||||
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
|
||||
@@ -647,6 +663,15 @@ func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
// In original protocol mode (e.g. /model/{modelId}/converse-stream), keep the body/path untouched
|
||||
// and only apply auth headers.
|
||||
if b.config.IsOriginal() {
|
||||
headers := util.GetRequestHeaders()
|
||||
b.setAuthHeaders(body, headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
return types.ActionContinue, replaceRequestBody(body)
|
||||
}
|
||||
|
||||
if !b.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
@@ -654,14 +679,25 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
var transformedBody []byte
|
||||
var err error
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return b.onChatCompletionRequestBody(ctx, body, headers)
|
||||
transformedBody, err = b.onChatCompletionRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return b.onImageGenerationRequestBody(ctx, body, headers)
|
||||
transformedBody, err = b.onImageGenerationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return b.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
transformedBody, err = b.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Always apply auth after request body/path are finalized.
|
||||
// For Bearer token mode this is a no-op; for AK/SK mode this generates SigV4 headers.
|
||||
b.setAuthHeaders(transformedBody, headers)
|
||||
return transformedBody, nil
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
@@ -715,9 +751,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
|
||||
Quality: origRequest.Quality,
|
||||
},
|
||||
}
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, headers)
|
||||
return requestBytes, err
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||
@@ -847,9 +881,7 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
request.AdditionalModelRequestFields[key] = value
|
||||
}
|
||||
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, headers)
|
||||
return requestBytes, err
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse {
|
||||
@@ -1163,45 +1195,88 @@ func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
|
||||
}
|
||||
|
||||
// Use AWS Signature V4 authentication
|
||||
accessKey := strings.TrimSpace(b.config.awsAccessKey)
|
||||
region := strings.TrimSpace(b.config.awsRegion)
|
||||
t := time.Now().UTC()
|
||||
amzDate := t.Format("20060102T150405Z")
|
||||
dateStamp := t.Format("20060102")
|
||||
path := headers.Get(":path")
|
||||
signature := b.generateSignature(path, amzDate, dateStamp, body)
|
||||
headers.Set("X-Amz-Date", amzDate)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
|
||||
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", accessKey, dateStamp, region, awsService, bedrockSignedHeaders, signature))
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
|
||||
path = encodeSigV4Path(path)
|
||||
canonicalURI := encodeSigV4Path(path)
|
||||
hashedPayload := sha256Hex(body)
|
||||
region := strings.TrimSpace(b.config.awsRegion)
|
||||
secretKey := strings.TrimSpace(b.config.awsSecretKey)
|
||||
|
||||
endpoint := fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion)
|
||||
endpoint := fmt.Sprintf(bedrockDefaultDomain, region)
|
||||
canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate)
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
|
||||
httpPostMethod, path, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
|
||||
httpPostMethod, canonicalURI, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
|
||||
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, b.config.awsRegion, awsService)
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, region, awsService)
|
||||
hashedCanonReq := sha256Hex([]byte(canonicalRequest))
|
||||
stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
|
||||
amzDate, credentialScope, hashedCanonReq)
|
||||
|
||||
signingKey := getSignatureKey(b.config.awsSecretKey, dateStamp, b.config.awsRegion, awsService)
|
||||
signingKey := getSignatureKey(secretKey, dateStamp, region, awsService)
|
||||
signature := hmacHex(signingKey, stringToSign)
|
||||
return signature
|
||||
}
|
||||
|
||||
func encodeSigV4Path(path string) string {
|
||||
// Keep only the URI path for canonical URI. Query string is handled separately in SigV4,
|
||||
// and this implementation uses an empty canonical query string.
|
||||
if queryIndex := strings.Index(path, "?"); queryIndex >= 0 {
|
||||
path = path[:queryIndex]
|
||||
}
|
||||
|
||||
segments := strings.Split(path, "/")
|
||||
for i, seg := range segments {
|
||||
if seg == "" {
|
||||
continue
|
||||
}
|
||||
segments[i] = url.PathEscape(seg)
|
||||
// Normalize to "single-encoded" form:
|
||||
// - raw ":" -> %3A
|
||||
// - already encoded "%3A" -> still %3A (not %253A)
|
||||
decoded, err := url.PathUnescape(seg)
|
||||
if err == nil {
|
||||
segments[i] = sigV4EscapePathSegment(decoded)
|
||||
} else {
|
||||
// If segment has invalid escape sequence, fall back to escaping raw segment.
|
||||
segments[i] = sigV4EscapePathSegment(seg)
|
||||
}
|
||||
}
|
||||
return strings.Join(segments, "/")
|
||||
}
|
||||
|
||||
func sigV4EscapePathSegment(segment string) string {
|
||||
const upperHex = "0123456789ABCDEF"
|
||||
var b strings.Builder
|
||||
b.Grow(len(segment) * 3)
|
||||
for i := 0; i < len(segment); i++ {
|
||||
c := segment[i]
|
||||
if isSigV4Unreserved(c) {
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
b.WriteByte('%')
|
||||
b.WriteByte(upperHex[c>>4])
|
||||
b.WriteByte(upperHex[c&0x0F])
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isSigV4Unreserved(c byte) bool {
|
||||
return (c >= 'A' && c <= 'Z') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '-' || c == '_' || c == '.' || c == '~'
|
||||
}
|
||||
|
||||
func getSignatureKey(key, dateStamp, region, service string) []byte {
|
||||
kDate := hmacSha256([]byte("AWS4"+key), dateStamp)
|
||||
kRegion := hmacSha256(kDate, region)
|
||||
|
||||
@@ -25,6 +25,76 @@ var basicBedrockConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock original protocol config with AWS Access Key/Secret Key
|
||||
var bedrockOriginalAkSkConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"protocol": "original",
|
||||
"awsAccessKey": "test-ak-for-unit-test",
|
||||
"awsSecretKey": "test-sk-for-unit-test",
|
||||
"awsRegion": "us-east-1",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock original protocol config with api token
|
||||
var bedrockOriginalApiTokenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"protocol": "original",
|
||||
"awsRegion": "us-east-1",
|
||||
"apiTokens": []string{
|
||||
"test-token-for-unit-test",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock original protocol config with AWS Access Key/Secret Key and custom settings
|
||||
var bedrockOriginalAkSkWithCustomSettingsConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"protocol": "original",
|
||||
"awsAccessKey": "test-ak-for-unit-test",
|
||||
"awsSecretKey": "test-sk-for-unit-test",
|
||||
"awsRegion": "us-east-1",
|
||||
"customSettings": []map[string]interface{}{
|
||||
{
|
||||
"name": "foo",
|
||||
"value": "\"bar\"",
|
||||
"mode": "raw",
|
||||
"overwrite": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock config with embeddings capability to verify generic SigV4 flow
|
||||
var bedrockEmbeddingsCapabilityConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"awsAccessKey": "test-ak-for-unit-test",
|
||||
"awsSecretKey": "test-sk-for-unit-test",
|
||||
"awsRegion": "us-east-1",
|
||||
"capabilities": map[string]string{
|
||||
"openai/v1/embeddings": "/model/amazon.titan-embed-text-v2:0/invoke",
|
||||
},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "amazon.titan-embed-text-v2:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock config with Bearer Token authentication
|
||||
var bedrockApiTokenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
@@ -352,6 +422,169 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
||||
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
||||
})
|
||||
|
||||
// Test Bedrock generic request body processing with AWS Signature V4 authentication
|
||||
t.Run("bedrock embeddings request body with ak/sk should use sigv4", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockEmbeddingsCapabilityConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "text-embedding-3-small",
|
||||
"input": "Hello from embeddings"
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
|
||||
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
|
||||
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
|
||||
|
||||
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
|
||||
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
|
||||
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
|
||||
})
|
||||
|
||||
// Test Bedrock original converse-stream path with AWS Signature V4 authentication
|
||||
t.Run("bedrock original converse-stream with ak/sk should use sigv4", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockOriginalAkSkConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalPath := "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream"
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", originalPath},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": "Hello from original bedrock path"}]
|
||||
}
|
||||
],
|
||||
"inferenceConfig": {
|
||||
"maxTokens": 64
|
||||
}
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
|
||||
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
|
||||
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
|
||||
|
||||
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
|
||||
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
|
||||
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Equal(t, originalPath, pathValue, "Original Bedrock path should be kept unchanged")
|
||||
})
|
||||
|
||||
// Test Bedrock original converse-stream path with Bearer Token authentication
|
||||
t.Run("bedrock original converse-stream with api token should pass bearer auth", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockOriginalApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalPath := "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse-stream"
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", originalPath},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": "Hello from original bedrock path"}]
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "Bearer ", "Authorization should use Bearer token")
|
||||
require.Contains(t, authValue, "test-token-for-unit-test", "Authorization should contain configured token")
|
||||
|
||||
_, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
|
||||
require.False(t, hasDate, "X-Amz-Date should not be set in Bearer token mode")
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Equal(t, originalPath, pathValue, "Original Bedrock path should be kept unchanged")
|
||||
})
|
||||
|
||||
// Test Bedrock original converse-stream path keeps signed body consistent with custom settings
|
||||
t.Run("bedrock original converse-stream with custom settings should replace body before forwarding", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockOriginalAkSkWithCustomSettingsConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
originalPath := "/model/amazon.nova-2-lite-v1:0/converse-stream"
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", originalPath},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": "Hello"}]
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "\"bar\"", bodyMap["foo"], "Custom settings should be applied to forwarded body")
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(host.GetRequestHeaders(), "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
|
||||
})
|
||||
|
||||
// Test Bedrock streaming request
|
||||
t.Run("bedrock streaming request", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
|
||||
Reference in New Issue
Block a user