diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go index b5c3a07fc..80c6c346d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock.go @@ -35,12 +35,14 @@ const ( // converseStream路径 /model/{modelId}/converse-stream bedrockStreamChatCompletionPath = "/model/%s/converse-stream" // invoke_model 路径 /model/{modelId}/invoke - bedrockInvokeModelPath = "/model/%s/invoke" - bedrockSignedHeaders = "host;x-amz-date" - requestIdHeader = "X-Amzn-Requestid" - bedrockCacheTypeDefault = "default" - bedrockCacheTTL5m = "5m" - bedrockCacheTTL1h = "1h" + bedrockInvokeModelPath = "/model/%s/invoke" + bedrockSignedHeaders = "host;x-amz-date" + requestIdHeader = "X-Amzn-Requestid" + bedrockCacheTypeDefault = "default" + bedrockCacheTTL5m = "5m" + bedrockCacheTTL1h = "1h" + bedrockPromptCacheNova = "amazon.nova" + bedrockPromptCacheClaude = "anthropic.claude" bedrockCachePointPositionSystemPrompt = "systemPrompt" bedrockCachePointPositionLastUserMessage = "lastUserMessage" @@ -179,7 +181,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex CompletionTokens: bedrockEvent.Usage.OutputTokens, PromptTokens: bedrockEvent.Usage.InputTokens, TotalTokens: bedrockEvent.Usage.TotalTokens, - PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens), + PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens, bedrockEvent.Usage.CacheWriteInputTokens), } } openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk) @@ -839,11 +841,17 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom }, } + effectivePromptCacheRetention := b.resolvePromptCacheRetention(origRequest.PromptCacheRetention) + if origRequest.PromptCacheKey != "" { log.Warnf("bedrock provider ignores prompt_cache_key because Converse API has no equivalent field") } - if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(origRequest.PromptCacheRetention); ok { - addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions()) + if isPromptCacheSupportedModel(origRequest.Model) { + if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(effectivePromptCacheRetention); ok { + addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions()) + } + } else if effectivePromptCacheRetention != "" { + log.Warnf("skip prompt cache injection for unsupported model: %s", origRequest.Model) } if origRequest.ReasoningEffort != "" { @@ -950,7 +958,7 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b PromptTokens: bedrockResponse.Usage.InputTokens, CompletionTokens: bedrockResponse.Usage.OutputTokens, TotalTokens: bedrockResponse.Usage.TotalTokens, - PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens), + PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens, bedrockResponse.Usage.CacheWriteInputTokens), }, } } @@ -982,11 +990,14 @@ func stopReasonBedrock2OpenAI(reason string) string { } func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) { - switch retention { + normalizedRetention := normalizePromptCacheRetention(retention) + switch normalizedRetention { case "": return "", false case "in_memory": - return bedrockCacheTTL5m, true + // For the default 5-minute cache, omit ttl and let Bedrock apply its default. + // This is more robust for models that are strict about explicit ttl fields. + return "", true case "24h": return bedrockCacheTTL1h, true default: @@ -995,6 +1006,32 @@ func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) { } } +func normalizePromptCacheRetention(retention string) string { + normalized := strings.ToLower(strings.TrimSpace(retention)) + normalized = strings.ReplaceAll(normalized, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + if normalized == "inmemory" { + return "in_memory" + } + return normalized +} + +func isPromptCacheSupportedModel(model string) bool { + normalizedModel := strings.ToLower(strings.TrimSpace(model)) + return strings.Contains(normalizedModel, bedrockPromptCacheNova) || + strings.Contains(normalizedModel, bedrockPromptCacheClaude) +} + +func (b *bedrockProvider) resolvePromptCacheRetention(requestPromptCacheRetention string) string { + if requestPromptCacheRetention != "" { + return requestPromptCacheRetention + } + if b.config.promptCacheRetention != "" { + return b.config.promptCacheRetention + } + return "" +} + func (b *bedrockProvider) getPromptCachePointPositions() map[string]bool { if b.config.bedrockPromptCachePointPositions == nil { return map[string]bool{ @@ -1070,6 +1107,9 @@ func findLastMessageIndexByRole(messages []bedrockMessage, role string) int { } func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageIndex int, cacheTTL string) { + if messageIndex < 0 || messageIndex >= len(request.Messages) { + return + } request.Messages[messageIndex].Content = append(request.Messages[messageIndex].Content, bedrockMessageContent{ CachePoint: &bedrockCachePoint{ Type: bedrockCacheTypeDefault, @@ -1078,12 +1118,13 @@ func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageInd }) } -func buildPromptTokensDetails(cacheReadInputTokens int) *promptTokensDetails { - if cacheReadInputTokens <= 0 { +func buildPromptTokensDetails(cacheReadInputTokens int, cacheWriteInputTokens int) *promptTokensDetails { + totalCachedTokens := cacheReadInputTokens + cacheWriteInputTokens + if totalCachedTokens <= 0 { return nil } return &promptTokensDetails{ - CachedTokens: cacheReadInputTokens, + CachedTokens: totalCachedTokens, } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock_sigv4_path_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock_sigv4_path_test.go index b782ca15b..8728b1acf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/bedrock_sigv4_path_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/bedrock_sigv4_path_test.go @@ -102,3 +102,90 @@ func TestGenerateSignatureDiffersForRawAndPreEncodedModelPath(t *testing.T) { preEncodedSignature := p.generateSignature(preEncodedPath, "20260312T142942Z", "20260312", body) assert.NotEqual(t, rawSignature, preEncodedSignature) } + +func TestNormalizePromptCacheRetention(t *testing.T) { + tests := []struct { + name string + retention string + want string + }{ + { + name: "inmemory alias maps to in_memory", + retention: "inmemory", + want: "in_memory", + }, + { + name: "dash style maps to in_memory", + retention: "in-memory", + want: "in_memory", + }, + { + name: "space style with trim maps to in_memory", + retention: " in memory ", + want: "in_memory", + }, + { + name: "already normalized remains unchanged", + retention: "in_memory", + want: "in_memory", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, normalizePromptCacheRetention(tt.retention)) + }) + } +} + +func TestAppendCachePointToBedrockMessageInvalidIndexNoop(t *testing.T) { + request := &bedrockTextGenRequest{ + Messages: []bedrockMessage{ + { + Role: roleUser, + Content: []bedrockMessageContent{ + {Text: "hello"}, + }, + }, + }, + } + + appendCachePointToBedrockMessage(request, -1, bedrockCacheTTL5m) + appendCachePointToBedrockMessage(request, len(request.Messages), bedrockCacheTTL5m) + + assert.Len(t, request.Messages[0].Content, 1) + + appendCachePointToBedrockMessage(request, 0, bedrockCacheTTL5m) + assert.Len(t, request.Messages[0].Content, 2) + assert.NotNil(t, request.Messages[0].Content[1].CachePoint) +} + +func TestIsPromptCacheSupportedModel(t *testing.T) { + tests := []struct { + name string + model string + want bool + }{ + { + name: "anthropic claude model is supported", + model: "anthropic.claude-3-5-haiku-20241022-v1:0", + want: true, + }, + { + name: "amazon nova inference profile is supported", + model: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.amazon.nova-2-lite-v1:0", + want: true, + }, + { + name: "other model is not supported", + model: "meta.llama3-70b-instruct-v1:0", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, isPromptCacheSupportedModel(tt.model)) + }) + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index d827b2f8a..80878c883 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -357,6 +357,9 @@ type ProviderConfig struct { // @Title zh-CN Amazon Bedrock Prompt CachePoint 插入位置 // @Description zh-CN 仅适用于Amazon Bedrock服务。用于配置 cachePoint 插入位置,支持多选:systemPrompt、lastUserMessage、lastMessage。值为 true 表示启用该位置。 bedrockPromptCachePointPositions map[string]bool `required:"false" yaml:"bedrockPromptCachePointPositions" json:"bedrockPromptCachePointPositions"` + // @Title zh-CN Amazon Bedrock Prompt Cache 保留策略(默认值) + // @Description zh-CN 仅适用于Amazon Bedrock服务。作为请求中 prompt_cache_retention 缺省时的默认值,支持 in_memory 和 24h。 + promptCacheRetention string `required:"false" yaml:"promptCacheRetention" json:"promptCacheRetention"` // @Title zh-CN minimax API type // @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2 minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"` @@ -558,6 +561,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { for k, v := range json.Get("bedrockAdditionalFields").Map() { c.bedrockAdditionalFields[k] = v.Value() } + c.promptCacheRetention = json.Get("promptCacheRetention").String() if rawPositions := json.Get("bedrockPromptCachePointPositions"); rawPositions.Exists() { c.bedrockPromptCachePointPositions = make(map[string]bool) for k, v := range rawPositions.Map() { diff --git a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go index 49c2c9467..a690f39ec 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/bedrock.go @@ -133,6 +133,41 @@ func bedrockApiTokenConfigWithCachePointPositions(positions map[string]bool) jso return data } +func bedrockApiTokenConfigWithPromptCacheRetention(promptCacheRetention string) json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "bedrock", + "apiTokens": []string{ + "test-token-for-unit-test", + }, + "awsRegion": "us-east-1", + "modelMapping": map[string]string{ + "*": "anthropic.claude-3-5-haiku-20241022-v1:0", + }, + "promptCacheRetention": promptCacheRetention, + }, + }) + return data +} + +func bedrockApiTokenConfigWithModelAndPromptCache(mappedModel, promptCacheRetention string, positions map[string]bool) json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "bedrock", + "apiTokens": []string{ + "test-token-for-unit-test", + }, + "awsRegion": "us-east-1", + "modelMapping": map[string]string{ + "*": mappedModel, + }, + "promptCacheRetention": promptCacheRetention, + "bedrockPromptCachePointPositions": positions, + }, + }) + return data +} + // Test config: Bedrock config with multiple Bearer Tokens var bedrockMultiTokenConfig = func() json.RawMessage { data, _ := json.Marshal(map[string]interface{}{ @@ -390,7 +425,7 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint") }) - t.Run("bedrock request body prompt cache in_memory should inject system cache point only by default", func(t *testing.T) { + t.Run("bedrock request body prompt cache in-memory should inject system cache point only by default", func(t *testing.T) { host, status := test.NewTestHost(bedrockApiTokenConfig) defer host.Reset() require.Equal(t, types.OnPluginStartStatusOK, status) @@ -405,7 +440,7 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { requestBody := `{ "model": "gpt-4", - "prompt_cache_retention": "in_memory", + "prompt_cache_retention": "in-memory", "prompt_cache_key": "session-001", "messages": [ { @@ -440,7 +475,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { systemCachePoint, ok := systemCachePointBlock["cachePoint"].(map[string]interface{}) require.True(t, ok, "system tail block should contain cachePoint") require.Equal(t, "default", systemCachePoint["type"]) - require.Equal(t, "5m", systemCachePoint["ttl"]) + _, hasTTL := systemCachePoint["ttl"] + require.False(t, hasTTL, "ttl should be omitted for in_memory to use Bedrock default 5m") messages := bodyMap["messages"].([]interface{}) require.NotEmpty(t, messages, "messages should not be empty") @@ -451,6 +487,91 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { require.False(t, hasMessageCachePoint, "last message should not include cachePoint by default") }) + t.Run("bedrock request body should use provider promptCacheRetention in-memory when request omits prompt_cache_retention", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfigWithPromptCacheRetention("in-memory")) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + systemBlocks := bodyMap["system"].([]interface{}) + require.Len(t, systemBlocks, 2, "provider promptCacheRetention should trigger cachePoint injection") + systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{}) + _, hasTTL := systemCachePoint["ttl"] + require.False(t, hasTTL, "provider promptCacheRetention=in-memory should omit ttl and use Bedrock default 5m") + }) + + t.Run("bedrock request body prompt_cache_retention should override provider promptCacheRetention", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfigWithPromptCacheRetention("in_memory")) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "prompt_cache_retention": "24h", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + systemBlocks := bodyMap["system"].([]interface{}) + systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{}) + require.Equal(t, "1h", systemCachePoint["ttl"], "request prompt_cache_retention should override provider promptCacheRetention") + }) + t.Run("bedrock request body prompt cache 24h should map to 1h ttl on system cache point by default", func(t *testing.T) { host, status := test.NewTestHost(bedrockApiTokenConfig) defer host.Reset() @@ -549,7 +670,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { systemBlocks := bodyMap["system"].([]interface{}) require.Len(t, systemBlocks, 2, "system should include cachePoint due to systemPrompt=true") systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{}) - require.Equal(t, "5m", systemCachePoint["ttl"]) + _, hasSystemTTL := systemCachePoint["ttl"] + require.False(t, hasSystemTTL, "ttl should be omitted for in_memory cachePoint") messages := bodyMap["messages"].([]interface{}) require.Len(t, messages, 2, "system message should not be in messages array") @@ -557,7 +679,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { lastUserMessageContent := messages[0].(map[string]interface{})["content"].([]interface{}) require.Len(t, lastUserMessageContent, 2, "last user message should include one cachePoint") lastUserMessageCachePoint := lastUserMessageContent[len(lastUserMessageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{}) - require.Equal(t, "5m", lastUserMessageCachePoint["ttl"]) + _, hasLastUserTTL := lastUserMessageCachePoint["ttl"] + require.False(t, hasLastUserTTL, "ttl should be omitted for in_memory cachePoint") lastMessageContent := messages[1].(map[string]interface{})["content"].([]interface{}) require.Len(t, lastMessageContent, 1, "last message should not include cachePoint when lastMessage=false") @@ -608,7 +731,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { messageContent := messages[0].(map[string]interface{})["content"].([]interface{}) require.Len(t, messageContent, 2, "overlap positions should still insert only one cachePoint") cachePoint := messageContent[len(messageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{}) - require.Equal(t, "5m", cachePoint["ttl"]) + _, hasTTL := cachePoint["ttl"] + require.False(t, hasTTL, "ttl should be omitted for in_memory cachePoint") }) t.Run("bedrock request body with empty prompt cache retention should not inject cache points", func(t *testing.T) { @@ -711,6 +835,63 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) { require.False(t, hasMessageCachePoint, "message block should not include cachePoint when retention is unsupported") }) + t.Run("bedrock request body should skip prompt cache for unsupported model even when enabled", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfigWithModelAndPromptCache( + "meta.llama3-70b-instruct-v1:0", + "in_memory", + map[string]bool{ + "systemPrompt": true, + "lastMessage": true, + }, + )) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "prompt_cache_retention": "24h", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + var bodyMap map[string]interface{} + err := json.Unmarshal(processedBody, &bodyMap) + require.NoError(t, err) + + systemBlocks := bodyMap["system"].([]interface{}) + require.Len(t, systemBlocks, 1, "unsupported model should skip system cachePoint injection") + _, hasSystemCachePoint := systemBlocks[0].(map[string]interface{})["cachePoint"] + require.False(t, hasSystemCachePoint, "unsupported model should not contain system cachePoint") + + messages := bodyMap["messages"].([]interface{}) + require.Len(t, messages, 1, "system message should not be in messages array") + lastMessageContent := messages[0].(map[string]interface{})["content"].([]interface{}) + require.Len(t, lastMessageContent, 1, "unsupported model should skip message cachePoint injection") + _, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"] + require.False(t, hasMessageCachePoint, "unsupported model should not contain message cachePoint") + }) + t.Run("bedrock request body without system should not inject cache point by default", func(t *testing.T) { host, status := test.NewTestHost(bedrockApiTokenConfig) defer host.Reset() @@ -1327,7 +1508,9 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) { usageMap := usage.(map[string]interface{}) promptTokensDetails, hasPromptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{}) require.True(t, hasPromptTokensDetails, "prompt_tokens_details should exist when cacheReadInputTokens is present") - require.Equal(t, float64(6), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens") + require.Equal(t, float64(18), promptTokensDetails["cached_tokens"], "cached_tokens should sum cacheReadInputTokens and cacheWriteInputTokens") + _, hasCacheWriteTokens := promptTokensDetails["cache_write_tokens"] + require.False(t, hasCacheWriteTokens, "cache_write_tokens should not exist in OpenAI-compatible usage") }) t.Run("bedrock response body with zero cache read tokens should omit prompt_tokens_details", func(t *testing.T) { @@ -1397,11 +1580,95 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) { _, hasPromptTokensDetails := usageMap["prompt_tokens_details"] require.False(t, hasPromptTokensDetails, "prompt_tokens_details should be omitted when cacheReadInputTokens is zero") }) + + t.Run("bedrock response body with only cache write tokens should map to cached_tokens", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ] + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + action = host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.ActionContinue, action) + + responseBody := `{ + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "Hello! How can I help you today?" + } + ] + } + }, + "stopReason": "end_turn", + "usage": { + "inputTokens": 10, + "outputTokens": 15, + "totalTokens": 25, + "cacheReadInputTokens": 0, + "cacheWriteInputTokens": 9 + } + }` + action = host.CallOnHttpResponseBody([]byte(responseBody)) + require.Equal(t, types.ActionContinue, action) + + transformedResponseBody := host.GetResponseBody() + require.NotNil(t, transformedResponseBody) + + var responseMap map[string]interface{} + err := json.Unmarshal(transformedResponseBody, &responseMap) + require.NoError(t, err) + + usageMap := responseMap["usage"].(map[string]interface{}) + promptTokensDetails, hasPromptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{}) + require.True(t, hasPromptTokensDetails, "prompt_tokens_details should exist when cacheWriteInputTokens is present") + require.Equal(t, float64(9), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheWriteInputTokens when cacheReadInputTokens is zero") + _, hasCacheWriteTokens := promptTokensDetails["cache_write_tokens"] + require.False(t, hasCacheWriteTokens, "cache_write_tokens should not exist in OpenAI-compatible usage") + }) }) } func RunBedrockOnStreamingResponseBodyTests(t *testing.T) { test.RunTest(t, func(t *testing.T) { + extractFirstDataPayload := func(body []byte) string { + for _, line := range strings.Split(string(body), "\n") { + if strings.HasPrefix(line, "data: ") && line != "data: [DONE]" { + return strings.TrimPrefix(line, "data: ") + } + } + return "" + } + + t.Run("extract first data payload should return empty when no data line", func(t *testing.T) { + payload := extractFirstDataPayload([]byte("event: ping\n\n")) + require.Equal(t, "", payload) + }) + t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) { host, status := test.NewTestHost(bedrockApiTokenConfig) defer host.Reset() @@ -1465,7 +1732,147 @@ func RunBedrockOnStreamingResponseBodyTests(t *testing.T) { require.NoError(t, err) usageMap := responseMap["usage"].(map[string]interface{}) promptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{}) - require.Equal(t, float64(7), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens in streaming usage event") + require.Equal(t, float64(10), promptTokensDetails["cached_tokens"], "cached_tokens should sum cacheReadInputTokens and cacheWriteInputTokens in streaming usage event") + _, hasCacheWriteTokens := promptTokensDetails["cache_write_tokens"] + require.False(t, hasCacheWriteTokens, "cache_write_tokens should not exist in OpenAI-compatible streaming usage") + }) + + t.Run("bedrock streaming text chunk then usage chunk format is stable", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "stream": true + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + action = host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/vnd.amazon.eventstream"}, + }) + require.Equal(t, types.ActionContinue, action) + + textChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{ + "delta": map[string]interface{}{ + "text": "Hello from Bedrock", + }, + }) + action = host.CallOnHttpStreamingResponseBody(textChunk, false) + require.Equal(t, types.ActionContinue, action) + + firstResponseBody := host.GetResponseBody() + require.NotNil(t, firstResponseBody) + firstDataPayload := extractFirstDataPayload(firstResponseBody) + require.NotEmpty(t, firstDataPayload, "first chunk should contain one SSE data payload") + + var firstResponseMap map[string]interface{} + err := json.Unmarshal([]byte(firstDataPayload), &firstResponseMap) + require.NoError(t, err) + firstChoices := firstResponseMap["choices"].([]interface{}) + require.Len(t, firstChoices, 1, "text chunk should contain one choice") + + usageChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{ + "usage": map[string]interface{}{ + "inputTokens": 10, + "outputTokens": 2, + "totalTokens": 12, + }, + }) + action = host.CallOnHttpStreamingResponseBody(usageChunk, true) + require.Equal(t, types.ActionContinue, action) + + secondResponseBody := host.GetResponseBody() + require.NotNil(t, secondResponseBody) + require.Contains(t, string(secondResponseBody), "data: [DONE]", "last chunk should append [DONE]") + secondDataPayload := extractFirstDataPayload(secondResponseBody) + require.NotEmpty(t, secondDataPayload, "usage chunk should contain one SSE data payload") + + var secondResponseMap map[string]interface{} + err = json.Unmarshal([]byte(secondDataPayload), &secondResponseMap) + require.NoError(t, err) + secondChoices := secondResponseMap["choices"].([]interface{}) + require.Len(t, secondChoices, 0, "usage chunk should contain empty choices by design") + _, hasUsage := secondResponseMap["usage"] + require.True(t, hasUsage, "usage chunk should include usage field") + }) + + t.Run("bedrock empty intermediate callback should not affect next usage event", func(t *testing.T) { + host, status := test.NewTestHost(bedrockApiTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestBody := `{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "stream": true + }` + action = host.CallOnHttpRequestBody([]byte(requestBody)) + require.Equal(t, types.ActionContinue, action) + + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + action = host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/vnd.amazon.eventstream"}, + }) + require.Equal(t, types.ActionContinue, action) + + action = host.CallOnHttpStreamingResponseBody([]byte{}, false) + require.Equal(t, types.ActionContinue, action) + emptyResponseBody := host.GetResponseBody() + require.Equal(t, 0, len(emptyResponseBody), "empty intermediate callback should output empty payload") + + usageChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{ + "usage": map[string]interface{}{ + "inputTokens": 10, + "outputTokens": 2, + "totalTokens": 12, + }, + }) + action = host.CallOnHttpStreamingResponseBody(usageChunk, true) + require.Equal(t, types.ActionContinue, action) + + finalResponseBody := host.GetResponseBody() + require.NotNil(t, finalResponseBody) + require.Contains(t, string(finalResponseBody), "data: [DONE]", "last chunk should append [DONE]") + finalDataPayload := extractFirstDataPayload(finalResponseBody) + require.NotEmpty(t, finalDataPayload, "final usage event should still be parsed") + + var finalResponseMap map[string]interface{} + err := json.Unmarshal([]byte(finalDataPayload), &finalResponseMap) + require.NoError(t, err) + finalChoices := finalResponseMap["choices"].([]interface{}) + require.Len(t, finalChoices, 0, "usage chunk should still keep empty choices") }) }) }