mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
feat(ai-proxy): 添加Amazon Bedrock Prompt Cache保留策略配置及优化缓存处理逻辑 (#3609)
This commit is contained in:
@@ -35,12 +35,14 @@ const (
|
||||
// converseStream路径 /model/{modelId}/converse-stream
|
||||
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
|
||||
// invoke_model 路径 /model/{modelId}/invoke
|
||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
bedrockCacheTypeDefault = "default"
|
||||
bedrockCacheTTL5m = "5m"
|
||||
bedrockCacheTTL1h = "1h"
|
||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
bedrockCacheTypeDefault = "default"
|
||||
bedrockCacheTTL5m = "5m"
|
||||
bedrockCacheTTL1h = "1h"
|
||||
bedrockPromptCacheNova = "amazon.nova"
|
||||
bedrockPromptCacheClaude = "anthropic.claude"
|
||||
|
||||
bedrockCachePointPositionSystemPrompt = "systemPrompt"
|
||||
bedrockCachePointPositionLastUserMessage = "lastUserMessage"
|
||||
@@ -179,7 +181,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
||||
PromptTokens: bedrockEvent.Usage.InputTokens,
|
||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens),
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens, bedrockEvent.Usage.CacheWriteInputTokens),
|
||||
}
|
||||
}
|
||||
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
|
||||
@@ -839,11 +841,17 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
},
|
||||
}
|
||||
|
||||
effectivePromptCacheRetention := b.resolvePromptCacheRetention(origRequest.PromptCacheRetention)
|
||||
|
||||
if origRequest.PromptCacheKey != "" {
|
||||
log.Warnf("bedrock provider ignores prompt_cache_key because Converse API has no equivalent field")
|
||||
}
|
||||
if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(origRequest.PromptCacheRetention); ok {
|
||||
addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions())
|
||||
if isPromptCacheSupportedModel(origRequest.Model) {
|
||||
if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(effectivePromptCacheRetention); ok {
|
||||
addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions())
|
||||
}
|
||||
} else if effectivePromptCacheRetention != "" {
|
||||
log.Warnf("skip prompt cache injection for unsupported model: %s", origRequest.Model)
|
||||
}
|
||||
|
||||
if origRequest.ReasoningEffort != "" {
|
||||
@@ -950,7 +958,7 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
||||
PromptTokens: bedrockResponse.Usage.InputTokens,
|
||||
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
||||
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens),
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens, bedrockResponse.Usage.CacheWriteInputTokens),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -982,11 +990,14 @@ func stopReasonBedrock2OpenAI(reason string) string {
|
||||
}
|
||||
|
||||
func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
|
||||
switch retention {
|
||||
normalizedRetention := normalizePromptCacheRetention(retention)
|
||||
switch normalizedRetention {
|
||||
case "":
|
||||
return "", false
|
||||
case "in_memory":
|
||||
return bedrockCacheTTL5m, true
|
||||
// For the default 5-minute cache, omit ttl and let Bedrock apply its default.
|
||||
// This is more robust for models that are strict about explicit ttl fields.
|
||||
return "", true
|
||||
case "24h":
|
||||
return bedrockCacheTTL1h, true
|
||||
default:
|
||||
@@ -995,6 +1006,32 @@ func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizePromptCacheRetention(retention string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(retention))
|
||||
normalized = strings.ReplaceAll(normalized, "-", "_")
|
||||
normalized = strings.ReplaceAll(normalized, " ", "_")
|
||||
if normalized == "inmemory" {
|
||||
return "in_memory"
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isPromptCacheSupportedModel(model string) bool {
|
||||
normalizedModel := strings.ToLower(strings.TrimSpace(model))
|
||||
return strings.Contains(normalizedModel, bedrockPromptCacheNova) ||
|
||||
strings.Contains(normalizedModel, bedrockPromptCacheClaude)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) resolvePromptCacheRetention(requestPromptCacheRetention string) string {
|
||||
if requestPromptCacheRetention != "" {
|
||||
return requestPromptCacheRetention
|
||||
}
|
||||
if b.config.promptCacheRetention != "" {
|
||||
return b.config.promptCacheRetention
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) getPromptCachePointPositions() map[string]bool {
|
||||
if b.config.bedrockPromptCachePointPositions == nil {
|
||||
return map[string]bool{
|
||||
@@ -1070,6 +1107,9 @@ func findLastMessageIndexByRole(messages []bedrockMessage, role string) int {
|
||||
}
|
||||
|
||||
func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageIndex int, cacheTTL string) {
|
||||
if messageIndex < 0 || messageIndex >= len(request.Messages) {
|
||||
return
|
||||
}
|
||||
request.Messages[messageIndex].Content = append(request.Messages[messageIndex].Content, bedrockMessageContent{
|
||||
CachePoint: &bedrockCachePoint{
|
||||
Type: bedrockCacheTypeDefault,
|
||||
@@ -1078,12 +1118,13 @@ func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageInd
|
||||
})
|
||||
}
|
||||
|
||||
func buildPromptTokensDetails(cacheReadInputTokens int) *promptTokensDetails {
|
||||
if cacheReadInputTokens <= 0 {
|
||||
func buildPromptTokensDetails(cacheReadInputTokens int, cacheWriteInputTokens int) *promptTokensDetails {
|
||||
totalCachedTokens := cacheReadInputTokens + cacheWriteInputTokens
|
||||
if totalCachedTokens <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &promptTokensDetails{
|
||||
CachedTokens: cacheReadInputTokens,
|
||||
CachedTokens: totalCachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -102,3 +102,90 @@ func TestGenerateSignatureDiffersForRawAndPreEncodedModelPath(t *testing.T) {
|
||||
preEncodedSignature := p.generateSignature(preEncodedPath, "20260312T142942Z", "20260312", body)
|
||||
assert.NotEqual(t, rawSignature, preEncodedSignature)
|
||||
}
|
||||
|
||||
func TestNormalizePromptCacheRetention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
retention string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "inmemory alias maps to in_memory",
|
||||
retention: "inmemory",
|
||||
want: "in_memory",
|
||||
},
|
||||
{
|
||||
name: "dash style maps to in_memory",
|
||||
retention: "in-memory",
|
||||
want: "in_memory",
|
||||
},
|
||||
{
|
||||
name: "space style with trim maps to in_memory",
|
||||
retention: " in memory ",
|
||||
want: "in_memory",
|
||||
},
|
||||
{
|
||||
name: "already normalized remains unchanged",
|
||||
retention: "in_memory",
|
||||
want: "in_memory",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, normalizePromptCacheRetention(tt.retention))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendCachePointToBedrockMessageInvalidIndexNoop(t *testing.T) {
|
||||
request := &bedrockTextGenRequest{
|
||||
Messages: []bedrockMessage{
|
||||
{
|
||||
Role: roleUser,
|
||||
Content: []bedrockMessageContent{
|
||||
{Text: "hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
appendCachePointToBedrockMessage(request, -1, bedrockCacheTTL5m)
|
||||
appendCachePointToBedrockMessage(request, len(request.Messages), bedrockCacheTTL5m)
|
||||
|
||||
assert.Len(t, request.Messages[0].Content, 1)
|
||||
|
||||
appendCachePointToBedrockMessage(request, 0, bedrockCacheTTL5m)
|
||||
assert.Len(t, request.Messages[0].Content, 2)
|
||||
assert.NotNil(t, request.Messages[0].Content[1].CachePoint)
|
||||
}
|
||||
|
||||
func TestIsPromptCacheSupportedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "anthropic claude model is supported",
|
||||
model: "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "amazon nova inference profile is supported",
|
||||
model: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.amazon.nova-2-lite-v1:0",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other model is not supported",
|
||||
model: "meta.llama3-70b-instruct-v1:0",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, isPromptCacheSupportedModel(tt.model))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,6 +357,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Amazon Bedrock Prompt CachePoint 插入位置
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务。用于配置 cachePoint 插入位置,支持多选:systemPrompt、lastUserMessage、lastMessage。值为 true 表示启用该位置。
|
||||
bedrockPromptCachePointPositions map[string]bool `required:"false" yaml:"bedrockPromptCachePointPositions" json:"bedrockPromptCachePointPositions"`
|
||||
// @Title zh-CN Amazon Bedrock Prompt Cache 保留策略(默认值)
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务。作为请求中 prompt_cache_retention 缺省时的默认值,支持 in_memory 和 24h。
|
||||
promptCacheRetention string `required:"false" yaml:"promptCacheRetention" json:"promptCacheRetention"`
|
||||
// @Title zh-CN minimax API type
|
||||
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2
|
||||
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
|
||||
@@ -558,6 +561,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
for k, v := range json.Get("bedrockAdditionalFields").Map() {
|
||||
c.bedrockAdditionalFields[k] = v.Value()
|
||||
}
|
||||
c.promptCacheRetention = json.Get("promptCacheRetention").String()
|
||||
if rawPositions := json.Get("bedrockPromptCachePointPositions"); rawPositions.Exists() {
|
||||
c.bedrockPromptCachePointPositions = make(map[string]bool)
|
||||
for k, v := range rawPositions.Map() {
|
||||
|
||||
Reference in New Issue
Block a user