mirror of
https://github.com/alibaba/higress.git
synced 2026-06-08 20:27:31 +08:00
feat(ai-proxy): 添加Amazon Bedrock Prompt Cache保留策略配置及优化缓存处理逻辑 (#3609)
This commit is contained in:
@@ -35,12 +35,14 @@ const (
|
|||||||
// converseStream路径 /model/{modelId}/converse-stream
|
// converseStream路径 /model/{modelId}/converse-stream
|
||||||
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
|
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
|
||||||
// invoke_model 路径 /model/{modelId}/invoke
|
// invoke_model 路径 /model/{modelId}/invoke
|
||||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||||
bedrockSignedHeaders = "host;x-amz-date"
|
bedrockSignedHeaders = "host;x-amz-date"
|
||||||
requestIdHeader = "X-Amzn-Requestid"
|
requestIdHeader = "X-Amzn-Requestid"
|
||||||
bedrockCacheTypeDefault = "default"
|
bedrockCacheTypeDefault = "default"
|
||||||
bedrockCacheTTL5m = "5m"
|
bedrockCacheTTL5m = "5m"
|
||||||
bedrockCacheTTL1h = "1h"
|
bedrockCacheTTL1h = "1h"
|
||||||
|
bedrockPromptCacheNova = "amazon.nova"
|
||||||
|
bedrockPromptCacheClaude = "anthropic.claude"
|
||||||
|
|
||||||
bedrockCachePointPositionSystemPrompt = "systemPrompt"
|
bedrockCachePointPositionSystemPrompt = "systemPrompt"
|
||||||
bedrockCachePointPositionLastUserMessage = "lastUserMessage"
|
bedrockCachePointPositionLastUserMessage = "lastUserMessage"
|
||||||
@@ -179,7 +181,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
|||||||
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
||||||
PromptTokens: bedrockEvent.Usage.InputTokens,
|
PromptTokens: bedrockEvent.Usage.InputTokens,
|
||||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||||
PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens),
|
PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens, bedrockEvent.Usage.CacheWriteInputTokens),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
|
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
|
||||||
@@ -839,11 +841,17 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
effectivePromptCacheRetention := b.resolvePromptCacheRetention(origRequest.PromptCacheRetention)
|
||||||
|
|
||||||
if origRequest.PromptCacheKey != "" {
|
if origRequest.PromptCacheKey != "" {
|
||||||
log.Warnf("bedrock provider ignores prompt_cache_key because Converse API has no equivalent field")
|
log.Warnf("bedrock provider ignores prompt_cache_key because Converse API has no equivalent field")
|
||||||
}
|
}
|
||||||
if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(origRequest.PromptCacheRetention); ok {
|
if isPromptCacheSupportedModel(origRequest.Model) {
|
||||||
addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions())
|
if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(effectivePromptCacheRetention); ok {
|
||||||
|
addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions())
|
||||||
|
}
|
||||||
|
} else if effectivePromptCacheRetention != "" {
|
||||||
|
log.Warnf("skip prompt cache injection for unsupported model: %s", origRequest.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
if origRequest.ReasoningEffort != "" {
|
if origRequest.ReasoningEffort != "" {
|
||||||
@@ -950,7 +958,7 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
|||||||
PromptTokens: bedrockResponse.Usage.InputTokens,
|
PromptTokens: bedrockResponse.Usage.InputTokens,
|
||||||
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
||||||
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
||||||
PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens),
|
PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens, bedrockResponse.Usage.CacheWriteInputTokens),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -982,11 +990,14 @@ func stopReasonBedrock2OpenAI(reason string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
|
func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
|
||||||
switch retention {
|
normalizedRetention := normalizePromptCacheRetention(retention)
|
||||||
|
switch normalizedRetention {
|
||||||
case "":
|
case "":
|
||||||
return "", false
|
return "", false
|
||||||
case "in_memory":
|
case "in_memory":
|
||||||
return bedrockCacheTTL5m, true
|
// For the default 5-minute cache, omit ttl and let Bedrock apply its default.
|
||||||
|
// This is more robust for models that are strict about explicit ttl fields.
|
||||||
|
return "", true
|
||||||
case "24h":
|
case "24h":
|
||||||
return bedrockCacheTTL1h, true
|
return bedrockCacheTTL1h, true
|
||||||
default:
|
default:
|
||||||
@@ -995,6 +1006,32 @@ func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizePromptCacheRetention(retention string) string {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(retention))
|
||||||
|
normalized = strings.ReplaceAll(normalized, "-", "_")
|
||||||
|
normalized = strings.ReplaceAll(normalized, " ", "_")
|
||||||
|
if normalized == "inmemory" {
|
||||||
|
return "in_memory"
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPromptCacheSupportedModel(model string) bool {
|
||||||
|
normalizedModel := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
return strings.Contains(normalizedModel, bedrockPromptCacheNova) ||
|
||||||
|
strings.Contains(normalizedModel, bedrockPromptCacheClaude)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *bedrockProvider) resolvePromptCacheRetention(requestPromptCacheRetention string) string {
|
||||||
|
if requestPromptCacheRetention != "" {
|
||||||
|
return requestPromptCacheRetention
|
||||||
|
}
|
||||||
|
if b.config.promptCacheRetention != "" {
|
||||||
|
return b.config.promptCacheRetention
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (b *bedrockProvider) getPromptCachePointPositions() map[string]bool {
|
func (b *bedrockProvider) getPromptCachePointPositions() map[string]bool {
|
||||||
if b.config.bedrockPromptCachePointPositions == nil {
|
if b.config.bedrockPromptCachePointPositions == nil {
|
||||||
return map[string]bool{
|
return map[string]bool{
|
||||||
@@ -1070,6 +1107,9 @@ func findLastMessageIndexByRole(messages []bedrockMessage, role string) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageIndex int, cacheTTL string) {
|
func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageIndex int, cacheTTL string) {
|
||||||
|
if messageIndex < 0 || messageIndex >= len(request.Messages) {
|
||||||
|
return
|
||||||
|
}
|
||||||
request.Messages[messageIndex].Content = append(request.Messages[messageIndex].Content, bedrockMessageContent{
|
request.Messages[messageIndex].Content = append(request.Messages[messageIndex].Content, bedrockMessageContent{
|
||||||
CachePoint: &bedrockCachePoint{
|
CachePoint: &bedrockCachePoint{
|
||||||
Type: bedrockCacheTypeDefault,
|
Type: bedrockCacheTypeDefault,
|
||||||
@@ -1078,12 +1118,13 @@ func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageInd
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildPromptTokensDetails(cacheReadInputTokens int) *promptTokensDetails {
|
func buildPromptTokensDetails(cacheReadInputTokens int, cacheWriteInputTokens int) *promptTokensDetails {
|
||||||
if cacheReadInputTokens <= 0 {
|
totalCachedTokens := cacheReadInputTokens + cacheWriteInputTokens
|
||||||
|
if totalCachedTokens <= 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &promptTokensDetails{
|
return &promptTokensDetails{
|
||||||
CachedTokens: cacheReadInputTokens,
|
CachedTokens: totalCachedTokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -102,3 +102,90 @@ func TestGenerateSignatureDiffersForRawAndPreEncodedModelPath(t *testing.T) {
|
|||||||
preEncodedSignature := p.generateSignature(preEncodedPath, "20260312T142942Z", "20260312", body)
|
preEncodedSignature := p.generateSignature(preEncodedPath, "20260312T142942Z", "20260312", body)
|
||||||
assert.NotEqual(t, rawSignature, preEncodedSignature)
|
assert.NotEqual(t, rawSignature, preEncodedSignature)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizePromptCacheRetention(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
retention string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "inmemory alias maps to in_memory",
|
||||||
|
retention: "inmemory",
|
||||||
|
want: "in_memory",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dash style maps to in_memory",
|
||||||
|
retention: "in-memory",
|
||||||
|
want: "in_memory",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "space style with trim maps to in_memory",
|
||||||
|
retention: " in memory ",
|
||||||
|
want: "in_memory",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "already normalized remains unchanged",
|
||||||
|
retention: "in_memory",
|
||||||
|
want: "in_memory",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.want, normalizePromptCacheRetention(tt.retention))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendCachePointToBedrockMessageInvalidIndexNoop(t *testing.T) {
|
||||||
|
request := &bedrockTextGenRequest{
|
||||||
|
Messages: []bedrockMessage{
|
||||||
|
{
|
||||||
|
Role: roleUser,
|
||||||
|
Content: []bedrockMessageContent{
|
||||||
|
{Text: "hello"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
appendCachePointToBedrockMessage(request, -1, bedrockCacheTTL5m)
|
||||||
|
appendCachePointToBedrockMessage(request, len(request.Messages), bedrockCacheTTL5m)
|
||||||
|
|
||||||
|
assert.Len(t, request.Messages[0].Content, 1)
|
||||||
|
|
||||||
|
appendCachePointToBedrockMessage(request, 0, bedrockCacheTTL5m)
|
||||||
|
assert.Len(t, request.Messages[0].Content, 2)
|
||||||
|
assert.NotNil(t, request.Messages[0].Content[1].CachePoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPromptCacheSupportedModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "anthropic claude model is supported",
|
||||||
|
model: "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "amazon nova inference profile is supported",
|
||||||
|
model: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.amazon.nova-2-lite-v1:0",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other model is not supported",
|
||||||
|
model: "meta.llama3-70b-instruct-v1:0",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.want, isPromptCacheSupportedModel(tt.model))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -357,6 +357,9 @@ type ProviderConfig struct {
|
|||||||
// @Title zh-CN Amazon Bedrock Prompt CachePoint 插入位置
|
// @Title zh-CN Amazon Bedrock Prompt CachePoint 插入位置
|
||||||
// @Description zh-CN 仅适用于Amazon Bedrock服务。用于配置 cachePoint 插入位置,支持多选:systemPrompt、lastUserMessage、lastMessage。值为 true 表示启用该位置。
|
// @Description zh-CN 仅适用于Amazon Bedrock服务。用于配置 cachePoint 插入位置,支持多选:systemPrompt、lastUserMessage、lastMessage。值为 true 表示启用该位置。
|
||||||
bedrockPromptCachePointPositions map[string]bool `required:"false" yaml:"bedrockPromptCachePointPositions" json:"bedrockPromptCachePointPositions"`
|
bedrockPromptCachePointPositions map[string]bool `required:"false" yaml:"bedrockPromptCachePointPositions" json:"bedrockPromptCachePointPositions"`
|
||||||
|
// @Title zh-CN Amazon Bedrock Prompt Cache 保留策略(默认值)
|
||||||
|
// @Description zh-CN 仅适用于Amazon Bedrock服务。作为请求中 prompt_cache_retention 缺省时的默认值,支持 in_memory 和 24h。
|
||||||
|
promptCacheRetention string `required:"false" yaml:"promptCacheRetention" json:"promptCacheRetention"`
|
||||||
// @Title zh-CN minimax API type
|
// @Title zh-CN minimax API type
|
||||||
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2
|
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2
|
||||||
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
|
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
|
||||||
@@ -558,6 +561,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
|||||||
for k, v := range json.Get("bedrockAdditionalFields").Map() {
|
for k, v := range json.Get("bedrockAdditionalFields").Map() {
|
||||||
c.bedrockAdditionalFields[k] = v.Value()
|
c.bedrockAdditionalFields[k] = v.Value()
|
||||||
}
|
}
|
||||||
|
c.promptCacheRetention = json.Get("promptCacheRetention").String()
|
||||||
if rawPositions := json.Get("bedrockPromptCachePointPositions"); rawPositions.Exists() {
|
if rawPositions := json.Get("bedrockPromptCachePointPositions"); rawPositions.Exists() {
|
||||||
c.bedrockPromptCachePointPositions = make(map[string]bool)
|
c.bedrockPromptCachePointPositions = make(map[string]bool)
|
||||||
for k, v := range rawPositions.Map() {
|
for k, v := range rawPositions.Map() {
|
||||||
|
|||||||
@@ -133,6 +133,41 @@ func bedrockApiTokenConfigWithCachePointPositions(positions map[string]bool) jso
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func bedrockApiTokenConfigWithPromptCacheRetention(promptCacheRetention string) json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{
|
||||||
|
"type": "bedrock",
|
||||||
|
"apiTokens": []string{
|
||||||
|
"test-token-for-unit-test",
|
||||||
|
},
|
||||||
|
"awsRegion": "us-east-1",
|
||||||
|
"modelMapping": map[string]string{
|
||||||
|
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
},
|
||||||
|
"promptCacheRetention": promptCacheRetention,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func bedrockApiTokenConfigWithModelAndPromptCache(mappedModel, promptCacheRetention string, positions map[string]bool) json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{
|
||||||
|
"type": "bedrock",
|
||||||
|
"apiTokens": []string{
|
||||||
|
"test-token-for-unit-test",
|
||||||
|
},
|
||||||
|
"awsRegion": "us-east-1",
|
||||||
|
"modelMapping": map[string]string{
|
||||||
|
"*": mappedModel,
|
||||||
|
},
|
||||||
|
"promptCacheRetention": promptCacheRetention,
|
||||||
|
"bedrockPromptCachePointPositions": positions,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
// Test config: Bedrock config with multiple Bearer Tokens
|
// Test config: Bedrock config with multiple Bearer Tokens
|
||||||
var bedrockMultiTokenConfig = func() json.RawMessage {
|
var bedrockMultiTokenConfig = func() json.RawMessage {
|
||||||
data, _ := json.Marshal(map[string]interface{}{
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
@@ -390,7 +425,7 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
|||||||
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("bedrock request body prompt cache in_memory should inject system cache point only by default", func(t *testing.T) {
|
t.Run("bedrock request body prompt cache in-memory should inject system cache point only by default", func(t *testing.T) {
|
||||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
defer host.Reset()
|
defer host.Reset()
|
||||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
@@ -405,7 +440,7 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
|||||||
|
|
||||||
requestBody := `{
|
requestBody := `{
|
||||||
"model": "gpt-4",
|
"model": "gpt-4",
|
||||||
"prompt_cache_retention": "in_memory",
|
"prompt_cache_retention": "in-memory",
|
||||||
"prompt_cache_key": "session-001",
|
"prompt_cache_key": "session-001",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
@@ -440,7 +475,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
|||||||
systemCachePoint, ok := systemCachePointBlock["cachePoint"].(map[string]interface{})
|
systemCachePoint, ok := systemCachePointBlock["cachePoint"].(map[string]interface{})
|
||||||
require.True(t, ok, "system tail block should contain cachePoint")
|
require.True(t, ok, "system tail block should contain cachePoint")
|
||||||
require.Equal(t, "default", systemCachePoint["type"])
|
require.Equal(t, "default", systemCachePoint["type"])
|
||||||
require.Equal(t, "5m", systemCachePoint["ttl"])
|
_, hasTTL := systemCachePoint["ttl"]
|
||||||
|
require.False(t, hasTTL, "ttl should be omitted for in_memory to use Bedrock default 5m")
|
||||||
|
|
||||||
messages := bodyMap["messages"].([]interface{})
|
messages := bodyMap["messages"].([]interface{})
|
||||||
require.NotEmpty(t, messages, "messages should not be empty")
|
require.NotEmpty(t, messages, "messages should not be empty")
|
||||||
@@ -451,6 +487,91 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
|||||||
require.False(t, hasMessageCachePoint, "last message should not include cachePoint by default")
|
require.False(t, hasMessageCachePoint, "last message should not include cachePoint by default")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("bedrock request body should use provider promptCacheRetention in-memory when request omits prompt_cache_retention", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockApiTokenConfigWithPromptCacheRetention("in-memory"))
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/chat/completions"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
var bodyMap map[string]interface{}
|
||||||
|
err := json.Unmarshal(processedBody, &bodyMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
systemBlocks := bodyMap["system"].([]interface{})
|
||||||
|
require.Len(t, systemBlocks, 2, "provider promptCacheRetention should trigger cachePoint injection")
|
||||||
|
systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
||||||
|
_, hasTTL := systemCachePoint["ttl"]
|
||||||
|
require.False(t, hasTTL, "provider promptCacheRetention=in-memory should omit ttl and use Bedrock default 5m")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bedrock request body prompt_cache_retention should override provider promptCacheRetention", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockApiTokenConfigWithPromptCacheRetention("in_memory"))
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/chat/completions"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"prompt_cache_retention": "24h",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
var bodyMap map[string]interface{}
|
||||||
|
err := json.Unmarshal(processedBody, &bodyMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
systemBlocks := bodyMap["system"].([]interface{})
|
||||||
|
systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
||||||
|
require.Equal(t, "1h", systemCachePoint["ttl"], "request prompt_cache_retention should override provider promptCacheRetention")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("bedrock request body prompt cache 24h should map to 1h ttl on system cache point by default", func(t *testing.T) {
|
t.Run("bedrock request body prompt cache 24h should map to 1h ttl on system cache point by default", func(t *testing.T) {
|
||||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
defer host.Reset()
|
defer host.Reset()
|
||||||
@@ -549,7 +670,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
|||||||
systemBlocks := bodyMap["system"].([]interface{})
|
systemBlocks := bodyMap["system"].([]interface{})
|
||||||
require.Len(t, systemBlocks, 2, "system should include cachePoint due to systemPrompt=true")
|
require.Len(t, systemBlocks, 2, "system should include cachePoint due to systemPrompt=true")
|
||||||
systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
systemCachePoint := systemBlocks[len(systemBlocks)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
||||||
require.Equal(t, "5m", systemCachePoint["ttl"])
|
_, hasSystemTTL := systemCachePoint["ttl"]
|
||||||
|
require.False(t, hasSystemTTL, "ttl should be omitted for in_memory cachePoint")
|
||||||
|
|
||||||
messages := bodyMap["messages"].([]interface{})
|
messages := bodyMap["messages"].([]interface{})
|
||||||
require.Len(t, messages, 2, "system message should not be in messages array")
|
require.Len(t, messages, 2, "system message should not be in messages array")
|
||||||
@@ -557,7 +679,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
|||||||
lastUserMessageContent := messages[0].(map[string]interface{})["content"].([]interface{})
|
lastUserMessageContent := messages[0].(map[string]interface{})["content"].([]interface{})
|
||||||
require.Len(t, lastUserMessageContent, 2, "last user message should include one cachePoint")
|
require.Len(t, lastUserMessageContent, 2, "last user message should include one cachePoint")
|
||||||
lastUserMessageCachePoint := lastUserMessageContent[len(lastUserMessageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
lastUserMessageCachePoint := lastUserMessageContent[len(lastUserMessageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
||||||
require.Equal(t, "5m", lastUserMessageCachePoint["ttl"])
|
_, hasLastUserTTL := lastUserMessageCachePoint["ttl"]
|
||||||
|
require.False(t, hasLastUserTTL, "ttl should be omitted for in_memory cachePoint")
|
||||||
|
|
||||||
lastMessageContent := messages[1].(map[string]interface{})["content"].([]interface{})
|
lastMessageContent := messages[1].(map[string]interface{})["content"].([]interface{})
|
||||||
require.Len(t, lastMessageContent, 1, "last message should not include cachePoint when lastMessage=false")
|
require.Len(t, lastMessageContent, 1, "last message should not include cachePoint when lastMessage=false")
|
||||||
@@ -608,7 +731,8 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
|||||||
messageContent := messages[0].(map[string]interface{})["content"].([]interface{})
|
messageContent := messages[0].(map[string]interface{})["content"].([]interface{})
|
||||||
require.Len(t, messageContent, 2, "overlap positions should still insert only one cachePoint")
|
require.Len(t, messageContent, 2, "overlap positions should still insert only one cachePoint")
|
||||||
cachePoint := messageContent[len(messageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
cachePoint := messageContent[len(messageContent)-1].(map[string]interface{})["cachePoint"].(map[string]interface{})
|
||||||
require.Equal(t, "5m", cachePoint["ttl"])
|
_, hasTTL := cachePoint["ttl"]
|
||||||
|
require.False(t, hasTTL, "ttl should be omitted for in_memory cachePoint")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("bedrock request body with empty prompt cache retention should not inject cache points", func(t *testing.T) {
|
t.Run("bedrock request body with empty prompt cache retention should not inject cache points", func(t *testing.T) {
|
||||||
@@ -711,6 +835,63 @@ func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
|||||||
require.False(t, hasMessageCachePoint, "message block should not include cachePoint when retention is unsupported")
|
require.False(t, hasMessageCachePoint, "message block should not include cachePoint when retention is unsupported")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("bedrock request body should skip prompt cache for unsupported model even when enabled", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockApiTokenConfigWithModelAndPromptCache(
|
||||||
|
"meta.llama3-70b-instruct-v1:0",
|
||||||
|
"in_memory",
|
||||||
|
map[string]bool{
|
||||||
|
"systemPrompt": true,
|
||||||
|
"lastMessage": true,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/chat/completions"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"prompt_cache_retention": "24h",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
var bodyMap map[string]interface{}
|
||||||
|
err := json.Unmarshal(processedBody, &bodyMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
systemBlocks := bodyMap["system"].([]interface{})
|
||||||
|
require.Len(t, systemBlocks, 1, "unsupported model should skip system cachePoint injection")
|
||||||
|
_, hasSystemCachePoint := systemBlocks[0].(map[string]interface{})["cachePoint"]
|
||||||
|
require.False(t, hasSystemCachePoint, "unsupported model should not contain system cachePoint")
|
||||||
|
|
||||||
|
messages := bodyMap["messages"].([]interface{})
|
||||||
|
require.Len(t, messages, 1, "system message should not be in messages array")
|
||||||
|
lastMessageContent := messages[0].(map[string]interface{})["content"].([]interface{})
|
||||||
|
require.Len(t, lastMessageContent, 1, "unsupported model should skip message cachePoint injection")
|
||||||
|
_, hasMessageCachePoint := lastMessageContent[0].(map[string]interface{})["cachePoint"]
|
||||||
|
require.False(t, hasMessageCachePoint, "unsupported model should not contain message cachePoint")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("bedrock request body without system should not inject cache point by default", func(t *testing.T) {
|
t.Run("bedrock request body without system should not inject cache point by default", func(t *testing.T) {
|
||||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
defer host.Reset()
|
defer host.Reset()
|
||||||
@@ -1327,7 +1508,9 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
|
|||||||
usageMap := usage.(map[string]interface{})
|
usageMap := usage.(map[string]interface{})
|
||||||
promptTokensDetails, hasPromptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
|
promptTokensDetails, hasPromptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
|
||||||
require.True(t, hasPromptTokensDetails, "prompt_tokens_details should exist when cacheReadInputTokens is present")
|
require.True(t, hasPromptTokensDetails, "prompt_tokens_details should exist when cacheReadInputTokens is present")
|
||||||
require.Equal(t, float64(6), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens")
|
require.Equal(t, float64(18), promptTokensDetails["cached_tokens"], "cached_tokens should sum cacheReadInputTokens and cacheWriteInputTokens")
|
||||||
|
_, hasCacheWriteTokens := promptTokensDetails["cache_write_tokens"]
|
||||||
|
require.False(t, hasCacheWriteTokens, "cache_write_tokens should not exist in OpenAI-compatible usage")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("bedrock response body with zero cache read tokens should omit prompt_tokens_details", func(t *testing.T) {
|
t.Run("bedrock response body with zero cache read tokens should omit prompt_tokens_details", func(t *testing.T) {
|
||||||
@@ -1397,11 +1580,95 @@ func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
|
|||||||
_, hasPromptTokensDetails := usageMap["prompt_tokens_details"]
|
_, hasPromptTokensDetails := usageMap["prompt_tokens_details"]
|
||||||
require.False(t, hasPromptTokensDetails, "prompt_tokens_details should be omitted when cacheReadInputTokens is zero")
|
require.False(t, hasPromptTokensDetails, "prompt_tokens_details should be omitted when cacheReadInputTokens is zero")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("bedrock response body with only cache write tokens should map to cached_tokens", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/chat/completions"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||||
|
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": "Hello! How can I help you today?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stopReason": "end_turn",
|
||||||
|
"usage": {
|
||||||
|
"inputTokens": 10,
|
||||||
|
"outputTokens": 15,
|
||||||
|
"totalTokens": 25,
|
||||||
|
"cacheReadInputTokens": 0,
|
||||||
|
"cacheWriteInputTokens": 9
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpResponseBody([]byte(responseBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
transformedResponseBody := host.GetResponseBody()
|
||||||
|
require.NotNil(t, transformedResponseBody)
|
||||||
|
|
||||||
|
var responseMap map[string]interface{}
|
||||||
|
err := json.Unmarshal(transformedResponseBody, &responseMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
usageMap := responseMap["usage"].(map[string]interface{})
|
||||||
|
promptTokensDetails, hasPromptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
|
||||||
|
require.True(t, hasPromptTokensDetails, "prompt_tokens_details should exist when cacheWriteInputTokens is present")
|
||||||
|
require.Equal(t, float64(9), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheWriteInputTokens when cacheReadInputTokens is zero")
|
||||||
|
_, hasCacheWriteTokens := promptTokensDetails["cache_write_tokens"]
|
||||||
|
require.False(t, hasCacheWriteTokens, "cache_write_tokens should not exist in OpenAI-compatible usage")
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunBedrockOnStreamingResponseBodyTests(t *testing.T) {
|
func RunBedrockOnStreamingResponseBodyTests(t *testing.T) {
|
||||||
test.RunTest(t, func(t *testing.T) {
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
extractFirstDataPayload := func(body []byte) string {
|
||||||
|
for _, line := range strings.Split(string(body), "\n") {
|
||||||
|
if strings.HasPrefix(line, "data: ") && line != "data: [DONE]" {
|
||||||
|
return strings.TrimPrefix(line, "data: ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("extract first data payload should return empty when no data line", func(t *testing.T) {
|
||||||
|
payload := extractFirstDataPayload([]byte("event: ping\n\n"))
|
||||||
|
require.Equal(t, "", payload)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) {
|
t.Run("bedrock streaming usage should map cached_tokens", func(t *testing.T) {
|
||||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
defer host.Reset()
|
defer host.Reset()
|
||||||
@@ -1465,7 +1732,147 @@ func RunBedrockOnStreamingResponseBodyTests(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
usageMap := responseMap["usage"].(map[string]interface{})
|
usageMap := responseMap["usage"].(map[string]interface{})
|
||||||
promptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
|
promptTokensDetails := usageMap["prompt_tokens_details"].(map[string]interface{})
|
||||||
require.Equal(t, float64(7), promptTokensDetails["cached_tokens"], "cached_tokens should map from cacheReadInputTokens in streaming usage event")
|
require.Equal(t, float64(10), promptTokensDetails["cached_tokens"], "cached_tokens should sum cacheReadInputTokens and cacheWriteInputTokens in streaming usage event")
|
||||||
|
_, hasCacheWriteTokens := promptTokensDetails["cache_write_tokens"]
|
||||||
|
require.False(t, hasCacheWriteTokens, "cache_write_tokens should not exist in OpenAI-compatible streaming usage")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bedrock streaming text chunk then usage chunk format is stable", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/chat/completions"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": true
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||||
|
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"Content-Type", "application/vnd.amazon.eventstream"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
textChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"text": "Hello from Bedrock",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
action = host.CallOnHttpStreamingResponseBody(textChunk, false)
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
firstResponseBody := host.GetResponseBody()
|
||||||
|
require.NotNil(t, firstResponseBody)
|
||||||
|
firstDataPayload := extractFirstDataPayload(firstResponseBody)
|
||||||
|
require.NotEmpty(t, firstDataPayload, "first chunk should contain one SSE data payload")
|
||||||
|
|
||||||
|
var firstResponseMap map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(firstDataPayload), &firstResponseMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
firstChoices := firstResponseMap["choices"].([]interface{})
|
||||||
|
require.Len(t, firstChoices, 1, "text chunk should contain one choice")
|
||||||
|
|
||||||
|
usageChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"inputTokens": 10,
|
||||||
|
"outputTokens": 2,
|
||||||
|
"totalTokens": 12,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
action = host.CallOnHttpStreamingResponseBody(usageChunk, true)
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
secondResponseBody := host.GetResponseBody()
|
||||||
|
require.NotNil(t, secondResponseBody)
|
||||||
|
require.Contains(t, string(secondResponseBody), "data: [DONE]", "last chunk should append [DONE]")
|
||||||
|
secondDataPayload := extractFirstDataPayload(secondResponseBody)
|
||||||
|
require.NotEmpty(t, secondDataPayload, "usage chunk should contain one SSE data payload")
|
||||||
|
|
||||||
|
var secondResponseMap map[string]interface{}
|
||||||
|
err = json.Unmarshal([]byte(secondDataPayload), &secondResponseMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
secondChoices := secondResponseMap["choices"].([]interface{})
|
||||||
|
require.Len(t, secondChoices, 0, "usage chunk should contain empty choices by design")
|
||||||
|
_, hasUsage := secondResponseMap["usage"]
|
||||||
|
require.True(t, hasUsage, "usage chunk should include usage field")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bedrock empty intermediate callback should not affect next usage event", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/chat/completions"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
requestBody := `{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": true
|
||||||
|
}`
|
||||||
|
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||||
|
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"Content-Type", "application/vnd.amazon.eventstream"},
|
||||||
|
})
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
action = host.CallOnHttpStreamingResponseBody([]byte{}, false)
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
emptyResponseBody := host.GetResponseBody()
|
||||||
|
require.Equal(t, 0, len(emptyResponseBody), "empty intermediate callback should output empty payload")
|
||||||
|
|
||||||
|
usageChunk := buildBedrockEventStreamMessage(t, map[string]interface{}{
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"inputTokens": 10,
|
||||||
|
"outputTokens": 2,
|
||||||
|
"totalTokens": 12,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
action = host.CallOnHttpStreamingResponseBody(usageChunk, true)
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
finalResponseBody := host.GetResponseBody()
|
||||||
|
require.NotNil(t, finalResponseBody)
|
||||||
|
require.Contains(t, string(finalResponseBody), "data: [DONE]", "last chunk should append [DONE]")
|
||||||
|
finalDataPayload := extractFirstDataPayload(finalResponseBody)
|
||||||
|
require.NotEmpty(t, finalDataPayload, "final usage event should still be parsed")
|
||||||
|
|
||||||
|
var finalResponseMap map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(finalDataPayload), &finalResponseMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
finalChoices := finalResponseMap["choices"].([]interface{})
|
||||||
|
require.Len(t, finalChoices, 0, "usage chunk should still keep empty choices")
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user