feat(ai-proxy): 添加Amazon Bedrock Prompt Cache保留策略配置及优化缓存处理逻辑 (#3609)

This commit is contained in:
woody
2026-03-18 20:37:04 +08:00
committed by GitHub
parent 8961db2e90
commit 62df71aadf
4 changed files with 562 additions and 23 deletions

View File

@@ -35,12 +35,14 @@ const (
// converseStream路径 /model/{modelId}/converse-stream
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
// invoke_model 路径 /model/{modelId}/invoke
bedrockInvokeModelPath = "/model/%s/invoke"
bedrockSignedHeaders = "host;x-amz-date"
requestIdHeader = "X-Amzn-Requestid"
bedrockCacheTypeDefault = "default"
bedrockCacheTTL5m = "5m"
bedrockCacheTTL1h = "1h"
bedrockInvokeModelPath = "/model/%s/invoke"
bedrockSignedHeaders = "host;x-amz-date"
requestIdHeader = "X-Amzn-Requestid"
bedrockCacheTypeDefault = "default"
bedrockCacheTTL5m = "5m"
bedrockCacheTTL1h = "1h"
bedrockPromptCacheNova = "amazon.nova"
bedrockPromptCacheClaude = "anthropic.claude"
bedrockCachePointPositionSystemPrompt = "systemPrompt"
bedrockCachePointPositionLastUserMessage = "lastUserMessage"
@@ -179,7 +181,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
CompletionTokens: bedrockEvent.Usage.OutputTokens,
PromptTokens: bedrockEvent.Usage.InputTokens,
TotalTokens: bedrockEvent.Usage.TotalTokens,
PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens),
PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens, bedrockEvent.Usage.CacheWriteInputTokens),
}
}
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
@@ -839,11 +841,17 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
},
}
effectivePromptCacheRetention := b.resolvePromptCacheRetention(origRequest.PromptCacheRetention)
if origRequest.PromptCacheKey != "" {
log.Warnf("bedrock provider ignores prompt_cache_key because Converse API has no equivalent field")
}
if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(origRequest.PromptCacheRetention); ok {
addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions())
if isPromptCacheSupportedModel(origRequest.Model) {
if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(effectivePromptCacheRetention); ok {
addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions())
}
} else if effectivePromptCacheRetention != "" {
log.Warnf("skip prompt cache injection for unsupported model: %s", origRequest.Model)
}
if origRequest.ReasoningEffort != "" {
@@ -950,7 +958,7 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
PromptTokens: bedrockResponse.Usage.InputTokens,
CompletionTokens: bedrockResponse.Usage.OutputTokens,
TotalTokens: bedrockResponse.Usage.TotalTokens,
PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens),
PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens, bedrockResponse.Usage.CacheWriteInputTokens),
},
}
}
@@ -982,11 +990,14 @@ func stopReasonBedrock2OpenAI(reason string) string {
}
func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
switch retention {
normalizedRetention := normalizePromptCacheRetention(retention)
switch normalizedRetention {
case "":
return "", false
case "in_memory":
return bedrockCacheTTL5m, true
// For the default 5-minute cache, omit ttl and let Bedrock apply its default.
// This is more robust for models that are strict about explicit ttl fields.
return "", true
case "24h":
return bedrockCacheTTL1h, true
default:
@@ -995,6 +1006,32 @@ func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
}
}
func normalizePromptCacheRetention(retention string) string {
normalized := strings.ToLower(strings.TrimSpace(retention))
normalized = strings.ReplaceAll(normalized, "-", "_")
normalized = strings.ReplaceAll(normalized, " ", "_")
if normalized == "inmemory" {
return "in_memory"
}
return normalized
}
func isPromptCacheSupportedModel(model string) bool {
normalizedModel := strings.ToLower(strings.TrimSpace(model))
return strings.Contains(normalizedModel, bedrockPromptCacheNova) ||
strings.Contains(normalizedModel, bedrockPromptCacheClaude)
}
func (b *bedrockProvider) resolvePromptCacheRetention(requestPromptCacheRetention string) string {
if requestPromptCacheRetention != "" {
return requestPromptCacheRetention
}
if b.config.promptCacheRetention != "" {
return b.config.promptCacheRetention
}
return ""
}
func (b *bedrockProvider) getPromptCachePointPositions() map[string]bool {
if b.config.bedrockPromptCachePointPositions == nil {
return map[string]bool{
@@ -1070,6 +1107,9 @@ func findLastMessageIndexByRole(messages []bedrockMessage, role string) int {
}
func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageIndex int, cacheTTL string) {
if messageIndex < 0 || messageIndex >= len(request.Messages) {
return
}
request.Messages[messageIndex].Content = append(request.Messages[messageIndex].Content, bedrockMessageContent{
CachePoint: &bedrockCachePoint{
Type: bedrockCacheTypeDefault,
@@ -1078,12 +1118,13 @@ func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageInd
})
}
func buildPromptTokensDetails(cacheReadInputTokens int) *promptTokensDetails {
if cacheReadInputTokens <= 0 {
func buildPromptTokensDetails(cacheReadInputTokens int, cacheWriteInputTokens int) *promptTokensDetails {
totalCachedTokens := cacheReadInputTokens + cacheWriteInputTokens
if totalCachedTokens <= 0 {
return nil
}
return &promptTokensDetails{
CachedTokens: cacheReadInputTokens,
CachedTokens: totalCachedTokens,
}
}