From 68d6090e36bde2b7c442e327a2d839a42fd490f1 Mon Sep 17 00:00:00 2001 From: woody Date: Thu, 12 Mar 2026 17:44:42 +0800 Subject: [PATCH] feat(bedrock): prompt caching params transform (#3563) --- .../wasm-go/extensions/ai-proxy/main_test.go | 1 + .../extensions/ai-proxy/provider/bedrock.go | 161 ++++- .../extensions/ai-proxy/provider/model.go | 58 +- .../extensions/ai-proxy/provider/provider.go | 9 + .../extensions/ai-proxy/test/bedrock.go | 561 +++++++++++++++++- 5 files changed, 747 insertions(+), 43 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index acd699c3..1b44bee5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -199,6 +199,7 @@ func TestBedrock(t *testing.T) { test.RunBedrockOnHttpRequestBodyTests(t) test.RunBedrockOnHttpResponseHeadersTests(t) test.RunBedrockOnHttpResponseBodyTests(t) + test.RunBedrockOnStreamingResponseBodyTests(t) test.RunBedrockToolCallTests(t) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index c8ff8855..146d0910 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -35,9 +35,16 @@ const ( // converseStream路径 /model/{modelId}/converse-stream bedrockStreamChatCompletionPath = "/model/%s/converse-stream" // invoke_model 路径 /model/{modelId}/invoke - bedrockInvokeModelPath = "/model/%s/invoke" - bedrockSignedHeaders = "host;x-amz-date" - requestIdHeader = "X-Amzn-Requestid" + bedrockInvokeModelPath = "/model/%s/invoke" + bedrockSignedHeaders = "host;x-amz-date" + requestIdHeader = "X-Amzn-Requestid" + bedrockCacheTypeDefault = "default" + bedrockCacheTTL5m = "5m" + bedrockCacheTTL1h = "1h" + + bedrockCachePointPositionSystemPrompt = "systemPrompt" + bedrockCachePointPositionLastUserMessage = "lastUserMessage" + bedrockCachePointPositionLastMessage = "lastMessage" ) var ( @@ -169,9 +176,10 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex if bedrockEvent.Usage != nil { openAIFormattedChunk.Choices = choices[:0] openAIFormattedChunk.Usage = &usage{ - CompletionTokens: bedrockEvent.Usage.OutputTokens, - PromptTokens: bedrockEvent.Usage.InputTokens, - TotalTokens: bedrockEvent.Usage.TotalTokens, + CompletionTokens: bedrockEvent.Usage.OutputTokens, + PromptTokens: bedrockEvent.Usage.InputTokens, + TotalTokens: bedrockEvent.Usage.TotalTokens, + PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens), } } openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk) @@ -831,6 +839,13 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom }, } + if origRequest.PromptCacheKey != "" { + log.Warnf("bedrock provider ignores prompt_cache_key because Converse API has no equivalent field") + } + if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(origRequest.PromptCacheRetention); ok { + addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions()) + } + if origRequest.ReasoningEffort != "" { thinkingBudget := 1024 // default switch origRequest.ReasoningEffort { @@ -932,9 +947,10 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b Object: objectChatCompletion, Choices: choices, Usage: &usage{ - PromptTokens: bedrockResponse.Usage.InputTokens, - CompletionTokens: bedrockResponse.Usage.OutputTokens, - TotalTokens: bedrockResponse.Usage.TotalTokens, + PromptTokens: bedrockResponse.Usage.InputTokens, + CompletionTokens: bedrockResponse.Usage.OutputTokens, + TotalTokens: bedrockResponse.Usage.TotalTokens, + PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens), }, } } @@ -965,6 +981,112 @@ func stopReasonBedrock2OpenAI(reason string) string { } } +func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) { + switch retention { + case "": + return "", false + case "in_memory": + return bedrockCacheTTL5m, true + case "24h": + return bedrockCacheTTL1h, true + default: + log.Warnf("unsupported prompt_cache_retention for bedrock mapping: %s", retention) + return "", false + } +} + +func (b *bedrockProvider) getPromptCachePointPositions() map[string]bool { + if b.config.bedrockPromptCachePointPositions == nil { + return map[string]bool{ + bedrockCachePointPositionSystemPrompt: true, + bedrockCachePointPositionLastMessage: false, + } + } + positions := map[string]bool{ + bedrockCachePointPositionSystemPrompt: false, + bedrockCachePointPositionLastUserMessage: false, + bedrockCachePointPositionLastMessage: false, + } + for rawKey, enabled := range b.config.bedrockPromptCachePointPositions { + key := normalizeBedrockCachePointPosition(rawKey) + switch key { + case bedrockCachePointPositionSystemPrompt, bedrockCachePointPositionLastUserMessage, bedrockCachePointPositionLastMessage: + positions[key] = enabled + default: + log.Warnf("unsupported bedrockPromptCachePointPositions key: %s", rawKey) + } + } + return positions +} + +func normalizeBedrockCachePointPosition(raw string) string { + key := strings.ToLower(raw) + key = strings.ReplaceAll(key, "_", "") + key = strings.ReplaceAll(key, "-", "") + switch key { + case "systemprompt": + return bedrockCachePointPositionSystemPrompt + case "lastusermessage": + return bedrockCachePointPositionLastUserMessage + case "lastmessage": + return bedrockCachePointPositionLastMessage + default: + return raw + } +} + +func addPromptCachePointsToBedrockRequest(request *bedrockTextGenRequest, cacheTTL string, positions map[string]bool) { + if positions[bedrockCachePointPositionSystemPrompt] && len(request.System) > 0 { + request.System = append(request.System, systemContentBlock{ + CachePoint: &bedrockCachePoint{ + Type: bedrockCacheTypeDefault, + TTL: cacheTTL, + }, + }) + } + + lastUserMessageIndex := -1 + if positions[bedrockCachePointPositionLastUserMessage] { + lastUserMessageIndex = findLastMessageIndexByRole(request.Messages, roleUser) + if lastUserMessageIndex >= 0 { + appendCachePointToBedrockMessage(request, lastUserMessageIndex, cacheTTL) + } + } + if positions[bedrockCachePointPositionLastMessage] && len(request.Messages) > 0 { + lastMessageIndex := len(request.Messages) - 1 + if lastMessageIndex != lastUserMessageIndex { + appendCachePointToBedrockMessage(request, lastMessageIndex, cacheTTL) + } + } +} + +func findLastMessageIndexByRole(messages []bedrockMessage, role string) int { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == role { + return i + } + } + return -1 +} + +func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageIndex int, cacheTTL string) { + request.Messages[messageIndex].Content = append(request.Messages[messageIndex].Content, bedrockMessageContent{ + CachePoint: &bedrockCachePoint{ + Type: bedrockCacheTypeDefault, + TTL: cacheTTL, + }, + }) +} + +func buildPromptTokensDetails(cacheReadInputTokens int) *promptTokensDetails { + if cacheReadInputTokens <= 0 { + return nil + } + return &promptTokensDetails{ + CachedTokens: cacheReadInputTokens, + } +} + type bedrockTextGenRequest struct { Messages []bedrockMessage `json:"messages"` System []systemContentBlock `json:"system,omitempty"` @@ -1009,14 +1131,21 @@ type bedrockMessage struct { } type bedrockMessageContent struct { - Text string `json:"text,omitempty"` - Image *imageBlock `json:"image,omitempty"` - ToolResult *toolResultBlock `json:"toolResult,omitempty"` - ToolUse *toolUseBlock `json:"toolUse,omitempty"` + Text string `json:"text,omitempty"` + Image *imageBlock `json:"image,omitempty"` + ToolResult *toolResultBlock `json:"toolResult,omitempty"` + ToolUse *toolUseBlock `json:"toolUse,omitempty"` + CachePoint *bedrockCachePoint `json:"cachePoint,omitempty"` } type systemContentBlock struct { - Text string `json:"text,omitempty"` + Text string `json:"text,omitempty"` + CachePoint *bedrockCachePoint `json:"cachePoint,omitempty"` +} + +type bedrockCachePoint struct { + Type string `json:"type"` + TTL string `json:"ttl,omitempty"` } type imageBlock struct { @@ -1098,6 +1227,10 @@ type tokenUsage struct { OutputTokens int `json:"outputTokens,omitempty"` TotalTokens int `json:"totalTokens"` + + CacheReadInputTokens int `json:"cacheReadInputTokens,omitempty"` + + CacheWriteInputTokens int `json:"cacheWriteInputTokens,omitempty"` } func chatToolMessage2BedrockToolResultContent(chatMessage chatMessage) bedrockMessageContent { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index b231c739..881e4cbc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -42,34 +42,36 @@ type thinkingParam struct { type chatCompletionRequest struct { NonOpenAIStyleOptions - Messages []chatMessage `json:"messages"` - Model string `json:"model"` - Store bool `json:"store,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` - Metadata map[string]string `json:"metadata,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - Logprobs bool `json:"logprobs,omitempty"` - TopLogprobs int `json:"top_logprobs,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` - N int `json:"n,omitempty"` - Modalities []string `json:"modalities,omitempty"` - Prediction map[string]interface{} `json:"prediction,omitempty"` - Audio map[string]interface{} `json:"audio,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat map[string]interface{} `json:"response_format,omitempty"` - Seed int `json:"seed,omitempty"` - ServiceTier string `json:"service_tier,omitempty"` - Stop []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *streamOptions `json:"stream_options,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Tools []tool `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` - User string `json:"user,omitempty"` + Messages []chatMessage `json:"messages"` + Model string `json:"model"` + Store bool `json:"store,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + Logprobs bool `json:"logprobs,omitempty"` + TopLogprobs int `json:"top_logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + N int `json:"n,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Prediction map[string]interface{} `json:"prediction,omitempty"` + Audio map[string]interface{} `json:"audio,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat map[string]interface{} `json:"response_format,omitempty"` + Seed int `json:"seed,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *streamOptions `json:"stream_options,omitempty"` + PromptCacheRetention string `json:"prompt_cache_retention,omitempty"` + PromptCacheKey string `json:"prompt_cache_key,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Tools []tool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + User string `json:"user,omitempty"` } func (c *chatCompletionRequest) getMaxTokens() int { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 5ca5a658..bb1b9a63 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -354,6 +354,9 @@ type ProviderConfig struct { // @Title zh-CN Amazon Bedrock 额外模型请求参数 // @Description zh-CN 仅适用于Amazon Bedrock服务,用于设置模型特定的推理参数 bedrockAdditionalFields map[string]interface{} `required:"false" yaml:"bedrockAdditionalFields" json:"bedrockAdditionalFields"` + // @Title zh-CN Amazon Bedrock Prompt CachePoint 插入位置 + // @Description zh-CN 仅适用于Amazon Bedrock服务。用于配置 cachePoint 插入位置,支持多选:systemPrompt、lastUserMessage、lastMessage。值为 true 表示启用该位置。 + bedrockPromptCachePointPositions map[string]bool `required:"false" yaml:"bedrockPromptCachePointPositions" json:"bedrockPromptCachePointPositions"` // @Title zh-CN minimax API type // @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2 minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"` @@ -552,6 +555,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { for k, v := range json.Get("bedrockAdditionalFields").Map() { c.bedrockAdditionalFields[k] = v.Value() } + if rawPositions := json.Get("bedrockPromptCachePointPositions"); rawPositions.Exists() { + c.bedrockPromptCachePointPositions = make(map[string]bool) + for k, v := range rawPositions.Map() { + c.bedrockPromptCachePointPositions[k] = v.Bool() + } + } } c.minimaxApiType = json.Get("minimaxApiType").String() c.minimaxGroupId = json.Get("minimaxGroupId").String() diff --git a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go index 6a4b17e6..49c2c946 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go @@ -1,7 +1,11 @@ package test import ( + "bytes" + "encoding/binary" "encoding/json" + "hash/crc32" + "strings" "testing" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" @@ -112,6 +116,23 @@ var bedrockApiTokenConfig = func() json.RawMessage { return data }() +func bedrockApiTokenConfigWithCachePointPositions(positions map[string]bool) json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "bedrock", + "apiTokens": []string{ + "test-token-for-unit-test", + }, + "awsRegion": "us-east-1", + "modelMapping": map[string]string{ + "*": "anthropic.claude-3-5-haiku-20241022-v1:0", + }, + "bedrockPromptCachePointPositions": positions, + }, + }) + return data +} + // Test config: Bedrock config with multiple Bearer Tokens var bedrockMultiTokenConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ @@ -369,6 +390,372 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint") }) + t.Run("bedrock request body prompt cache in_memory should inject system cache point only by default", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "prompt_cache_retention": "in_memory", + "prompt_cache_key": "session-001", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + _, hasPromptCacheRetention := bodyMap["prompt_cache_retention"] + require.False(t, hasPromptCacheRetention, "prompt_cache_retention should not be forwarded to Bedrock") + _, hasPromptCacheKey := bodyMap["prompt_cache_key"] + require.False(t, hasPromptCacheKey, "prompt_cache_key should not be forwarded to Bedrock") + + systemBlocks, ok := bodyMap["system"].([]interface{}) + require.True(t, ok, "system should be an array") + require.Len(t, systemBlocks, 2, "system should contain text block and cachePoint block") + systemCachePointBlock := systemBlocks[len(systemBlocks)-1].(map[string]interface{}) + systemCachePoint, ok := systemCachePointBlock["cachePoint"].(map[string]interface{}) + require.True(t, ok, "system tail block should contain cachePoint") + require.Equal(t, "default", systemCachePoint["type"]) + require.Equal(t, "5m", systemCachePoint["ttl"]) + + messages := bodyMap["messages"].([]interface{}) + require.NotEmpty(t, messages, "messages should not be empty") + lastMessage := messages[len(messages)-1].(map[string]interface{}) + lastMessageContent := lastMessage["content"].([]interface{}) + require.Len(t, lastMessageContent, 1, "last message should keep original content only by default") + _, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"] + require.False(t, hasMessageCachePoint, "last message should not include cachePoint by default") + }) + + t.Run("bedrock request body prompt cache 24h should map to 1h ttl on system cache point by default", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "prompt_cache_retention": "24h", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + systemBlocks := bodyMap["system"].([]interface{}) + systemCachePointBlock := systemBlocks[len(systemBlocks)-1].(map[string]interface{}) + systemCachePoint := systemCachePointBlock["cachePoint"].(map[string]interface{}) + require.Equal(t, "1h", systemCachePoint["ttl"]) + + messages := bodyMap["messages"].([]interface{}) + lastMessage := messages[len(messages)-1].(map[string]interface{}) + lastMessageContent := lastMessage["content"].([]interface{}) + require.Len(t, lastMessageContent, 1, "last message should keep original content only by default") + _, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"] + require.False(t, hasMessageCachePoint, "last message should not include cachePoint by default") + }) + + t.Run("bedrock request body prompt cache should insert cache points based on configured positions", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfigWithCachePointPositions(map[string]bool{ + "systemPrompt": true, + "lastUserMessage": true, + "lastMessage": false, + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "prompt_cache_retention": "in_memory", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Question from user" + }, + { + "role": "assistant", + "content": "Previous assistant answer" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + systemBlocks := bodyMap["system"].([]interface{}) + require.Len(t, systemBlocks, 2, "system should include cachePoint due to systemPrompt=true") + systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{}) + require.Equal(t, "5m", systemCachePoint["ttl"]) + + messages := bodyMap["messages"].([]interface{}) + require.Len(t, messages, 2, "system message should not be in messages array") + + lastUserMessageContent := messages[0].(map[string]interface{})["content"].([]interface{}) + require.Len(t, lastUserMessageContent, 2, "last user message should include one cachePoint") + lastUserMessageCachePoint := lastUserMessageContent[len(lastUserMessageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{}) + require.Equal(t, "5m", lastUserMessageCachePoint["ttl"]) + + lastMessageContent := messages[1].(map[string]interface{})["content"].([]interface{}) + require.Len(t, lastMessageContent, 1, "last message should not include cachePoint when lastMessage=false") + }) + + t.Run("bedrock request body prompt cache should avoid duplicate insertion when lastUserMessage and lastMessage overlap", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfigWithCachePointPositions(map[string]bool{ + "systemPrompt": false, + "lastUserMessage": true, + "lastMessage": true, + })) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "prompt_cache_retention": "in_memory", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + _, hasSystem := bodyMap["system"] + require.False(t, hasSystem, "system should not include cachePoint when systemPrompt=false and no system messages") + + messages := bodyMap["messages"].([]interface{}) + require.Len(t, messages, 1, "only one message should exist") + messageContent := messages[0].(map[string]interface{})["content"].([]interface{}) + require.Len(t, messageContent, 2, "overlap positions should still insert only one cachePoint") + cachePoint := messageContent[len(messageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{}) + require.Equal(t, "5m", cachePoint["ttl"]) + }) + + t.Run("bedrock request body with empty prompt cache retention should not inject cache points", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "prompt_cache_retention": "", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + systemBlocks := bodyMap["system"].([]interface{}) + require.Len(t, systemBlocks, 1, "system should only contain the original text block") + _, hasSystemCachePoint := systemBlocks[0].(map[string]interface{})["cachePoint"] + require.False(t, hasSystemCachePoint, "system block should not include cachePoint when retention is empty") + + messages := bodyMap["messages"].([]interface{}) + lastMessage := messages[len(messages)-1].(map[string]interface{}) + lastMessageContent := lastMessage["content"].([]interface{}) + require.Len(t, lastMessageContent, 1, "message should only contain original text block") + _, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"] + require.False(t, hasMessageCachePoint, "message block should not include cachePoint when retention is empty") + }) + + t.Run("bedrock request body with unsupported prompt cache retention should not inject cache points", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "prompt_cache_retention": "2h", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + systemBlocks := bodyMap["system"].([]interface{}) + require.Len(t, systemBlocks, 1, "system should only contain the original text block") + _, hasSystemCachePoint := systemBlocks[0].(map[string]interface{})["cachePoint"] + require.False(t, hasSystemCachePoint, "system block should not include cachePoint when retention is unsupported") + + messages := bodyMap["messages"].([]interface{}) + lastMessage := messages[len(messages)-1].(map[string]interface{}) + lastMessageContent := lastMessage["content"].([]interface{}) + require.Len(t, lastMessageContent, 1, "message should only contain original text block") + _, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"] + require.False(t, hasMessageCachePoint, "message block should not include cachePoint when retention is unsupported") + }) + + t.Run("bedrock request body without system should not inject cache point by default", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "prompt_cache_retention": "in_memory", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + _, hasSystem := bodyMap["system"] + require.False(t, hasSystem, "system should be omitted when original request has no system prompts") + + messages := bodyMap["messages"].([]interface{}) + require.Len(t, messages, 1, "messages should keep original one user message") + lastMessage := messages[0].(map[string]interface{}) + lastMessageContent := lastMessage["content"].([]interface{}) + require.Len(t, lastMessageContent, 1, "message should keep original text block only by default") + _, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"] + require.False(t, hasMessageCachePoint, "message should not include cachePoint by default") + }) + // Test Bedrock request body processing with AWS Signature V4 authentication t.Run("bedrock chat completion request body with ak/sk", func(t *testing.T) { host, status := test.NewTestHost(basicBedrockConfig) @@ -911,7 +1298,9 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) { "usage": { "inputTokens": 10, "outputTokens": 15, - "totalTokens": 25 + "totalTokens": 25, + "cacheReadInputTokens": 6, + "cacheWriteInputTokens": 12 } }` @@ -935,6 +1324,176 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) { usage, exists := responseMap["usage"] require.True(t, exists, "Usage should exist in response body") require.NotNil(t, usage, "Usage should not be nil") + usageMap := usage.(map[string]interface{}) + promptTokensDetails, hasPromptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{}) + require.True(t, hasPromptTokensDetails, "prompt_tokens_details should exist when cacheReadInputTokens is present") + require.Equal(t, float64(6), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens") + }) + + t.Run("bedrock response body with zero cache read tokens should omit prompt_tokens_details", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + + action = host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.ActionContinue, action) + + responseBody := `{ + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "Hello! How can I help you today?" + } + ] + } + }, + "stopReason": "end_turn", + "usage": { + "inputTokens": 10, + "outputTokens": 15, + "totalTokens": 25, + "cacheReadInputTokens": 0 + } + }` + + action = host.CallOnHttpResponseBody([]byte(responseBody)) + require.Equal(t, types.ActionContinue, action) + + transformedResponseBody := host.GetResponseBody() + require.NotNil(t, transformedResponseBody) + + var responseMap map[string]interface{} + err := json.Unmarshal(transformedResponseBody, &responseMap) + require.NoError(t, err) + + usageMap := responseMap["usage"].(map[string]interface{}) + _, hasPromptTokensDetails := usageMap["prompt_tokens_details"] + require.False(t, hasPromptTokensDetails, "prompt_tokens_details should be omitted when cacheReadInputTokens is zero") }) }) } + +func RunBedrockOnStreamingResponseBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "stream": true + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + + action = host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/vnd.amazon.eventstream"}, + }) + require.Equal(t, types.ActionContinue, action) + + streamingChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{ + "usage": map[string]interface{}{ + "inputTokens": 10, + "outputTokens": 2, + "totalTokens": 12, + "cacheReadInputTokens": 7, + "cacheWriteInputTokens": 3, + }, + }) + action = host.CallOnHttpStreamingResponseBody(streamingChunk, true) + require.Equal(t, types.ActionContinue, action) + + transformedResponseBody := host.GetResponseBody() + require.NotNil(t, transformedResponseBody) + + var dataPayload string + for _, line := range strings.Split(string(transformedResponseBody), "\n") { + if strings.HasPrefix(line, "data: ") && line != "data: [DONE]" { + dataPayload = strings.TrimPrefix(line, "data: ") + break + } + } + require.NotEmpty(t, dataPayload, "should have at least one SSE data payload") + + var responseMap map[string]interface{} + err := json.Unmarshal([]byte(dataPayload), &responseMap) + require.NoError(t, err) + usageMap := responseMap["usage"].(map[string]interface{}) + promptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{}) + require.Equal(t, float64(7), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens in streaming usage event") + }) + }) +} + +func buildBedrockEventStreamMessage(t *testing.T, payload map[string]interface{}) []byte { + payloadBytes, err := json.Marshal(payload) + require.NoError(t, err) + + totalLength := uint32(16 + len(payloadBytes)) + headersLength := uint32(0) + + var message bytes.Buffer + prelude := make([]byte, 8) + binary.BigEndian.PutUint32(prelude[0:4], totalLength) + binary.BigEndian.PutUint32(prelude[4:8], headersLength) + message.Write(prelude) + + preludeCRC := crc32.ChecksumIEEE(prelude) + preludeCRCBytes := make([]byte, 4) + binary.BigEndian.PutUint32(preludeCRCBytes, preludeCRC) + message.Write(preludeCRCBytes) + + message.Write(payloadBytes) + + messageCRC := crc32.ChecksumIEEE(message.Bytes()) + messageCRCBytes := make([]byte, 4) + binary.BigEndian.PutUint32(messageCRCBytes, messageCRC) + message.Write(messageCRCBytes) + + return message.Bytes() +}