修复 ai-proxy 插件 Bedrock Provider 在 AWS AK/SK 鉴权模式下仅对部分 API 进行 SigV4 签名的问题 || Fixed the problem of ai-proxy plug-in Bedrock Provider only performing SigV4 signature on some APIs in AWS AK/SK authentication mode (#3549)

This commit is contained in:
woody
2026-03-02 09:55:31 +08:00
committed by GitHub
parent e2a22d1171
commit c12183cae5
2 changed files with 325 additions and 17 deletions

View File

@@ -40,6 +40,11 @@ const (
requestIdHeader = "X-Amzn-Requestid"
)
var (
bedrockConversePathPattern = regexp.MustCompile(`/model/[^/]+/converse(-stream)?$`)
bedrockInvokePathPattern = regexp.MustCompile(`/model/[^/]+/invoke(-with-response-stream)?$`)
)
type bedrockProviderInitializer struct{}
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
@@ -630,13 +635,24 @@ func (b *bedrockProvider) GetProviderType() string {
return providerTypeBedrock
}
func (b *bedrockProvider) GetApiName(path string) ApiName {
switch {
case bedrockConversePathPattern.MatchString(path):
return ApiNameChatCompletion
case bedrockInvokePathPattern.MatchString(path):
return ApiNameImageGeneration
default:
return ""
}
}
func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
b.config.handleRequestHeaders(b, ctx, apiName)
return nil
}
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, strings.TrimSpace(b.config.awsRegion)))
// If apiTokens is configured, set Bearer token authentication here
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
@@ -647,6 +663,15 @@ func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
}
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
// In original protocol mode (e.g. /model/{modelId}/converse-stream), keep the body/path untouched
// and only apply auth headers.
if b.config.IsOriginal() {
headers := util.GetRequestHeaders()
b.setAuthHeaders(body, headers)
util.ReplaceRequestHeaders(headers)
return types.ActionContinue, replaceRequestBody(body)
}
if !b.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
@@ -654,14 +679,25 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
}
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
var transformedBody []byte
var err error
switch apiName {
case ApiNameChatCompletion:
return b.onChatCompletionRequestBody(ctx, body, headers)
transformedBody, err = b.onChatCompletionRequestBody(ctx, body, headers)
case ApiNameImageGeneration:
return b.onImageGenerationRequestBody(ctx, body, headers)
transformedBody, err = b.onImageGenerationRequestBody(ctx, body, headers)
default:
return b.config.defaultTransformRequestBody(ctx, apiName, body)
transformedBody, err = b.config.defaultTransformRequestBody(ctx, apiName, body)
}
if err != nil {
return nil, err
}
// Always apply auth after request body/path are finalized.
// For Bearer token mode this is a no-op; for AK/SK mode this generates SigV4 headers.
b.setAuthHeaders(transformedBody, headers)
return transformedBody, nil
}
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
@@ -715,9 +751,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
Quality: origRequest.Quality,
},
}
requestBytes, err := json.Marshal(request)
b.setAuthHeaders(requestBytes, headers)
return requestBytes, err
return json.Marshal(request)
}
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
@@ -847,9 +881,7 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
request.AdditionalModelRequestFields[key] = value
}
requestBytes, err := json.Marshal(request)
b.setAuthHeaders(requestBytes, headers)
return requestBytes, err
return json.Marshal(request)
}
func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse {
@@ -1163,45 +1195,88 @@ func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
}
// Use AWS Signature V4 authentication
accessKey := strings.TrimSpace(b.config.awsAccessKey)
region := strings.TrimSpace(b.config.awsRegion)
t := time.Now().UTC()
amzDate := t.Format("20060102T150405Z")
dateStamp := t.Format("20060102")
path := headers.Get(":path")
signature := b.generateSignature(path, amzDate, dateStamp, body)
headers.Set("X-Amz-Date", amzDate)
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", accessKey, dateStamp, region, awsService, bedrockSignedHeaders, signature))
}
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
path = encodeSigV4Path(path)
canonicalURI := encodeSigV4Path(path)
hashedPayload := sha256Hex(body)
region := strings.TrimSpace(b.config.awsRegion)
secretKey := strings.TrimSpace(b.config.awsSecretKey)
endpoint := fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion)
endpoint := fmt.Sprintf(bedrockDefaultDomain, region)
canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate)
canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
httpPostMethod, path, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
httpPostMethod, canonicalURI, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, b.config.awsRegion, awsService)
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, region, awsService)
hashedCanonReq := sha256Hex([]byte(canonicalRequest))
stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
amzDate, credentialScope, hashedCanonReq)
signingKey := getSignatureKey(b.config.awsSecretKey, dateStamp, b.config.awsRegion, awsService)
signingKey := getSignatureKey(secretKey, dateStamp, region, awsService)
signature := hmacHex(signingKey, stringToSign)
return signature
}
func encodeSigV4Path(path string) string {
// Keep only the URI path for canonical URI. Query string is handled separately in SigV4,
// and this implementation uses an empty canonical query string.
if queryIndex := strings.Index(path, "?"); queryIndex >= 0 {
path = path[:queryIndex]
}
segments := strings.Split(path, "/")
for i, seg := range segments {
if seg == "" {
continue
}
segments[i] = url.PathEscape(seg)
// Normalize to "single-encoded" form:
// - raw ":" -> %3A
// - already encoded "%3A" -> still %3A (not %253A)
decoded, err := url.PathUnescape(seg)
if err == nil {
segments[i] = sigV4EscapePathSegment(decoded)
} else {
// If segment has invalid escape sequence, fall back to escaping raw segment.
segments[i] = sigV4EscapePathSegment(seg)
}
}
return strings.Join(segments, "/")
}
func sigV4EscapePathSegment(segment string) string {
const upperHex = "0123456789ABCDEF"
var b strings.Builder
b.Grow(len(segment) * 3)
for i := 0; i < len(segment); i++ {
c := segment[i]
if isSigV4Unreserved(c) {
b.WriteByte(c)
continue
}
b.WriteByte('%')
b.WriteByte(upperHex[c>>4])
b.WriteByte(upperHex[c&0x0F])
}
return b.String()
}
func isSigV4Unreserved(c byte) bool {
return (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '_' || c == '.' || c == '~'
}
func getSignatureKey(key, dateStamp, region, service string) []byte {
kDate := hmacSha256([]byte("AWS4"+key), dateStamp)
kRegion := hmacSha256(kDate, region)