mirror of
https://github.com/alibaba/higress.git
synced 2026-06-06 11:17:29 +08:00
feat(bedrock): prompt caching params transform (#3563)
This commit is contained in:
@@ -199,6 +199,7 @@ func TestBedrock(t *testing.T) {
|
||||
test.RunBedrockOnHttpRequestBodyTests(t)
|
||||
test.RunBedrockOnHttpResponseHeadersTests(t)
|
||||
test.RunBedrockOnHttpResponseBodyTests(t)
|
||||
test.RunBedrockOnStreamingResponseBodyTests(t)
|
||||
test.RunBedrockToolCallTests(t)
|
||||
}
|
||||
|
||||
|
||||
@@ -35,9 +35,16 @@ const (
|
||||
// converseStream路径 /model/{modelId}/converse-stream
|
||||
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
|
||||
// invoke_model 路径 /model/{modelId}/invoke
|
||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
bedrockCacheTypeDefault = "default"
|
||||
bedrockCacheTTL5m = "5m"
|
||||
bedrockCacheTTL1h = "1h"
|
||||
|
||||
bedrockCachePointPositionSystemPrompt = "systemPrompt"
|
||||
bedrockCachePointPositionLastUserMessage = "lastUserMessage"
|
||||
bedrockCachePointPositionLastMessage = "lastMessage"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -169,9 +176,10 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
if bedrockEvent.Usage != nil {
|
||||
openAIFormattedChunk.Choices = choices[:0]
|
||||
openAIFormattedChunk.Usage = &usage{
|
||||
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
||||
PromptTokens: bedrockEvent.Usage.InputTokens,
|
||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
||||
PromptTokens: bedrockEvent.Usage.InputTokens,
|
||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens),
|
||||
}
|
||||
}
|
||||
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
|
||||
@@ -831,6 +839,13 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
},
|
||||
}
|
||||
|
||||
if origRequest.PromptCacheKey != "" {
|
||||
log.Warnf("bedrock provider ignores prompt_cache_key because Converse API has no equivalent field")
|
||||
}
|
||||
if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(origRequest.PromptCacheRetention); ok {
|
||||
addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions())
|
||||
}
|
||||
|
||||
if origRequest.ReasoningEffort != "" {
|
||||
thinkingBudget := 1024 // default
|
||||
switch origRequest.ReasoningEffort {
|
||||
@@ -932,9 +947,10 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
||||
Object: objectChatCompletion,
|
||||
Choices: choices,
|
||||
Usage: &usage{
|
||||
PromptTokens: bedrockResponse.Usage.InputTokens,
|
||||
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
||||
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
||||
PromptTokens: bedrockResponse.Usage.InputTokens,
|
||||
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
||||
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -965,6 +981,112 @@ func stopReasonBedrock2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
|
||||
switch retention {
|
||||
case "":
|
||||
return "", false
|
||||
case "in_memory":
|
||||
return bedrockCacheTTL5m, true
|
||||
case "24h":
|
||||
return bedrockCacheTTL1h, true
|
||||
default:
|
||||
log.Warnf("unsupported prompt_cache_retention for bedrock mapping: %s", retention)
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) getPromptCachePointPositions() map[string]bool {
|
||||
if b.config.bedrockPromptCachePointPositions == nil {
|
||||
return map[string]bool{
|
||||
bedrockCachePointPositionSystemPrompt: true,
|
||||
bedrockCachePointPositionLastMessage: false,
|
||||
}
|
||||
}
|
||||
positions := map[string]bool{
|
||||
bedrockCachePointPositionSystemPrompt: false,
|
||||
bedrockCachePointPositionLastUserMessage: false,
|
||||
bedrockCachePointPositionLastMessage: false,
|
||||
}
|
||||
for rawKey, enabled := range b.config.bedrockPromptCachePointPositions {
|
||||
key := normalizeBedrockCachePointPosition(rawKey)
|
||||
switch key {
|
||||
case bedrockCachePointPositionSystemPrompt, bedrockCachePointPositionLastUserMessage, bedrockCachePointPositionLastMessage:
|
||||
positions[key] = enabled
|
||||
default:
|
||||
log.Warnf("unsupported bedrockPromptCachePointPositions key: %s", rawKey)
|
||||
}
|
||||
}
|
||||
return positions
|
||||
}
|
||||
|
||||
func normalizeBedrockCachePointPosition(raw string) string {
|
||||
key := strings.ToLower(raw)
|
||||
key = strings.ReplaceAll(key, "_", "")
|
||||
key = strings.ReplaceAll(key, "-", "")
|
||||
switch key {
|
||||
case "systemprompt":
|
||||
return bedrockCachePointPositionSystemPrompt
|
||||
case "lastusermessage":
|
||||
return bedrockCachePointPositionLastUserMessage
|
||||
case "lastmessage":
|
||||
return bedrockCachePointPositionLastMessage
|
||||
default:
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
func addPromptCachePointsToBedrockRequest(request *bedrockTextGenRequest, cacheTTL string, positions map[string]bool) {
|
||||
if positions[bedrockCachePointPositionSystemPrompt] && len(request.System) > 0 {
|
||||
request.System = append(request.System, systemContentBlock{
|
||||
CachePoint: &bedrockCachePoint{
|
||||
Type: bedrockCacheTypeDefault,
|
||||
TTL: cacheTTL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
lastUserMessageIndex := -1
|
||||
if positions[bedrockCachePointPositionLastUserMessage] {
|
||||
lastUserMessageIndex = findLastMessageIndexByRole(request.Messages, roleUser)
|
||||
if lastUserMessageIndex >= 0 {
|
||||
appendCachePointToBedrockMessage(request, lastUserMessageIndex, cacheTTL)
|
||||
}
|
||||
}
|
||||
if positions[bedrockCachePointPositionLastMessage] && len(request.Messages) > 0 {
|
||||
lastMessageIndex := len(request.Messages) - 1
|
||||
if lastMessageIndex != lastUserMessageIndex {
|
||||
appendCachePointToBedrockMessage(request, lastMessageIndex, cacheTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findLastMessageIndexByRole(messages []bedrockMessage, role string) int {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == role {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageIndex int, cacheTTL string) {
|
||||
request.Messages[messageIndex].Content = append(request.Messages[messageIndex].Content, bedrockMessageContent{
|
||||
CachePoint: &bedrockCachePoint{
|
||||
Type: bedrockCacheTypeDefault,
|
||||
TTL: cacheTTL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildPromptTokensDetails(cacheReadInputTokens int) *promptTokensDetails {
|
||||
if cacheReadInputTokens <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &promptTokensDetails{
|
||||
CachedTokens: cacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
type bedrockTextGenRequest struct {
|
||||
Messages []bedrockMessage `json:"messages"`
|
||||
System []systemContentBlock `json:"system,omitempty"`
|
||||
@@ -1009,14 +1131,21 @@ type bedrockMessage struct {
|
||||
}
|
||||
|
||||
type bedrockMessageContent struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
|
||||
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
|
||||
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
|
||||
CachePoint *bedrockCachePoint `json:"cachePoint,omitempty"`
|
||||
}
|
||||
|
||||
type systemContentBlock struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
CachePoint *bedrockCachePoint `json:"cachePoint,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockCachePoint struct {
|
||||
Type string `json:"type"`
|
||||
TTL string `json:"ttl,omitempty"`
|
||||
}
|
||||
|
||||
type imageBlock struct {
|
||||
@@ -1098,6 +1227,10 @@ type tokenUsage struct {
|
||||
OutputTokens int `json:"outputTokens,omitempty"`
|
||||
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
|
||||
CacheReadInputTokens int `json:"cacheReadInputTokens,omitempty"`
|
||||
|
||||
CacheWriteInputTokens int `json:"cacheWriteInputTokens,omitempty"`
|
||||
}
|
||||
|
||||
func chatToolMessage2BedrockToolResultContent(chatMessage chatMessage) bedrockMessageContent {
|
||||
|
||||
@@ -42,34 +42,36 @@ type thinkingParam struct {
|
||||
|
||||
type chatCompletionRequest struct {
|
||||
NonOpenAIStyleOptions
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Prediction map[string]interface{} `json:"prediction,omitempty"`
|
||||
Audio map[string]interface{} `json:"audio,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Prediction map[string]interface{} `json:"prediction,omitempty"`
|
||||
Audio map[string]interface{} `json:"audio,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
PromptCacheRetention string `json:"prompt_cache_retention,omitempty"`
|
||||
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (c *chatCompletionRequest) getMaxTokens() int {
|
||||
|
||||
@@ -354,6 +354,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Amazon Bedrock 额外模型请求参数
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务,用于设置模型特定的推理参数
|
||||
bedrockAdditionalFields map[string]interface{} `required:"false" yaml:"bedrockAdditionalFields" json:"bedrockAdditionalFields"`
|
||||
// @Title zh-CN Amazon Bedrock Prompt CachePoint 插入位置
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务。用于配置 cachePoint 插入位置,支持多选:systemPrompt、lastUserMessage、lastMessage。值为 true 表示启用该位置。
|
||||
bedrockPromptCachePointPositions map[string]bool `required:"false" yaml:"bedrockPromptCachePointPositions" json:"bedrockPromptCachePointPositions"`
|
||||
// @Title zh-CN minimax API type
|
||||
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2
|
||||
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
|
||||
@@ -552,6 +555,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
for k, v := range json.Get("bedrockAdditionalFields").Map() {
|
||||
c.bedrockAdditionalFields[k] = v.Value()
|
||||
}
|
||||
if rawPositions := json.Get("bedrockPromptCachePointPositions"); rawPositions.Exists() {
|
||||
c.bedrockPromptCachePointPositions = make(map[string]bool)
|
||||
for k, v := range rawPositions.Map() {
|
||||
c.bedrockPromptCachePointPositions[k] = v.Bool()
|
||||
}
|
||||
}
|
||||
}
|
||||
c.minimaxApiType = json.Get("minimaxApiType").String()
|
||||
c.minimaxGroupId = json.Get("minimaxGroupId").String()
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"hash/crc32"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -112,6 +116,23 @@ var bedrockApiTokenConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
func bedrockApiTokenConfigWithCachePointPositions(positions map[string]bool) json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"apiTokens": []string{
|
||||
"test-token-for-unit-test",
|
||||
},
|
||||
"awsRegion": "us-east-1",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
"bedrockPromptCachePointPositions": positions,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
// Test config: Bedrock config with multiple Bearer Tokens
|
||||
var bedrockMultiTokenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
@@ -369,6 +390,372 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
||||
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
||||
})
|
||||
|
||||
t.Run("bedrock request body prompt cache in_memory should inject system cache point only by default", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"prompt_cache_retention": "in_memory",
|
||||
"prompt_cache_key": "session-001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, hasPromptCacheRetention := bodyMap["prompt_cache_retention"]
|
||||
require.False(t, hasPromptCacheRetention, "prompt_cache_retention should not be forwarded to Bedrock")
|
||||
_, hasPromptCacheKey := bodyMap["prompt_cache_key"]
|
||||
require.False(t, hasPromptCacheKey, "prompt_cache_key should not be forwarded to Bedrock")
|
||||
|
||||
systemBlocks, ok := bodyMap["system"].([]interface{})
|
||||
require.True(t, ok, "system should be an array")
|
||||
require.Len(t, systemBlocks, 2, "system should contain text block and cachePoint block")
|
||||
systemCachePointBlock := systemBlocks[len(systemBlocks)-1].(map[string]interface{})
|
||||
systemCachePoint, ok := systemCachePointBlock["cachePoint"].(map[string]interface{})
|
||||
require.True(t, ok, "system tail block should contain cachePoint")
|
||||
require.Equal(t, "default", systemCachePoint["type"])
|
||||
require.Equal(t, "5m", systemCachePoint["ttl"])
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
require.NotEmpty(t, messages, "messages should not be empty")
|
||||
lastMessage := messages[len(messages)-1].(map[string]interface{})
|
||||
lastMessageContent := lastMessage["content"].([]interface{})
|
||||
require.Len(t, lastMessageContent, 1, "last message should keep original content only by default")
|
||||
_, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"]
|
||||
require.False(t, hasMessageCachePoint, "last message should not include cachePoint by default")
|
||||
})
|
||||
|
||||
t.Run("bedrock request body prompt cache 24h should map to 1h ttl on system cache point by default", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"prompt_cache_retention": "24h",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
systemBlocks := bodyMap["system"].([]interface{})
|
||||
systemCachePointBlock := systemBlocks[len(systemBlocks)-1].(map[string]interface{})
|
||||
systemCachePoint := systemCachePointBlock["cachePoint"].(map[string]interface{})
|
||||
require.Equal(t, "1h", systemCachePoint["ttl"])
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
lastMessage := messages[len(messages)-1].(map[string]interface{})
|
||||
lastMessageContent := lastMessage["content"].([]interface{})
|
||||
require.Len(t, lastMessageContent, 1, "last message should keep original content only by default")
|
||||
_, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"]
|
||||
require.False(t, hasMessageCachePoint, "last message should not include cachePoint by default")
|
||||
})
|
||||
|
||||
t.Run("bedrock request body prompt cache should insert cache points based on configured positions", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfigWithCachePointPositions(map[string]bool{
|
||||
"systemPrompt": true,
|
||||
"lastUserMessage": true,
|
||||
"lastMessage": false,
|
||||
}))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"prompt_cache_retention": "in_memory",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Question from user"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Previous assistant answer"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
systemBlocks := bodyMap["system"].([]interface{})
|
||||
require.Len(t, systemBlocks, 2, "system should include cachePoint due to systemPrompt=true")
|
||||
systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
||||
require.Equal(t, "5m", systemCachePoint["ttl"])
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
require.Len(t, messages, 2, "system message should not be in messages array")
|
||||
|
||||
lastUserMessageContent := messages[0].(map[string]interface{})["content"].([]interface{})
|
||||
require.Len(t, lastUserMessageContent, 2, "last user message should include one cachePoint")
|
||||
lastUserMessageCachePoint := lastUserMessageContent[len(lastUserMessageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
||||
require.Equal(t, "5m", lastUserMessageCachePoint["ttl"])
|
||||
|
||||
lastMessageContent := messages[1].(map[string]interface{})["content"].([]interface{})
|
||||
require.Len(t, lastMessageContent, 1, "last message should not include cachePoint when lastMessage=false")
|
||||
})
|
||||
|
||||
t.Run("bedrock request body prompt cache should avoid duplicate insertion when lastUserMessage and lastMessage overlap", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfigWithCachePointPositions(map[string]bool{
|
||||
"systemPrompt": false,
|
||||
"lastUserMessage": true,
|
||||
"lastMessage": true,
|
||||
}))
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"prompt_cache_retention": "in_memory",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, hasSystem := bodyMap["system"]
|
||||
require.False(t, hasSystem, "system should not include cachePoint when systemPrompt=false and no system messages")
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
require.Len(t, messages, 1, "only one message should exist")
|
||||
messageContent := messages[0].(map[string]interface{})["content"].([]interface{})
|
||||
require.Len(t, messageContent, 2, "overlap positions should still insert only one cachePoint")
|
||||
cachePoint := messageContent[len(messageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
||||
require.Equal(t, "5m", cachePoint["ttl"])
|
||||
})
|
||||
|
||||
t.Run("bedrock request body with empty prompt cache retention should not inject cache points", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"prompt_cache_retention": "",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
systemBlocks := bodyMap["system"].([]interface{})
|
||||
require.Len(t, systemBlocks, 1, "system should only contain the original text block")
|
||||
_, hasSystemCachePoint := systemBlocks[0].(map[string]interface{})["cachePoint"]
|
||||
require.False(t, hasSystemCachePoint, "system block should not include cachePoint when retention is empty")
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
lastMessage := messages[len(messages)-1].(map[string]interface{})
|
||||
lastMessageContent := lastMessage["content"].([]interface{})
|
||||
require.Len(t, lastMessageContent, 1, "message should only contain original text block")
|
||||
_, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"]
|
||||
require.False(t, hasMessageCachePoint, "message block should not include cachePoint when retention is empty")
|
||||
})
|
||||
|
||||
t.Run("bedrock request body with unsupported prompt cache retention should not inject cache points", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"prompt_cache_retention": "2h",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
systemBlocks := bodyMap["system"].([]interface{})
|
||||
require.Len(t, systemBlocks, 1, "system should only contain the original text block")
|
||||
_, hasSystemCachePoint := systemBlocks[0].(map[string]interface{})["cachePoint"]
|
||||
require.False(t, hasSystemCachePoint, "system block should not include cachePoint when retention is unsupported")
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
lastMessage := messages[len(messages)-1].(map[string]interface{})
|
||||
lastMessageContent := lastMessage["content"].([]interface{})
|
||||
require.Len(t, lastMessageContent, 1, "message should only contain original text block")
|
||||
_, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"]
|
||||
require.False(t, hasMessageCachePoint, "message block should not include cachePoint when retention is unsupported")
|
||||
})
|
||||
|
||||
t.Run("bedrock request body without system should not inject cache point by default", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"prompt_cache_retention": "in_memory",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(processedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, hasSystem := bodyMap["system"]
|
||||
require.False(t, hasSystem, "system should be omitted when original request has no system prompts")
|
||||
|
||||
messages := bodyMap["messages"].([]interface{})
|
||||
require.Len(t, messages, 1, "messages should keep original one user message")
|
||||
lastMessage := messages[0].(map[string]interface{})
|
||||
lastMessageContent := lastMessage["content"].([]interface{})
|
||||
require.Len(t, lastMessageContent, 1, "message should keep original text block only by default")
|
||||
_, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"]
|
||||
require.False(t, hasMessageCachePoint, "message should not include cachePoint by default")
|
||||
})
|
||||
|
||||
// Test Bedrock request body processing with AWS Signature V4 authentication
|
||||
t.Run("bedrock chat completion request body with ak/sk", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicBedrockConfig)
|
||||
@@ -911,7 +1298,9 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
|
||||
"usage": {
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 15,
|
||||
"totalTokens": 25
|
||||
"totalTokens": 25,
|
||||
"cacheReadInputTokens": 6,
|
||||
"cacheWriteInputTokens": 12
|
||||
}
|
||||
}`
|
||||
|
||||
@@ -935,6 +1324,176 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
|
||||
usage, exists := responseMap["usage"]
|
||||
require.True(t, exists, "Usage should exist in response body")
|
||||
require.NotNil(t, usage, "Usage should not be nil")
|
||||
usageMap := usage.(map[string]interface{})
|
||||
promptTokensDetails, hasPromptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
|
||||
require.True(t, hasPromptTokensDetails, "prompt_tokens_details should exist when cacheReadInputTokens is present")
|
||||
require.Equal(t, float64(6), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens")
|
||||
})
|
||||
|
||||
t.Run("bedrock response body with zero cache read tokens should omit prompt_tokens_details", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
responseBody := `{
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "Hello! How can I help you today?"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"stopReason": "end_turn",
|
||||
"usage": {
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 15,
|
||||
"totalTokens": 25,
|
||||
"cacheReadInputTokens": 0
|
||||
}
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedResponseBody, &responseMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
usageMap := responseMap["usage"].(map[string]interface{})
|
||||
_, hasPromptTokensDetails := usageMap["prompt_tokens_details"]
|
||||
require.False(t, hasPromptTokensDetails, "prompt_tokens_details should be omitted when cacheReadInputTokens is zero")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnStreamingResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/vnd.amazon.eventstream"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
streamingChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{
|
||||
"usage": map[string]interface{}{
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 2,
|
||||
"totalTokens": 12,
|
||||
"cacheReadInputTokens": 7,
|
||||
"cacheWriteInputTokens": 3,
|
||||
},
|
||||
})
|
||||
action = host.CallOnHttpStreamingResponseBody(streamingChunk, true)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
|
||||
var dataPayload string
|
||||
for _, line := range strings.Split(string(transformedResponseBody), "\n") {
|
||||
if strings.HasPrefix(line, "data: ") && line != "data: [DONE]" {
|
||||
dataPayload = strings.TrimPrefix(line, "data: ")
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, dataPayload, "should have at least one SSE data payload")
|
||||
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal([]byte(dataPayload), &responseMap)
|
||||
require.NoError(t, err)
|
||||
usageMap := responseMap["usage"].(map[string]interface{})
|
||||
promptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
|
||||
require.Equal(t, float64(7), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens in streaming usage event")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func buildBedrockEventStreamMessage(t *testing.T, payload map[string]interface{}) []byte {
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
totalLength := uint32(16 + len(payloadBytes))
|
||||
headersLength := uint32(0)
|
||||
|
||||
var message bytes.Buffer
|
||||
prelude := make([]byte, 8)
|
||||
binary.BigEndian.PutUint32(prelude[0:4], totalLength)
|
||||
binary.BigEndian.PutUint32(prelude[4:8], headersLength)
|
||||
message.Write(prelude)
|
||||
|
||||
preludeCRC := crc32.ChecksumIEEE(prelude)
|
||||
preludeCRCBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(preludeCRCBytes, preludeCRC)
|
||||
message.Write(preludeCRCBytes)
|
||||
|
||||
message.Write(payloadBytes)
|
||||
|
||||
messageCRC := crc32.ChecksumIEEE(message.Bytes())
|
||||
messageCRCBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(messageCRCBytes, messageCRC)
|
||||
message.Write(messageCRCBytes)
|
||||
|
||||
return message.Bytes()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user