feat(ai-proxy): 添加Amazon Bedrock Prompt Cache保留策略配置及优化缓存处理逻辑 (#3609)

This commit is contained in:
woody
2026-03-18 20:37:04 +08:00
committed by GitHub
parent 8961db2e90
commit 62df71aadf
4 changed files with 562 additions and 23 deletions

View File

@@ -133,6 +133,41 @@ func bedrockApiTokenConfigWithCachePointPositions(positions map[string]bool) jso
return data
}
func bedrockApiTokenConfigWithPromptCacheRetention(promptCacheRetention string) json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"apiTokens": []string{
"test-token-for-unit-test",
},
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
"promptCacheRetention": promptCacheRetention,
},
})
return data
}
func bedrockApiTokenConfigWithModelAndPromptCache(mappedModel, promptCacheRetention string, positions map[string]bool) json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"apiTokens": []string{
"test-token-for-unit-test",
},
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": mappedModel,
},
"promptCacheRetention": promptCacheRetention,
"bedrockPromptCachePointPositions": positions,
},
})
return data
}
// Test config: Bedrock config with multiple Bearer Tokens
var bedrockMultiTokenConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
@@ -390,7 +425,7 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
})
t.Run("bedrock request body prompt cache in_memory should inject system cache point only by default", func(t *testing.T) {
t.Run("bedrock request body prompt cache in-memory should inject system cache point only by default", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
@@ -405,7 +440,7 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
requestBody := `{
"model": "gpt-4",
"prompt_cache_retention": "in_memory",
"prompt_cache_retention": "in-memory",
"prompt_cache_key": "session-001",
"messages": [
{
@@ -440,7 +475,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
systemCachePoint, ok := systemCachePointBlock["cachePoint"].(map[string]interface{})
require.True(t, ok, "system tail block should contain cachePoint")
require.Equal(t, "default", systemCachePoint["type"])
require.Equal(t, "5m", systemCachePoint["ttl"])
_, hasTTL := systemCachePoint["ttl"]
require.False(t, hasTTL, "ttl should be omitted for in_memory to use Bedrock default 5m")
messages := bodyMap["messages"].([]interface{})
require.NotEmpty(t, messages, "messages should not be empty")
@@ -451,6 +487,91 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
require.False(t, hasMessageCachePoint, "last message should not include cachePoint by default")
})
t.Run("bedrock request body should use provider promptCacheRetention in-memory when request omits prompt_cache_retention", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfigWithPromptCacheRetention("in-memory"))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(processedBody, &bodyMap)
require.NoError(t, err)
systemBlocks := bodyMap["system"].([]interface{})
require.Len(t, systemBlocks, 2, "provider promptCacheRetention should trigger cachePoint injection")
systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
_, hasTTL := systemCachePoint["ttl"]
require.False(t, hasTTL, "provider promptCacheRetention=in-memory should omit ttl and use Bedrock default 5m")
})
t.Run("bedrock request body prompt_cache_retention should override provider promptCacheRetention", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfigWithPromptCacheRetention("in_memory"))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"prompt_cache_retention": "24h",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(processedBody, &bodyMap)
require.NoError(t, err)
systemBlocks := bodyMap["system"].([]interface{})
systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
require.Equal(t, "1h", systemCachePoint["ttl"], "request prompt_cache_retention should override provider promptCacheRetention")
})
t.Run("bedrock request body prompt cache 24h should map to 1h ttl on system cache point by default", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
@@ -549,7 +670,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
systemBlocks := bodyMap["system"].([]interface{})
require.Len(t, systemBlocks, 2, "system should include cachePoint due to systemPrompt=true")
systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
require.Equal(t, "5m", systemCachePoint["ttl"])
_, hasSystemTTL := systemCachePoint["ttl"]
require.False(t, hasSystemTTL, "ttl should be omitted for in_memory cachePoint")
messages := bodyMap["messages"].([]interface{})
require.Len(t, messages, 2, "system message should not be in messages array")
@@ -557,7 +679,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
lastUserMessageContent := messages[0].(map[string]interface{})["content"].([]interface{})
require.Len(t, lastUserMessageContent, 2, "last user message should include one cachePoint")
lastUserMessageCachePoint := lastUserMessageContent[len(lastUserMessageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
require.Equal(t, "5m", lastUserMessageCachePoint["ttl"])
_, hasLastUserTTL := lastUserMessageCachePoint["ttl"]
require.False(t, hasLastUserTTL, "ttl should be omitted for in_memory cachePoint")
lastMessageContent := messages[1].(map[string]interface{})["content"].([]interface{})
require.Len(t, lastMessageContent, 1, "last message should not include cachePoint when lastMessage=false")
@@ -608,7 +731,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
messageContent := messages[0].(map[string]interface{})["content"].([]interface{})
require.Len(t, messageContent, 2, "overlap positions should still insert only one cachePoint")
cachePoint := messageContent[len(messageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
require.Equal(t, "5m", cachePoint["ttl"])
_, hasTTL := cachePoint["ttl"]
require.False(t, hasTTL, "ttl should be omitted for in_memory cachePoint")
})
t.Run("bedrock request body with empty prompt cache retention should not inject cache points", func(t *testing.T) {
@@ -711,6 +835,63 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
require.False(t, hasMessageCachePoint, "message block should not include cachePoint when retention is unsupported")
})
t.Run("bedrock request body should skip prompt cache for unsupported model even when enabled", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfigWithModelAndPromptCache(
"meta.llama3-70b-instruct-v1:0",
"in_memory",
map[string]bool{
"systemPrompt": true,
"lastMessage": true,
},
))
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"prompt_cache_retention": "24h",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(processedBody, &bodyMap)
require.NoError(t, err)
systemBlocks := bodyMap["system"].([]interface{})
require.Len(t, systemBlocks, 1, "unsupported model should skip system cachePoint injection")
_, hasSystemCachePoint := systemBlocks[0].(map[string]interface{})["cachePoint"]
require.False(t, hasSystemCachePoint, "unsupported model should not contain system cachePoint")
messages := bodyMap["messages"].([]interface{})
require.Len(t, messages, 1, "system message should not be in messages array")
lastMessageContent := messages[0].(map[string]interface{})["content"].([]interface{})
require.Len(t, lastMessageContent, 1, "unsupported model should skip message cachePoint injection")
_, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"]
require.False(t, hasMessageCachePoint, "unsupported model should not contain message cachePoint")
})
t.Run("bedrock request body without system should not inject cache point by default", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
@@ -1327,7 +1508,9 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
usageMap := usage.(map[string]interface{})
promptTokensDetails, hasPromptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
require.True(t, hasPromptTokensDetails, "prompt_tokens_details should exist when cacheReadInputTokens is present")
require.Equal(t, float64(6), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens")
require.Equal(t, float64(18), promptTokensDetails["cached_tokens"], "cached_tokens should sum cacheReadInputTokens and cacheWriteInputTokens")
_, hasCacheWriteTokens := promptTokensDetails["cache_write_tokens"]
require.False(t, hasCacheWriteTokens, "cache_write_tokens should not exist in OpenAI-compatible usage")
})
t.Run("bedrock response body with zero cache read tokens should omit prompt_tokens_details", func(t *testing.T) {
@@ -1397,11 +1580,95 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
_, hasPromptTokensDetails := usageMap["prompt_tokens_details"]
require.False(t, hasPromptTokensDetails, "prompt_tokens_details should be omitted when cacheReadInputTokens is zero")
})
t.Run("bedrock response body with only cache write tokens should map to cached_tokens", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.ActionContinue, action)
responseBody := `{
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Hello! How can I help you today?"
}
]
}
},
"stopReason": "end_turn",
"usage": {
"inputTokens": 10,
"outputTokens": 15,
"totalTokens": 25,
"cacheReadInputTokens": 0,
"cacheWriteInputTokens": 9
}
}`
action = host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
transformedResponseBody := host.GetResponseBody()
require.NotNil(t, transformedResponseBody)
var responseMap map[string]interface{}
err := json.Unmarshal(transformedResponseBody, &responseMap)
require.NoError(t, err)
usageMap := responseMap["usage"].(map[string]interface{})
promptTokensDetails, hasPromptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
require.True(t, hasPromptTokensDetails, "prompt_tokens_details should exist when cacheWriteInputTokens is present")
require.Equal(t, float64(9), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheWriteInputTokens when cacheReadInputTokens is zero")
_, hasCacheWriteTokens := promptTokensDetails["cache_write_tokens"]
require.False(t, hasCacheWriteTokens, "cache_write_tokens should not exist in OpenAI-compatible usage")
})
})
}
func RunBedrockOnStreamingResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
extractFirstDataPayload := func(body []byte) string {
for _, line := range strings.Split(string(body), "\n") {
if strings.HasPrefix(line, "data: ") && line != "data: [DONE]" {
return strings.TrimPrefix(line, "data: ")
}
}
return ""
}
t.Run("extract first data payload should return empty when no data line", func(t *testing.T) {
payload := extractFirstDataPayload([]byte("event: ping\n\n"))
require.Equal(t, "", payload)
})
t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
@@ -1465,7 +1732,147 @@ func RunBedrockOnStreamingResponseBodyTests(t *testing.T) {
require.NoError(t, err)
usageMap := responseMap["usage"].(map[string]interface{})
promptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
require.Equal(t, float64(7), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens in streaming usage event")
require.Equal(t, float64(10), promptTokensDetails["cached_tokens"], "cached_tokens should sum cacheReadInputTokens and cacheWriteInputTokens in streaming usage event")
_, hasCacheWriteTokens := promptTokensDetails["cache_write_tokens"]
require.False(t, hasCacheWriteTokens, "cache_write_tokens should not exist in OpenAI-compatible streaming usage")
})
t.Run("bedrock streaming text chunk then usage chunk format is stable", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"stream": true
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/vnd.amazon.eventstream"},
})
require.Equal(t, types.ActionContinue, action)
textChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{
"delta": map[string]interface{}{
"text": "Hello from Bedrock",
},
})
action = host.CallOnHttpStreamingResponseBody(textChunk, false)
require.Equal(t, types.ActionContinue, action)
firstResponseBody := host.GetResponseBody()
require.NotNil(t, firstResponseBody)
firstDataPayload := extractFirstDataPayload(firstResponseBody)
require.NotEmpty(t, firstDataPayload, "first chunk should contain one SSE data payload")
var firstResponseMap map[string]interface{}
err := json.Unmarshal([]byte(firstDataPayload), &firstResponseMap)
require.NoError(t, err)
firstChoices := firstResponseMap["choices"].([]interface{})
require.Len(t, firstChoices, 1, "text chunk should contain one choice")
usageChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{
"usage": map[string]interface{}{
"inputTokens": 10,
"outputTokens": 2,
"totalTokens": 12,
},
})
action = host.CallOnHttpStreamingResponseBody(usageChunk, true)
require.Equal(t, types.ActionContinue, action)
secondResponseBody := host.GetResponseBody()
require.NotNil(t, secondResponseBody)
require.Contains(t, string(secondResponseBody), "data: [DONE]", "last chunk should append [DONE]")
secondDataPayload := extractFirstDataPayload(secondResponseBody)
require.NotEmpty(t, secondDataPayload, "usage chunk should contain one SSE data payload")
var secondResponseMap map[string]interface{}
err = json.Unmarshal([]byte(secondDataPayload), &secondResponseMap)
require.NoError(t, err)
secondChoices := secondResponseMap["choices"].([]interface{})
require.Len(t, secondChoices, 0, "usage chunk should contain empty choices by design")
_, hasUsage := secondResponseMap["usage"]
require.True(t, hasUsage, "usage chunk should include usage field")
})
t.Run("bedrock empty intermediate callback should not affect next usage event", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"stream": true
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/vnd.amazon.eventstream"},
})
require.Equal(t, types.ActionContinue, action)
action = host.CallOnHttpStreamingResponseBody([]byte{}, false)
require.Equal(t, types.ActionContinue, action)
emptyResponseBody := host.GetResponseBody()
require.Equal(t, 0, len(emptyResponseBody), "empty intermediate callback should output empty payload")
usageChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{
"usage": map[string]interface{}{
"inputTokens": 10,
"outputTokens": 2,
"totalTokens": 12,
},
})
action = host.CallOnHttpStreamingResponseBody(usageChunk, true)
require.Equal(t, types.ActionContinue, action)
finalResponseBody := host.GetResponseBody()
require.NotNil(t, finalResponseBody)
require.Contains(t, string(finalResponseBody), "data: [DONE]", "last chunk should append [DONE]")
finalDataPayload := extractFirstDataPayload(finalResponseBody)
require.NotEmpty(t, finalDataPayload, "final usage event should still be parsed")
var finalResponseMap map[string]interface{}
err := json.Unmarshal([]byte(finalDataPayload), &finalResponseMap)
require.NoError(t, err)
finalChoices := finalResponseMap["choices"].([]interface{})
require.Len(t, finalChoices, 0, "usage chunk should still keep empty choices")
})
})
}