mirror of
https://github.com/alibaba/higress.git
synced 2026-03-21 19:47:52 +08:00
Compare commits
17 Commits
add-releas
...
feat/ai-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
179a233ad6 | ||
|
|
bdfe9950ce | ||
|
|
045238944d | ||
|
|
62df71aadf | ||
|
|
8961db2e90 | ||
|
|
94f0d7179f | ||
|
|
f1e305844e | ||
|
|
68d6090e36 | ||
|
|
65aba909d7 | ||
|
|
528e6c9908 | ||
|
|
13b808c1e4 | ||
|
|
aa502e7e62 | ||
|
|
2e3f6868df | ||
|
|
6c9747d778 | ||
|
|
c12183cae5 | ||
|
|
e2a22d1171 | ||
|
|
e9aecb6e1f |
7
.github/workflows/wasm-plugin-unit-test.yml
vendored
7
.github/workflows/wasm-plugin-unit-test.yml
vendored
@@ -199,15 +199,14 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go 1.24
|
||||
- name: Set up Go 1.25
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.24
|
||||
go-version: 1.25
|
||||
cache: true
|
||||
|
||||
|
||||
- name: Install required tools
|
||||
run: |
|
||||
go install github.com/wadey/gocovmerge@latest
|
||||
sudo apt-get update && sudo apt-get install -y bc
|
||||
|
||||
- name: Download all test results
|
||||
|
||||
@@ -4,6 +4,6 @@ dependencies:
|
||||
version: 2.2.0
|
||||
- name: higress-console
|
||||
repository: https://higress.io/helm-charts/
|
||||
version: 2.2.0
|
||||
digest: sha256:2cb148fa6d52856344e1905d3fea018466c2feb52013e08997c2d5c7d50f2e5d
|
||||
generated: "2026-02-11T17:45:59.187965929+08:00"
|
||||
version: 2.2.1
|
||||
digest: sha256:23fe7b0f84965c13ac7ceabe6334212fc3d323b7b781277a6d2b6fd38e935dda
|
||||
generated: "2026-03-07T12:45:44.267732+08:00"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.2.0
|
||||
appVersion: 2.2.1
|
||||
description: Helm chart for deploying Higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -15,6 +15,6 @@ dependencies:
|
||||
version: 2.2.0
|
||||
- name: higress-console
|
||||
repository: "https://higress.io/helm-charts/"
|
||||
version: 2.2.0
|
||||
version: 2.2.1
|
||||
type: application
|
||||
version: 2.2.0
|
||||
version: 2.2.1
|
||||
|
||||
@@ -17,7 +17,6 @@ package translation
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"istio.io/istio/pilot/pkg/model"
|
||||
istiomodel "istio.io/istio/pilot/pkg/model"
|
||||
"istio.io/istio/pkg/config"
|
||||
"istio.io/istio/pkg/config/schema/collection"
|
||||
@@ -40,8 +39,8 @@ type IngressTranslation struct {
|
||||
ingressConfig *ingressconfig.IngressConfig
|
||||
kingressConfig *ingressconfig.KIngressConfig
|
||||
mutex sync.RWMutex
|
||||
higressRouteCache model.IngressRouteCollection
|
||||
higressDomainCache model.IngressDomainCollection
|
||||
higressRouteCache istiomodel.IngressRouteCollection
|
||||
higressDomainCache istiomodel.IngressDomainCollection
|
||||
}
|
||||
|
||||
func NewIngressTranslation(localKubeClient kube.Client, xdsUpdater istiomodel.XDSUpdater, namespace string, options common.Options) *IngressTranslation {
|
||||
@@ -109,11 +108,11 @@ func (m *IngressTranslation) SetWatchErrorHandler(f func(r *cache.Reflector, err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *IngressTranslation) GetIngressRoutes() model.IngressRouteCollection {
|
||||
func (m *IngressTranslation) GetIngressRoutes() istiomodel.IngressRouteCollection {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
ingressRouteCache := m.ingressConfig.GetIngressRoutes()
|
||||
m.higressRouteCache = model.IngressRouteCollection{}
|
||||
m.higressRouteCache = istiomodel.IngressRouteCollection{}
|
||||
m.higressRouteCache.Invalid = append(m.higressRouteCache.Invalid, ingressRouteCache.Invalid...)
|
||||
m.higressRouteCache.Valid = append(m.higressRouteCache.Valid, ingressRouteCache.Valid...)
|
||||
if m.kingressConfig != nil {
|
||||
@@ -125,12 +124,12 @@ func (m *IngressTranslation) GetIngressRoutes() model.IngressRouteCollection {
|
||||
return m.higressRouteCache
|
||||
}
|
||||
|
||||
func (m *IngressTranslation) GetIngressDomains() model.IngressDomainCollection {
|
||||
func (m *IngressTranslation) GetIngressDomains() istiomodel.IngressDomainCollection {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
ingressDomainCache := m.ingressConfig.GetIngressDomains()
|
||||
|
||||
m.higressDomainCache = model.IngressDomainCollection{}
|
||||
m.higressDomainCache = istiomodel.IngressDomainCollection{}
|
||||
m.higressDomainCache.Invalid = append(m.higressDomainCache.Invalid, ingressDomainCache.Invalid...)
|
||||
m.higressDomainCache.Valid = append(m.higressDomainCache.Valid, ingressDomainCache.Valid...)
|
||||
if m.kingressConfig != nil {
|
||||
|
||||
@@ -140,10 +140,16 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
||||
|
||||
// Send the initial endpoint event
|
||||
initialEvent := fmt.Sprintf("event: endpoint\ndata: %s\n\n", messageEndpoint)
|
||||
err = s.redisClient.Publish(channel, initialEvent)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to send initial event: %v", err)
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
api.LogErrorf("Failed to send initial event: %v", r)
|
||||
}
|
||||
}()
|
||||
defer cb.EncoderFilterCallbacks().RecoverPanic()
|
||||
api.LogDebugf("SSE Send message: %s", initialEvent)
|
||||
cb.EncoderFilterCallbacks().InjectData([]byte(initialEvent))
|
||||
}()
|
||||
|
||||
// Start health check handler
|
||||
go func() {
|
||||
|
||||
@@ -52,6 +52,9 @@ var (
|
||||
{provider.PathOpenAICompletions, provider.ApiNameCompletion},
|
||||
{provider.PathOpenAIEmbeddings, provider.ApiNameEmbeddings},
|
||||
{provider.PathOpenAIAudioSpeech, provider.ApiNameAudioSpeech},
|
||||
{provider.PathOpenAIAudioTranscriptions, provider.ApiNameAudioTranscription},
|
||||
{provider.PathOpenAIAudioTranslations, provider.ApiNameAudioTranslation},
|
||||
{provider.PathOpenAIRealtime, provider.ApiNameRealtime},
|
||||
{provider.PathOpenAIImageGeneration, provider.ApiNameImageGeneration},
|
||||
{provider.PathOpenAIImageVariation, provider.ApiNameImageVariation},
|
||||
{provider.PathOpenAIImageEdit, provider.ApiNameImageEdit},
|
||||
@@ -225,9 +228,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
}
|
||||
}
|
||||
|
||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !isSupportedRequestContentType(apiName, contentType) {
|
||||
ctx.DontReadRequestBody()
|
||||
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
|
||||
log.Debugf("[onHttpRequestHeader] unsupported content type for api %s: %s, will not process the request body", apiName, contentType)
|
||||
}
|
||||
|
||||
if apiName == "" {
|
||||
@@ -306,6 +309,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
if err == nil {
|
||||
return action
|
||||
}
|
||||
log.Errorf("[onHttpRequestBody] failed to process request body, apiName=%s, err=%v", apiName, err)
|
||||
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
||||
}
|
||||
return types.ActionContinue
|
||||
@@ -381,6 +385,8 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
return chunk
|
||||
}
|
||||
|
||||
promoteThinking := pluginConfig.GetProviderConfig().GetPromoteThinkingOnEmpty()
|
||||
|
||||
log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType())
|
||||
log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
|
||||
|
||||
@@ -388,6 +394,9 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk)
|
||||
if err == nil && modifiedChunk != nil {
|
||||
if promoteThinking {
|
||||
modifiedChunk = promoteThinkingInStreamingChunk(ctx, modifiedChunk, isLastChunk)
|
||||
}
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, modifiedChunk)
|
||||
if convertErr != nil {
|
||||
@@ -431,6 +440,10 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
|
||||
result := []byte(responseBuilder.String())
|
||||
|
||||
if promoteThinking {
|
||||
result = promoteThinkingInStreamingChunk(ctx, result, isLastChunk)
|
||||
}
|
||||
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
|
||||
if convertErr != nil {
|
||||
@@ -439,11 +452,12 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
return claudeChunk
|
||||
}
|
||||
|
||||
if !needsClaudeResponseConversion(ctx) {
|
||||
if !needsClaudeResponseConversion(ctx) && !promoteThinking {
|
||||
return chunk
|
||||
}
|
||||
|
||||
// If provider doesn't implement any streaming handlers but we need Claude conversion
|
||||
// or thinking promotion
|
||||
// First extract complete events from the chunk
|
||||
events := provider.ExtractStreamingEvents(ctx, chunk)
|
||||
log.Debugf("[onStreamingResponseBody] %d events received (no handler)", len(events))
|
||||
@@ -460,6 +474,10 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
|
||||
result := []byte(responseBuilder.String())
|
||||
|
||||
if promoteThinking {
|
||||
result = promoteThinkingInStreamingChunk(ctx, result, isLastChunk)
|
||||
}
|
||||
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
|
||||
if convertErr != nil {
|
||||
@@ -492,6 +510,16 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
||||
finalBody = body
|
||||
}
|
||||
|
||||
// Promote thinking/reasoning to content when content is empty
|
||||
if pluginConfig.GetProviderConfig().GetPromoteThinkingOnEmpty() {
|
||||
promoted, err := provider.PromoteThinkingOnEmptyResponse(finalBody)
|
||||
if err != nil {
|
||||
log.Warnf("[promoteThinkingOnEmpty] failed: %v", err)
|
||||
} else {
|
||||
finalBody = promoted
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to Claude format if needed (applies to both branches)
|
||||
convertedBody, err := convertResponseBodyToClaude(ctx, finalBody)
|
||||
if err != nil {
|
||||
@@ -540,6 +568,49 @@ func convertStreamingResponseToClaude(ctx wrapper.HttpContext, data []byte) ([]b
|
||||
return claudeChunk, nil
|
||||
}
|
||||
|
||||
// promoteThinkingInStreamingChunk processes SSE-formatted streaming data, buffering
|
||||
// reasoning deltas and stripping them from chunks. On the last chunk, if no content
|
||||
// was ever seen, it appends a flush chunk that emits buffered reasoning as content.
|
||||
func promoteThinkingInStreamingChunk(ctx wrapper.HttpContext, data []byte, isLastChunk bool) []byte {
|
||||
// SSE data contains lines like "data: {...}\n\n"
|
||||
// We need to find and process each data line
|
||||
lines := strings.Split(string(data), "\n")
|
||||
modified := false
|
||||
for i, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data: ")
|
||||
if payload == "[DONE]" || payload == "" {
|
||||
continue
|
||||
}
|
||||
stripped, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, []byte(payload))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
newLine := "data: " + string(stripped)
|
||||
if newLine != line {
|
||||
lines[i] = newLine
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
result := data
|
||||
if modified {
|
||||
result = []byte(strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
// On last chunk, flush buffered reasoning as content if no content was seen
|
||||
if isLastChunk {
|
||||
flushChunk := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
if flushChunk != nil {
|
||||
result = append(flushChunk, result...)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper function to convert OpenAI response body to Claude format
|
||||
func convertResponseBodyToClaude(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
if !needsClaudeResponseConversion(ctx) {
|
||||
@@ -594,3 +665,14 @@ func getApiName(path string) provider.ApiName {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func isSupportedRequestContentType(apiName provider.ApiName, contentType string) bool {
|
||||
if strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||
return true
|
||||
}
|
||||
contentType = strings.ToLower(contentType)
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return apiName == provider.ApiNameImageEdit || apiName == provider.ApiNameImageVariation
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -18,6 +18,12 @@ func Test_getApiName(t *testing.T) {
|
||||
{"openai completions", "/v1/completions", provider.ApiNameCompletion},
|
||||
{"openai embeddings", "/v1/embeddings", provider.ApiNameEmbeddings},
|
||||
{"openai audio speech", "/v1/audio/speech", provider.ApiNameAudioSpeech},
|
||||
{"openai audio transcriptions", "/v1/audio/transcriptions", provider.ApiNameAudioTranscription},
|
||||
{"openai audio transcriptions with prefix", "/proxy/v1/audio/transcriptions", provider.ApiNameAudioTranscription},
|
||||
{"openai audio translations", "/v1/audio/translations", provider.ApiNameAudioTranslation},
|
||||
{"openai realtime", "/v1/realtime", provider.ApiNameRealtime},
|
||||
{"openai realtime with prefix", "/proxy/v1/realtime", provider.ApiNameRealtime},
|
||||
{"openai realtime with trailing slash", "/v1/realtime/", ""},
|
||||
{"openai image generation", "/v1/images/generations", provider.ApiNameImageGeneration},
|
||||
{"openai image variation", "/v1/images/variations", provider.ApiNameImageVariation},
|
||||
{"openai image edit", "/v1/images/edits", provider.ApiNameImageEdit},
|
||||
@@ -63,6 +69,54 @@ func Test_getApiName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSupportedRequestContentType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiName provider.ApiName
|
||||
contentType string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "json chat completion",
|
||||
apiName: provider.ApiNameChatCompletion,
|
||||
contentType: "application/json",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multipart image edit",
|
||||
apiName: provider.ApiNameImageEdit,
|
||||
contentType: "multipart/form-data; boundary=----boundary",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multipart image variation",
|
||||
apiName: provider.ApiNameImageVariation,
|
||||
contentType: "multipart/form-data; boundary=----boundary",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multipart chat completion",
|
||||
apiName: provider.ApiNameChatCompletion,
|
||||
contentType: "multipart/form-data; boundary=----boundary",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "text plain image edit",
|
||||
apiName: provider.ApiNameImageEdit,
|
||||
contentType: "text/plain",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isSupportedRequestContentType(tt.apiName, tt.contentType)
|
||||
if got != tt.want {
|
||||
t.Errorf("isSupportedRequestContentType(%v, %q) = %v, want %v", tt.apiName, tt.contentType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAi360(t *testing.T) {
|
||||
test.RunAi360ParseConfigTests(t)
|
||||
test.RunAi360OnHttpRequestHeadersTests(t)
|
||||
@@ -79,6 +133,8 @@ func TestOpenAI(t *testing.T) {
|
||||
test.RunOpenAIOnHttpResponseHeadersTests(t)
|
||||
test.RunOpenAIOnHttpResponseBodyTests(t)
|
||||
test.RunOpenAIOnStreamingResponseBodyTests(t)
|
||||
test.RunOpenAIPromoteThinkingOnEmptyTests(t)
|
||||
test.RunOpenAIPromoteThinkingOnEmptyStreamingTests(t)
|
||||
}
|
||||
|
||||
func TestQwen(t *testing.T) {
|
||||
@@ -123,6 +179,10 @@ func TestUtil(t *testing.T) {
|
||||
test.RunMapRequestPathByCapabilityTests(t)
|
||||
}
|
||||
|
||||
func TestApiPathRegression(t *testing.T) {
|
||||
test.RunApiPathRegressionTests(t)
|
||||
}
|
||||
|
||||
func TestGeneric(t *testing.T) {
|
||||
test.RunGenericParseConfigTests(t)
|
||||
test.RunGenericOnHttpRequestHeadersTests(t)
|
||||
@@ -137,6 +197,8 @@ func TestVertex(t *testing.T) {
|
||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
||||
test.RunVertexExpressModeImageEditVariationRequestBodyTests(t)
|
||||
test.RunVertexExpressModeImageEditVariationResponseBodyTests(t)
|
||||
// Vertex Raw 模式测试
|
||||
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
|
||||
test.RunVertexRawModeOnHttpRequestBodyTests(t)
|
||||
@@ -149,6 +211,7 @@ func TestBedrock(t *testing.T) {
|
||||
test.RunBedrockOnHttpRequestBodyTests(t)
|
||||
test.RunBedrockOnHttpResponseHeadersTests(t)
|
||||
test.RunBedrockOnHttpResponseBodyTests(t)
|
||||
test.RunBedrockOnStreamingResponseBodyTests(t)
|
||||
test.RunBedrockToolCallTests(t)
|
||||
}
|
||||
|
||||
|
||||
@@ -35,9 +35,23 @@ const (
|
||||
// converseStream路径 /model/{modelId}/converse-stream
|
||||
bedrockStreamChatCompletionPath = "/model/%s/converse-stream"
|
||||
// invoke_model 路径 /model/{modelId}/invoke
|
||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
bedrockInvokeModelPath = "/model/%s/invoke"
|
||||
bedrockSignedHeaders = "host;x-amz-date"
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
bedrockCacheTypeDefault = "default"
|
||||
bedrockCacheTTL5m = "5m"
|
||||
bedrockCacheTTL1h = "1h"
|
||||
bedrockPromptCacheNova = "amazon.nova"
|
||||
bedrockPromptCacheClaude = "anthropic.claude"
|
||||
|
||||
bedrockCachePointPositionSystemPrompt = "systemPrompt"
|
||||
bedrockCachePointPositionLastUserMessage = "lastUserMessage"
|
||||
bedrockCachePointPositionLastMessage = "lastMessage"
|
||||
)
|
||||
|
||||
var (
|
||||
bedrockConversePathPattern = regexp.MustCompile(`/model/[^/]+/converse(-stream)?$`)
|
||||
bedrockInvokePathPattern = regexp.MustCompile(`/model/[^/]+/invoke(-with-response-stream)?$`)
|
||||
)
|
||||
|
||||
type bedrockProviderInitializer struct{}
|
||||
@@ -164,9 +178,10 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
if bedrockEvent.Usage != nil {
|
||||
openAIFormattedChunk.Choices = choices[:0]
|
||||
openAIFormattedChunk.Usage = &usage{
|
||||
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
||||
PromptTokens: bedrockEvent.Usage.InputTokens,
|
||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
||||
PromptTokens: bedrockEvent.Usage.InputTokens,
|
||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockEvent.Usage.CacheReadInputTokens, bedrockEvent.Usage.CacheWriteInputTokens),
|
||||
}
|
||||
}
|
||||
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
|
||||
@@ -630,13 +645,24 @@ func (b *bedrockProvider) GetProviderType() string {
|
||||
return providerTypeBedrock
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) GetApiName(path string) ApiName {
|
||||
switch {
|
||||
case bedrockConversePathPattern.MatchString(path):
|
||||
return ApiNameChatCompletion
|
||||
case bedrockInvokePathPattern.MatchString(path):
|
||||
return ApiNameImageGeneration
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
b.config.handleRequestHeaders(b, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
|
||||
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, strings.TrimSpace(b.config.awsRegion)))
|
||||
|
||||
// If apiTokens is configured, set Bearer token authentication here
|
||||
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
|
||||
@@ -647,6 +673,15 @@ func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
// In original protocol mode (e.g. /model/{modelId}/converse-stream), keep the body/path untouched
|
||||
// and only apply auth headers.
|
||||
if b.config.IsOriginal() {
|
||||
headers := util.GetRequestHeaders()
|
||||
b.setAuthHeaders(body, headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
return types.ActionContinue, replaceRequestBody(body)
|
||||
}
|
||||
|
||||
if !b.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
@@ -654,14 +689,25 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
var transformedBody []byte
|
||||
var err error
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return b.onChatCompletionRequestBody(ctx, body, headers)
|
||||
transformedBody, err = b.onChatCompletionRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return b.onImageGenerationRequestBody(ctx, body, headers)
|
||||
transformedBody, err = b.onImageGenerationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return b.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
transformedBody, err = b.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Always apply auth after request body/path are finalized.
|
||||
// For Bearer token mode this is a no-op; for AK/SK mode this generates SigV4 headers.
|
||||
b.setAuthHeaders(transformedBody, headers)
|
||||
return transformedBody, nil
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
@@ -715,9 +761,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
|
||||
Quality: origRequest.Quality,
|
||||
},
|
||||
}
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, headers)
|
||||
return requestBytes, err
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||
@@ -797,6 +841,19 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
},
|
||||
}
|
||||
|
||||
effectivePromptCacheRetention := b.resolvePromptCacheRetention(origRequest.PromptCacheRetention)
|
||||
|
||||
if origRequest.PromptCacheKey != "" {
|
||||
log.Warnf("bedrock provider ignores prompt_cache_key because Converse API has no equivalent field")
|
||||
}
|
||||
if isPromptCacheSupportedModel(origRequest.Model) {
|
||||
if cacheTTL, ok := mapPromptCacheRetentionToBedrockTTL(effectivePromptCacheRetention); ok {
|
||||
addPromptCachePointsToBedrockRequest(request, cacheTTL, b.getPromptCachePointPositions())
|
||||
}
|
||||
} else if effectivePromptCacheRetention != "" {
|
||||
log.Warnf("skip prompt cache injection for unsupported model: %s", origRequest.Model)
|
||||
}
|
||||
|
||||
if origRequest.ReasoningEffort != "" {
|
||||
thinkingBudget := 1024 // default
|
||||
switch origRequest.ReasoningEffort {
|
||||
@@ -847,9 +904,7 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
request.AdditionalModelRequestFields[key] = value
|
||||
}
|
||||
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, headers)
|
||||
return requestBytes, err
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockConverseResponse) *chatCompletionResponse {
|
||||
@@ -900,9 +955,10 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
||||
Object: objectChatCompletion,
|
||||
Choices: choices,
|
||||
Usage: &usage{
|
||||
PromptTokens: bedrockResponse.Usage.InputTokens,
|
||||
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
||||
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
||||
PromptTokens: bedrockResponse.Usage.InputTokens,
|
||||
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
||||
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
||||
PromptTokensDetails: buildPromptTokensDetails(bedrockResponse.Usage.CacheReadInputTokens, bedrockResponse.Usage.CacheWriteInputTokens),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -933,6 +989,145 @@ func stopReasonBedrock2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func mapPromptCacheRetentionToBedrockTTL(retention string) (string, bool) {
|
||||
normalizedRetention := normalizePromptCacheRetention(retention)
|
||||
switch normalizedRetention {
|
||||
case "":
|
||||
return "", false
|
||||
case "in_memory":
|
||||
// For the default 5-minute cache, omit ttl and let Bedrock apply its default.
|
||||
// This is more robust for models that are strict about explicit ttl fields.
|
||||
return "", true
|
||||
case "24h":
|
||||
return bedrockCacheTTL1h, true
|
||||
default:
|
||||
log.Warnf("unsupported prompt_cache_retention for bedrock mapping: %s", retention)
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizePromptCacheRetention(retention string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(retention))
|
||||
normalized = strings.ReplaceAll(normalized, "-", "_")
|
||||
normalized = strings.ReplaceAll(normalized, " ", "_")
|
||||
if normalized == "inmemory" {
|
||||
return "in_memory"
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isPromptCacheSupportedModel(model string) bool {
|
||||
normalizedModel := strings.ToLower(strings.TrimSpace(model))
|
||||
return strings.Contains(normalizedModel, bedrockPromptCacheNova) ||
|
||||
strings.Contains(normalizedModel, bedrockPromptCacheClaude)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) resolvePromptCacheRetention(requestPromptCacheRetention string) string {
|
||||
if requestPromptCacheRetention != "" {
|
||||
return requestPromptCacheRetention
|
||||
}
|
||||
if b.config.promptCacheRetention != "" {
|
||||
return b.config.promptCacheRetention
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) getPromptCachePointPositions() map[string]bool {
|
||||
if b.config.bedrockPromptCachePointPositions == nil {
|
||||
return map[string]bool{
|
||||
bedrockCachePointPositionSystemPrompt: true,
|
||||
bedrockCachePointPositionLastMessage: false,
|
||||
}
|
||||
}
|
||||
positions := map[string]bool{
|
||||
bedrockCachePointPositionSystemPrompt: false,
|
||||
bedrockCachePointPositionLastUserMessage: false,
|
||||
bedrockCachePointPositionLastMessage: false,
|
||||
}
|
||||
for rawKey, enabled := range b.config.bedrockPromptCachePointPositions {
|
||||
key := normalizeBedrockCachePointPosition(rawKey)
|
||||
switch key {
|
||||
case bedrockCachePointPositionSystemPrompt, bedrockCachePointPositionLastUserMessage, bedrockCachePointPositionLastMessage:
|
||||
positions[key] = enabled
|
||||
default:
|
||||
log.Warnf("unsupported bedrockPromptCachePointPositions key: %s", rawKey)
|
||||
}
|
||||
}
|
||||
return positions
|
||||
}
|
||||
|
||||
func normalizeBedrockCachePointPosition(raw string) string {
|
||||
key := strings.ToLower(raw)
|
||||
key = strings.ReplaceAll(key, "_", "")
|
||||
key = strings.ReplaceAll(key, "-", "")
|
||||
switch key {
|
||||
case "systemprompt":
|
||||
return bedrockCachePointPositionSystemPrompt
|
||||
case "lastusermessage":
|
||||
return bedrockCachePointPositionLastUserMessage
|
||||
case "lastmessage":
|
||||
return bedrockCachePointPositionLastMessage
|
||||
default:
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
func addPromptCachePointsToBedrockRequest(request *bedrockTextGenRequest, cacheTTL string, positions map[string]bool) {
|
||||
if positions[bedrockCachePointPositionSystemPrompt] && len(request.System) > 0 {
|
||||
request.System = append(request.System, systemContentBlock{
|
||||
CachePoint: &bedrockCachePoint{
|
||||
Type: bedrockCacheTypeDefault,
|
||||
TTL: cacheTTL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
lastUserMessageIndex := -1
|
||||
if positions[bedrockCachePointPositionLastUserMessage] {
|
||||
lastUserMessageIndex = findLastMessageIndexByRole(request.Messages, roleUser)
|
||||
if lastUserMessageIndex >= 0 {
|
||||
appendCachePointToBedrockMessage(request, lastUserMessageIndex, cacheTTL)
|
||||
}
|
||||
}
|
||||
if positions[bedrockCachePointPositionLastMessage] && len(request.Messages) > 0 {
|
||||
lastMessageIndex := len(request.Messages) - 1
|
||||
if lastMessageIndex != lastUserMessageIndex {
|
||||
appendCachePointToBedrockMessage(request, lastMessageIndex, cacheTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findLastMessageIndexByRole(messages []bedrockMessage, role string) int {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == role {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func appendCachePointToBedrockMessage(request *bedrockTextGenRequest, messageIndex int, cacheTTL string) {
|
||||
if messageIndex < 0 || messageIndex >= len(request.Messages) {
|
||||
return
|
||||
}
|
||||
request.Messages[messageIndex].Content = append(request.Messages[messageIndex].Content, bedrockMessageContent{
|
||||
CachePoint: &bedrockCachePoint{
|
||||
Type: bedrockCacheTypeDefault,
|
||||
TTL: cacheTTL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildPromptTokensDetails(cacheReadInputTokens int, cacheWriteInputTokens int) *promptTokensDetails {
|
||||
totalCachedTokens := cacheReadInputTokens + cacheWriteInputTokens
|
||||
if totalCachedTokens <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &promptTokensDetails{
|
||||
CachedTokens: totalCachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
type bedrockTextGenRequest struct {
|
||||
Messages []bedrockMessage `json:"messages"`
|
||||
System []systemContentBlock `json:"system,omitempty"`
|
||||
@@ -977,14 +1172,21 @@ type bedrockMessage struct {
|
||||
}
|
||||
|
||||
type bedrockMessageContent struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
|
||||
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
|
||||
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
|
||||
CachePoint *bedrockCachePoint `json:"cachePoint,omitempty"`
|
||||
}
|
||||
|
||||
type systemContentBlock struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
CachePoint *bedrockCachePoint `json:"cachePoint,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockCachePoint struct {
|
||||
Type string `json:"type"`
|
||||
TTL string `json:"ttl,omitempty"`
|
||||
}
|
||||
|
||||
type imageBlock struct {
|
||||
@@ -1066,6 +1268,10 @@ type tokenUsage struct {
|
||||
OutputTokens int `json:"outputTokens,omitempty"`
|
||||
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
|
||||
CacheReadInputTokens int `json:"cacheReadInputTokens,omitempty"`
|
||||
|
||||
CacheWriteInputTokens int `json:"cacheWriteInputTokens,omitempty"`
|
||||
}
|
||||
|
||||
func chatToolMessage2BedrockToolResultContent(chatMessage chatMessage) bedrockMessageContent {
|
||||
@@ -1163,35 +1369,45 @@ func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
|
||||
}
|
||||
|
||||
// Use AWS Signature V4 authentication
|
||||
accessKey := strings.TrimSpace(b.config.awsAccessKey)
|
||||
region := strings.TrimSpace(b.config.awsRegion)
|
||||
t := time.Now().UTC()
|
||||
amzDate := t.Format("20060102T150405Z")
|
||||
dateStamp := t.Format("20060102")
|
||||
path := headers.Get(":path")
|
||||
signature := b.generateSignature(path, amzDate, dateStamp, body)
|
||||
headers.Set("X-Amz-Date", amzDate)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
|
||||
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", accessKey, dateStamp, region, awsService, bedrockSignedHeaders, signature))
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
|
||||
path = encodeSigV4Path(path)
|
||||
canonicalURI := encodeSigV4Path(path)
|
||||
hashedPayload := sha256Hex(body)
|
||||
region := strings.TrimSpace(b.config.awsRegion)
|
||||
secretKey := strings.TrimSpace(b.config.awsSecretKey)
|
||||
|
||||
endpoint := fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion)
|
||||
endpoint := fmt.Sprintf(bedrockDefaultDomain, region)
|
||||
canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate)
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
|
||||
httpPostMethod, path, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
|
||||
httpPostMethod, canonicalURI, canonicalHeaders, bedrockSignedHeaders, hashedPayload)
|
||||
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, b.config.awsRegion, awsService)
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, region, awsService)
|
||||
hashedCanonReq := sha256Hex([]byte(canonicalRequest))
|
||||
stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
|
||||
amzDate, credentialScope, hashedCanonReq)
|
||||
|
||||
signingKey := getSignatureKey(b.config.awsSecretKey, dateStamp, b.config.awsRegion, awsService)
|
||||
signingKey := getSignatureKey(secretKey, dateStamp, region, awsService)
|
||||
signature := hmacHex(signingKey, stringToSign)
|
||||
return signature
|
||||
}
|
||||
|
||||
func encodeSigV4Path(path string) string {
|
||||
// Keep only the URI path for canonical URI. Query string is handled separately in SigV4,
|
||||
// and this implementation uses an empty canonical query string.
|
||||
if queryIndex := strings.Index(path, "?"); queryIndex >= 0 {
|
||||
path = path[:queryIndex]
|
||||
}
|
||||
|
||||
segments := strings.Split(path, "/")
|
||||
for i, seg := range segments {
|
||||
if seg == "" {
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEncodeSigV4Path(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "raw model id keeps colon",
|
||||
path: "/model/global.amazon.nova-2-lite-v1:0/converse-stream",
|
||||
want: "/model/global.amazon.nova-2-lite-v1:0/converse-stream",
|
||||
},
|
||||
{
|
||||
name: "pre-encoded model id escapes percent to avoid mismatch",
|
||||
path: "/model/global.amazon.nova-2-lite-v1%3A0/converse-stream",
|
||||
want: "/model/global.amazon.nova-2-lite-v1%253A0/converse-stream",
|
||||
},
|
||||
{
|
||||
name: "raw inference profile arn keeps colon and slash delimiters",
|
||||
path: "/model/arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.anthropic.claude-sonnet-4-20250514-v1:0/converse",
|
||||
want: "/model/arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.anthropic.claude-sonnet-4-20250514-v1:0/converse",
|
||||
},
|
||||
{
|
||||
name: "encoded inference profile arn preserves escaped slash as double-escaped percent",
|
||||
path: "/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A123456789012%3Ainference-profile%2Fglobal.anthropic.claude-sonnet-4-20250514-v1%3A0/converse",
|
||||
want: "/model/arn%253Aaws%253Abedrock%253Aus-east-1%253A123456789012%253Ainference-profile%252Fglobal.anthropic.claude-sonnet-4-20250514-v1%253A0/converse",
|
||||
},
|
||||
{
|
||||
name: "query string is stripped before canonical encoding",
|
||||
path: "/model/global.amazon.nova-2-lite-v1%3A0/converse-stream?trace=1&foo=bar",
|
||||
want: "/model/global.amazon.nova-2-lite-v1%253A0/converse-stream",
|
||||
},
|
||||
{
|
||||
name: "invalid percent sequence falls back to escaped percent",
|
||||
path: "/model/abc%ZZxyz/converse",
|
||||
want: "/model/abc%25ZZxyz/converse",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, encodeSigV4Path(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverwriteRequestPathHeaderPreservesSingleEncodedRequestPath(t *testing.T) {
|
||||
p := &bedrockProvider{}
|
||||
plainModel := "arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.amazon.nova-2-lite-v1:0"
|
||||
preEncodedModel := url.QueryEscape(plainModel)
|
||||
|
||||
t.Run("plain model is encoded once", func(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
p.overwriteRequestPathHeader(headers, bedrockChatCompletionPath, plainModel)
|
||||
assert.Equal(t, "/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A123456789012%3Ainference-profile%2Fglobal.amazon.nova-2-lite-v1%3A0/converse", headers.Get(":path"))
|
||||
})
|
||||
|
||||
t.Run("pre-encoded model is not double encoded", func(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
p.overwriteRequestPathHeader(headers, bedrockChatCompletionPath, preEncodedModel)
|
||||
assert.Equal(t, "/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A123456789012%3Ainference-profile%2Fglobal.amazon.nova-2-lite-v1%3A0/converse", headers.Get(":path"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateSignatureIgnoresQueryStringInCanonicalURI(t *testing.T) {
|
||||
p := &bedrockProvider{
|
||||
config: ProviderConfig{
|
||||
awsRegion: "ap-northeast-3",
|
||||
awsSecretKey: "test-secret",
|
||||
},
|
||||
}
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"text":"hello"}]}]}`)
|
||||
pathWithoutQuery := "/model/global.amazon.nova-2-lite-v1%3A0/converse-stream"
|
||||
pathWithQuery := pathWithoutQuery + "?trace=1&foo=bar"
|
||||
|
||||
sigWithoutQuery := p.generateSignature(pathWithoutQuery, "20260312T142942Z", "20260312", body)
|
||||
sigWithQuery := p.generateSignature(pathWithQuery, "20260312T142942Z", "20260312", body)
|
||||
assert.Equal(t, sigWithoutQuery, sigWithQuery)
|
||||
}
|
||||
|
||||
func TestGenerateSignatureDiffersForRawAndPreEncodedModelPath(t *testing.T) {
|
||||
p := &bedrockProvider{
|
||||
config: ProviderConfig{
|
||||
awsRegion: "ap-northeast-3",
|
||||
awsSecretKey: "test-secret",
|
||||
},
|
||||
}
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"text":"hello"}]}]}`)
|
||||
rawPath := "/model/global.amazon.nova-2-lite-v1:0/converse-stream"
|
||||
preEncodedPath := "/model/global.amazon.nova-2-lite-v1%3A0/converse-stream"
|
||||
|
||||
rawSignature := p.generateSignature(rawPath, "20260312T142942Z", "20260312", body)
|
||||
preEncodedSignature := p.generateSignature(preEncodedPath, "20260312T142942Z", "20260312", body)
|
||||
assert.NotEqual(t, rawSignature, preEncodedSignature)
|
||||
}
|
||||
|
||||
func TestNormalizePromptCacheRetention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
retention string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "inmemory alias maps to in_memory",
|
||||
retention: "inmemory",
|
||||
want: "in_memory",
|
||||
},
|
||||
{
|
||||
name: "dash style maps to in_memory",
|
||||
retention: "in-memory",
|
||||
want: "in_memory",
|
||||
},
|
||||
{
|
||||
name: "space style with trim maps to in_memory",
|
||||
retention: " in memory ",
|
||||
want: "in_memory",
|
||||
},
|
||||
{
|
||||
name: "already normalized remains unchanged",
|
||||
retention: "in_memory",
|
||||
want: "in_memory",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, normalizePromptCacheRetention(tt.retention))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendCachePointToBedrockMessageInvalidIndexNoop(t *testing.T) {
|
||||
request := &bedrockTextGenRequest{
|
||||
Messages: []bedrockMessage{
|
||||
{
|
||||
Role: roleUser,
|
||||
Content: []bedrockMessageContent{
|
||||
{Text: "hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
appendCachePointToBedrockMessage(request, -1, bedrockCacheTTL5m)
|
||||
appendCachePointToBedrockMessage(request, len(request.Messages), bedrockCacheTTL5m)
|
||||
|
||||
assert.Len(t, request.Messages[0].Content, 1)
|
||||
|
||||
appendCachePointToBedrockMessage(request, 0, bedrockCacheTTL5m)
|
||||
assert.Len(t, request.Messages[0].Content, 2)
|
||||
assert.NotNil(t, request.Messages[0].Content[1].CachePoint)
|
||||
}
|
||||
|
||||
func TestIsPromptCacheSupportedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "anthropic claude model is supported",
|
||||
model: "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "amazon nova inference profile is supported",
|
||||
model: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/global.amazon.nova-2-lite-v1:0",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other model is not supported",
|
||||
model: "meta.llama3-70b-instruct-v1:0",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, isPromptCacheSupportedModel(tt.model))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -41,34 +42,36 @@ type thinkingParam struct {
|
||||
|
||||
type chatCompletionRequest struct {
|
||||
NonOpenAIStyleOptions
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Prediction map[string]interface{} `json:"prediction,omitempty"`
|
||||
Audio map[string]interface{} `json:"audio,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Prediction map[string]interface{} `json:"prediction,omitempty"`
|
||||
Audio map[string]interface{} `json:"audio,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
PromptCacheRetention string `json:"prompt_cache_retention,omitempty"`
|
||||
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (c *chatCompletionRequest) getMaxTokens() int {
|
||||
@@ -252,6 +255,70 @@ func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, r
|
||||
}
|
||||
}
|
||||
|
||||
// promoteThinkingOnEmpty promotes reasoning_content to content when content is empty.
|
||||
// This handles models that put user-facing replies into thinking blocks instead of text blocks.
|
||||
func (r *chatCompletionResponse) promoteThinkingOnEmpty() {
|
||||
for i := range r.Choices {
|
||||
msg := r.Choices[i].Message
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if !isContentEmpty(msg.Content) {
|
||||
continue
|
||||
}
|
||||
if msg.ReasoningContent != "" {
|
||||
msg.Content = msg.ReasoningContent
|
||||
msg.ReasoningContent = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// promoteStreamingThinkingOnEmpty accumulates reasoning content during streaming.
|
||||
// It strips reasoning from chunks and buffers it. When content is seen, it marks
|
||||
// the stream as having content so no promotion will happen.
|
||||
// Call PromoteStreamingThinkingFlush at the end of the stream to emit buffered
|
||||
// reasoning as content if no content was ever seen.
|
||||
// Returns true if the chunk was modified (reasoning stripped).
|
||||
func promoteStreamingThinkingOnEmpty(ctx wrapper.HttpContext, msg *chatMessage) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
hasContentDelta, _ := ctx.GetContext(ctxKeyHasContentDelta).(bool)
|
||||
if hasContentDelta {
|
||||
return false
|
||||
}
|
||||
|
||||
if !isContentEmpty(msg.Content) {
|
||||
ctx.SetContext(ctxKeyHasContentDelta, true)
|
||||
return false
|
||||
}
|
||||
|
||||
// Buffer reasoning content and strip it from the chunk
|
||||
reasoning := msg.ReasoningContent
|
||||
if reasoning == "" {
|
||||
reasoning = msg.Reasoning
|
||||
}
|
||||
if reasoning != "" {
|
||||
buffered, _ := ctx.GetContext(ctxKeyBufferedReasoning).(string)
|
||||
ctx.SetContext(ctxKeyBufferedReasoning, buffered+reasoning)
|
||||
msg.ReasoningContent = ""
|
||||
msg.Reasoning = ""
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isContentEmpty(content any) bool {
|
||||
switch v := content.(type) {
|
||||
case nil:
|
||||
return true
|
||||
case string:
|
||||
return strings.TrimSpace(v) == ""
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type chatMessageContent struct {
|
||||
CacheControl map[string]interface{} `json:"cache_control,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
@@ -461,6 +528,122 @@ type imageGenerationRequest struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
type imageInputURL struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
func (i *imageInputURL) UnmarshalJSON(data []byte) error {
|
||||
// Support a plain string payload, e.g. "data:image/png;base64,..."
|
||||
var rawURL string
|
||||
if err := json.Unmarshal(data, &rawURL); err == nil {
|
||||
i.URL = rawURL
|
||||
i.ImageURL = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
type alias imageInputURL
|
||||
var value alias
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = imageInputURL(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imageInputURL) GetURL() string {
|
||||
if i == nil {
|
||||
return ""
|
||||
}
|
||||
if i.ImageURL != nil && i.ImageURL.Url != "" {
|
||||
return i.ImageURL.Url
|
||||
}
|
||||
return i.URL
|
||||
}
|
||||
|
||||
type imageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image *imageInputURL `json:"image,omitempty"`
|
||||
Images []imageInputURL `json:"images,omitempty"`
|
||||
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||
Mask *imageInputURL `json:"mask,omitempty"`
|
||||
MaskURL *imageInputURL `json:"mask_url,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
func (r *imageEditRequest) GetImageURLs() []string {
|
||||
urls := make([]string, 0, len(r.Images)+2)
|
||||
for _, image := range r.Images {
|
||||
if url := image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.Image != nil {
|
||||
if url := r.Image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.ImageURL != nil {
|
||||
if url := r.ImageURL.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func (r *imageEditRequest) HasMask() bool {
|
||||
if r.Mask != nil && r.Mask.GetURL() != "" {
|
||||
return true
|
||||
}
|
||||
return r.MaskURL != nil && r.MaskURL.GetURL() != ""
|
||||
}
|
||||
|
||||
type imageVariationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image *imageInputURL `json:"image,omitempty"`
|
||||
Images []imageInputURL `json:"images,omitempty"`
|
||||
ImageURL *imageInputURL `json:"image_url,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
func (r *imageVariationRequest) GetImageURLs() []string {
|
||||
urls := make([]string, 0, len(r.Images)+2)
|
||||
for _, image := range r.Images {
|
||||
if url := image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.Image != nil {
|
||||
if url := r.Image.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
if r.ImageURL != nil {
|
||||
if url := r.ImageURL.GetURL(); url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
type imageGenerationData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64 string `json:"b64_json,omitempty"`
|
||||
@@ -529,3 +712,87 @@ func (r embeddingsRequest) ParseInput() []string {
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
// PromoteThinkingOnEmptyResponse promotes reasoning_content to content in a non-streaming
|
||||
// response body when content is empty. Returns the original body if no promotion is needed.
|
||||
func PromoteThinkingOnEmptyResponse(body []byte) ([]byte, error) {
|
||||
var resp chatCompletionResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body, fmt.Errorf("unable to unmarshal response for thinking promotion: %v", err)
|
||||
}
|
||||
promoted := false
|
||||
for i := range resp.Choices {
|
||||
msg := resp.Choices[i].Message
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if !isContentEmpty(msg.Content) {
|
||||
continue
|
||||
}
|
||||
if msg.ReasoningContent != "" {
|
||||
msg.Content = msg.ReasoningContent
|
||||
msg.ReasoningContent = ""
|
||||
promoted = true
|
||||
}
|
||||
}
|
||||
if !promoted {
|
||||
return body, nil
|
||||
}
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
// PromoteStreamingThinkingOnEmptyChunk buffers reasoning deltas and strips them from
|
||||
// the chunk during streaming. Call PromoteStreamingThinkingFlush on the last chunk
|
||||
// to emit buffered reasoning as content if no real content was ever seen.
|
||||
func PromoteStreamingThinkingOnEmptyChunk(ctx wrapper.HttpContext, data []byte) ([]byte, error) {
|
||||
var resp chatCompletionResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return data, nil // not a valid chat completion chunk, skip
|
||||
}
|
||||
modified := false
|
||||
for i := range resp.Choices {
|
||||
msg := resp.Choices[i].Delta
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if promoteStreamingThinkingOnEmpty(ctx, msg) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if !modified {
|
||||
return data, nil
|
||||
}
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
// PromoteStreamingThinkingFlush checks if the stream had no content and returns
|
||||
// an SSE chunk that emits the buffered reasoning as content. Returns nil if
|
||||
// content was already seen or no reasoning was buffered.
|
||||
func PromoteStreamingThinkingFlush(ctx wrapper.HttpContext) []byte {
|
||||
hasContentDelta, _ := ctx.GetContext(ctxKeyHasContentDelta).(bool)
|
||||
if hasContentDelta {
|
||||
return nil
|
||||
}
|
||||
buffered, _ := ctx.GetContext(ctxKeyBufferedReasoning).(string)
|
||||
if buffered == "" {
|
||||
return nil
|
||||
}
|
||||
// Build a minimal chat.completion.chunk with the buffered reasoning as content
|
||||
resp := chatCompletionResponse{
|
||||
Object: objectChatCompletionChunk,
|
||||
Choices: []chatCompletionChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: &chatMessage{
|
||||
Content: buffered,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// Format as SSE
|
||||
return []byte("data: " + string(data) + "\n\n")
|
||||
}
|
||||
|
||||
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
156
plugins/wasm-go/extensions/ai-proxy/provider/multipart_helper.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type multipartImageRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Size string
|
||||
OutputFormat string
|
||||
N int
|
||||
ImageURLs []string
|
||||
HasMask bool
|
||||
}
|
||||
|
||||
func isMultipartFormData(contentType string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(mediaType, "multipart/form-data")
|
||||
}
|
||||
|
||||
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse content-type: %v", err)
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, fmt.Errorf("missing multipart boundary")
|
||||
}
|
||||
|
||||
req := &multipartImageRequest{
|
||||
ImageURLs: make([]string, 0),
|
||||
}
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart part: %v", err)
|
||||
}
|
||||
fieldName := part.FormName()
|
||||
if fieldName == "" {
|
||||
_ = part.Close()
|
||||
continue
|
||||
}
|
||||
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
|
||||
|
||||
partData, err := io.ReadAll(part)
|
||||
_ = part.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err)
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(string(partData))
|
||||
switch fieldName {
|
||||
case "model":
|
||||
req.Model = value
|
||||
continue
|
||||
case "prompt":
|
||||
req.Prompt = value
|
||||
continue
|
||||
case "size":
|
||||
req.Size = value
|
||||
continue
|
||||
case "output_format":
|
||||
req.OutputFormat = value
|
||||
continue
|
||||
case "n":
|
||||
if value != "" {
|
||||
if parsed, err := strconv.Atoi(value); err == nil {
|
||||
req.N = parsed
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isMultipartImageField(fieldName) {
|
||||
if isMultipartImageURLValue(value) {
|
||||
req.ImageURLs = append(req.ImageURLs, value)
|
||||
continue
|
||||
}
|
||||
if len(partData) == 0 {
|
||||
continue
|
||||
}
|
||||
imageURL := buildMultipartDataURL(partContentType, partData)
|
||||
req.ImageURLs = append(req.ImageURLs, imageURL)
|
||||
continue
|
||||
}
|
||||
if isMultipartMaskField(fieldName) {
|
||||
if len(partData) > 0 || value != "" {
|
||||
req.HasMask = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func isMultipartImageField(fieldName string) bool {
|
||||
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
|
||||
}
|
||||
|
||||
func isMultipartMaskField(fieldName string) bool {
|
||||
return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[")
|
||||
}
|
||||
|
||||
func isMultipartImageURLValue(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
loweredValue := strings.ToLower(value)
|
||||
return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://")
|
||||
}
|
||||
|
||||
func buildMultipartDataURL(contentType string, data []byte) string {
|
||||
mimeType := strings.TrimSpace(contentType)
|
||||
if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") {
|
||||
mimeType = http.DetectContentType(data)
|
||||
}
|
||||
mimeType = normalizeMultipartMimeType(mimeType)
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
|
||||
}
|
||||
|
||||
func normalizeMultipartMimeType(contentType string) string {
|
||||
contentType = strings.TrimSpace(contentType)
|
||||
if contentType == "" {
|
||||
return ""
|
||||
}
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err == nil && mediaType != "" {
|
||||
return strings.TrimSpace(mediaType)
|
||||
}
|
||||
if idx := strings.Index(contentType, ";"); idx > 0 {
|
||||
return strings.TrimSpace(contentType[:idx])
|
||||
}
|
||||
return contentType
|
||||
}
|
||||
@@ -34,6 +34,9 @@ func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
string(ApiNameImageEdit): PathOpenAIImageEdit,
|
||||
string(ApiNameImageVariation): PathOpenAIImageVariation,
|
||||
string(ApiNameAudioSpeech): PathOpenAIAudioSpeech,
|
||||
string(ApiNameAudioTranscription): PathOpenAIAudioTranscriptions,
|
||||
string(ApiNameAudioTranslation): PathOpenAIAudioTranslations,
|
||||
string(ApiNameRealtime): PathOpenAIRealtime,
|
||||
string(ApiNameModels): PathOpenAIModels,
|
||||
string(ApiNameFiles): PathOpenAIFiles,
|
||||
string(ApiNameRetrieveFile): PathOpenAIRetrieveFile,
|
||||
@@ -63,6 +66,8 @@ func isDirectPath(path string) bool {
|
||||
return strings.HasSuffix(path, "/completions") ||
|
||||
strings.HasSuffix(path, "/embeddings") ||
|
||||
strings.HasSuffix(path, "/audio/speech") ||
|
||||
strings.HasSuffix(path, "/audio/transcriptions") ||
|
||||
strings.HasSuffix(path, "/audio/translations") ||
|
||||
strings.HasSuffix(path, "/images/generations") ||
|
||||
strings.HasSuffix(path, "/images/variations") ||
|
||||
strings.HasSuffix(path, "/images/edits") ||
|
||||
@@ -70,6 +75,7 @@ func isDirectPath(path string) bool {
|
||||
strings.HasSuffix(path, "/responses") ||
|
||||
strings.HasSuffix(path, "/fine_tuning/jobs") ||
|
||||
strings.HasSuffix(path, "/fine_tuning/checkpoints") ||
|
||||
strings.HasSuffix(path, "/realtime") ||
|
||||
strings.HasSuffix(path, "/videos")
|
||||
}
|
||||
|
||||
|
||||
@@ -41,6 +41,9 @@ const (
|
||||
ApiNameImageEdit ApiName = "openai/v1/imageedit"
|
||||
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
|
||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||
ApiNameAudioTranscription ApiName = "openai/v1/audiotranscription"
|
||||
ApiNameAudioTranslation ApiName = "openai/v1/audiotranslation"
|
||||
ApiNameRealtime ApiName = "openai/v1/realtime"
|
||||
ApiNameFiles ApiName = "openai/v1/files"
|
||||
ApiNameRetrieveFile ApiName = "openai/v1/retrievefile"
|
||||
ApiNameRetrieveFileContent ApiName = "openai/v1/retrievefilecontent"
|
||||
@@ -90,6 +93,9 @@ const (
|
||||
PathOpenAIImageEdit = "/v1/images/edits"
|
||||
PathOpenAIImageVariation = "/v1/images/variations"
|
||||
PathOpenAIAudioSpeech = "/v1/audio/speech"
|
||||
PathOpenAIAudioTranscriptions = "/v1/audio/transcriptions"
|
||||
PathOpenAIAudioTranslations = "/v1/audio/translations"
|
||||
PathOpenAIRealtime = "/v1/realtime"
|
||||
PathOpenAIResponses = "/v1/responses"
|
||||
PathOpenAIFineTuningJobs = "/v1/fine_tuning/jobs"
|
||||
PathOpenAIRetrieveFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}"
|
||||
@@ -172,6 +178,8 @@ const (
|
||||
ctxKeyPushedMessage = "pushedMessage"
|
||||
ctxKeyContentPushed = "contentPushed"
|
||||
ctxKeyReasoningContentPushed = "reasoningContentPushed"
|
||||
ctxKeyHasContentDelta = "hasContentDelta"
|
||||
ctxKeyBufferedReasoning = "bufferedReasoning"
|
||||
|
||||
objectChatCompletion = "chat.completion"
|
||||
objectChatCompletionChunk = "chat.completion.chunk"
|
||||
@@ -198,8 +206,7 @@ var (
|
||||
|
||||
// Providers that support the "developer" role. Other providers will have "developer" roles converted to "system".
|
||||
developerRoleSupportedProviders = map[string]bool{
|
||||
providerTypeOpenAI: true,
|
||||
providerTypeAzure: true,
|
||||
providerTypeAzure: true,
|
||||
}
|
||||
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
@@ -355,6 +362,12 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Amazon Bedrock 额外模型请求参数
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务,用于设置模型特定的推理参数
|
||||
bedrockAdditionalFields map[string]interface{} `required:"false" yaml:"bedrockAdditionalFields" json:"bedrockAdditionalFields"`
|
||||
// @Title zh-CN Amazon Bedrock Prompt CachePoint 插入位置
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务。用于配置 cachePoint 插入位置,支持多选:systemPrompt、lastUserMessage、lastMessage。值为 true 表示启用该位置。
|
||||
bedrockPromptCachePointPositions map[string]bool `required:"false" yaml:"bedrockPromptCachePointPositions" json:"bedrockPromptCachePointPositions"`
|
||||
// @Title zh-CN Amazon Bedrock Prompt Cache 保留策略(默认值)
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务。作为请求中 prompt_cache_retention 缺省时的默认值,支持 in_memory 和 24h。
|
||||
promptCacheRetention string `required:"false" yaml:"promptCacheRetention" json:"promptCacheRetention"`
|
||||
// @Title zh-CN minimax API type
|
||||
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2
|
||||
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
|
||||
@@ -460,6 +473,15 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 智谱AI Code Plan 模式
|
||||
// @Description zh-CN 仅适用于智谱AI服务。启用后将使用 /api/coding/paas/v4/chat/completions 接口
|
||||
zhipuCodePlanMode bool `required:"false" yaml:"zhipuCodePlanMode" json:"zhipuCodePlanMode"`
|
||||
// @Title zh-CN 合并连续同角色消息
|
||||
// @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息),将其内容合并为一条,以满足要求严格轮流交替(user→assistant→user→...)的模型服务商的要求。
|
||||
mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"`
|
||||
// @Title zh-CN 空内容时提升思考为正文
|
||||
// @Description zh-CN 开启后,若模型响应只包含 reasoning_content/thinking 而没有正文内容,将 reasoning 内容提升为正文内容返回,避免客户端收到空回复。
|
||||
promoteThinkingOnEmpty bool `required:"false" yaml:"promoteThinkingOnEmpty" json:"promoteThinkingOnEmpty"`
|
||||
// @Title zh-CN HiClaw 模式
|
||||
// @Description zh-CN 开启后同时启用 mergeConsecutiveMessages 和 promoteThinkingOnEmpty,适用于 HiClaw 多 Agent 协作场景。
|
||||
hiclawMode bool `required:"false" yaml:"hiclawMode" json:"hiclawMode"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -553,6 +575,13 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
for k, v := range json.Get("bedrockAdditionalFields").Map() {
|
||||
c.bedrockAdditionalFields[k] = v.Value()
|
||||
}
|
||||
c.promptCacheRetention = json.Get("promptCacheRetention").String()
|
||||
if rawPositions := json.Get("bedrockPromptCachePointPositions"); rawPositions.Exists() {
|
||||
c.bedrockPromptCachePointPositions = make(map[string]bool)
|
||||
for k, v := range rawPositions.Map() {
|
||||
c.bedrockPromptCachePointPositions[k] = v.Bool()
|
||||
}
|
||||
}
|
||||
}
|
||||
c.minimaxApiType = json.Get("minimaxApiType").String()
|
||||
c.minimaxGroupId = json.Get("minimaxGroupId").String()
|
||||
@@ -647,6 +676,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
string(ApiNameImageVariation),
|
||||
string(ApiNameImageEdit),
|
||||
string(ApiNameAudioSpeech),
|
||||
string(ApiNameAudioTranscription),
|
||||
string(ApiNameAudioTranslation),
|
||||
string(ApiNameRealtime),
|
||||
string(ApiNameResponses),
|
||||
string(ApiNameCohereV1Rerank),
|
||||
string(ApiNameVideos),
|
||||
string(ApiNameRetrieveVideo),
|
||||
@@ -673,6 +706,13 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.contextCleanupCommands = append(c.contextCleanupCommands, cmd.String())
|
||||
}
|
||||
}
|
||||
c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool()
|
||||
c.promoteThinkingOnEmpty = json.Get("promoteThinkingOnEmpty").Bool()
|
||||
c.hiclawMode = json.Get("hiclawMode").Bool()
|
||||
if c.hiclawMode {
|
||||
c.mergeConsecutiveMessages = true
|
||||
c.promoteThinkingOnEmpty = true
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
@@ -763,19 +803,19 @@ func (c *ProviderConfig) GetRandomToken() string {
|
||||
func isStatefulAPI(apiName string) bool {
|
||||
// These APIs maintain session state and should be routed to the same provider consistently
|
||||
statefulAPIs := map[string]bool{
|
||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||
string(ApiNameFiles): true, // Files API - maintains file state
|
||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||
string(ApiNameFiles): true, // Files API - maintains file state
|
||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||
}
|
||||
return statefulAPIs[apiName]
|
||||
}
|
||||
@@ -807,6 +847,10 @@ func (c *ProviderConfig) IsOriginal() bool {
|
||||
return c.protocol == protocolOriginal
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetPromoteThinkingOnEmpty() bool {
|
||||
return c.promoteThinkingOnEmpty
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
|
||||
return ReplaceByCustomSettings(body, c.customSettings)
|
||||
}
|
||||
@@ -845,6 +889,16 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *imageEditRequest:
|
||||
if err := decodeImageEditRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *imageVariationRequest:
|
||||
if err := decodeImageVariationRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req)
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
@@ -860,6 +914,10 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
||||
model = &req.Model
|
||||
case *imageGenerationRequest:
|
||||
model = &req.Model
|
||||
case *imageEditRequest:
|
||||
model = &req.Model
|
||||
case *imageVariationRequest:
|
||||
model = &req.Model
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
@@ -1020,7 +1078,7 @@ func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent
|
||||
if lineStartIndex != -1 {
|
||||
value := string(body[valueStartIndex:i])
|
||||
currentEvent.SetValue(currentKey, value)
|
||||
} else {
|
||||
} else if eventStartIndex != -1 {
|
||||
currentEvent.RawEvent = string(body[eventStartIndex : i+1])
|
||||
// Extra new line. The current event is complete.
|
||||
events = append(events, *currentEvent)
|
||||
@@ -1098,6 +1156,17 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
}
|
||||
}
|
||||
|
||||
// merge consecutive same-role messages for providers that require strict role alternation
|
||||
if apiName == ApiNameChatCompletion && c.mergeConsecutiveMessages {
|
||||
body, err = mergeConsecutiveMessages(body)
|
||||
if err != nil {
|
||||
log.Warnf("[mergeConsecutiveMessages] failed to merge messages: %v", err)
|
||||
err = nil
|
||||
} else {
|
||||
log.Debugf("[mergeConsecutiveMessages] merged consecutive messages for provider: %s", c.typ)
|
||||
}
|
||||
}
|
||||
|
||||
// convert developer role to system role for providers that don't support it
|
||||
if apiName == ApiNameChatCompletion && !isDeveloperRoleSupported(c.typ) {
|
||||
body, err = convertDeveloperRoleToSystem(body)
|
||||
|
||||
@@ -30,6 +30,7 @@ const (
|
||||
qwenCompatibleChatCompletionPath = "/compatible-mode/v1/chat/completions"
|
||||
qwenCompatibleCompletionsPath = "/compatible-mode/v1/completions"
|
||||
qwenCompatibleTextEmbeddingPath = "/compatible-mode/v1/embeddings"
|
||||
qwenCompatibleResponsesPath = "/api/v2/apps/protocols/compatible-mode/v1/responses"
|
||||
qwenCompatibleFilesPath = "/compatible-mode/v1/files"
|
||||
qwenCompatibleRetrieveFilePath = "/compatible-mode/v1/files/{file_id}"
|
||||
qwenCompatibleRetrieveFileContentPath = "/compatible-mode/v1/files/{file_id}/content"
|
||||
@@ -37,7 +38,7 @@ const (
|
||||
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
|
||||
qwenBailianPath = "/api/v1/apps"
|
||||
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||
qwenAnthropicMessagesPath = "/api/v2/apps/claude-code-proxy/v1/messages"
|
||||
qwenAnthropicMessagesPath = "/apps/anthropic/v1/messages"
|
||||
|
||||
qwenAsyncAIGCPath = "/api/v1/services/aigc/"
|
||||
qwenAsyncTaskPath = "/api/v1/tasks/"
|
||||
@@ -69,6 +70,7 @@ func (m *qwenProviderInitializer) DefaultCapabilities(qwenEnableCompatible bool)
|
||||
string(ApiNameChatCompletion): qwenCompatibleChatCompletionPath,
|
||||
string(ApiNameEmbeddings): qwenCompatibleTextEmbeddingPath,
|
||||
string(ApiNameCompletion): qwenCompatibleCompletionsPath,
|
||||
string(ApiNameResponses): qwenCompatibleResponsesPath,
|
||||
string(ApiNameFiles): qwenCompatibleFilesPath,
|
||||
string(ApiNameRetrieveFile): qwenCompatibleRetrieveFilePath,
|
||||
string(ApiNameRetrieveFileContent): qwenCompatibleRetrieveFileContentPath,
|
||||
@@ -707,6 +709,8 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
|
||||
case strings.Contains(path, qwenTextEmbeddingPath),
|
||||
strings.Contains(path, qwenCompatibleTextEmbeddingPath):
|
||||
return ApiNameEmbeddings
|
||||
case strings.Contains(path, qwenCompatibleResponsesPath):
|
||||
return ApiNameResponses
|
||||
case strings.Contains(path, qwenAsyncAIGCPath):
|
||||
return ApiNameQwenAsyncAIGC
|
||||
case strings.Contains(path, qwenAsyncTaskPath):
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
)
|
||||
|
||||
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
|
||||
@@ -32,6 +32,20 @@ func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeImageEditRequest(body []byte, request *imageEditRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error {
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceJsonRequestBody(request interface{}) error {
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
@@ -140,6 +154,54 @@ func cleanupContextMessages(body []byte, cleanupCommands []string) ([]byte, erro
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
// mergeConsecutiveMessages merges consecutive messages of the same role (user or assistant).
|
||||
// Many LLM providers require strict user↔assistant alternation and reject requests where
|
||||
// two messages of the same role appear consecutively. When enabled, consecutive same-role
|
||||
// messages have their content concatenated into a single message.
|
||||
func mergeConsecutiveMessages(body []byte) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return body, fmt.Errorf("unable to unmarshal request for message merging: %v", err)
|
||||
}
|
||||
if len(request.Messages) <= 1 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
merged := false
|
||||
result := make([]chatMessage, 0, len(request.Messages))
|
||||
for _, msg := range request.Messages {
|
||||
if len(result) > 0 &&
|
||||
result[len(result)-1].Role == msg.Role &&
|
||||
(msg.Role == roleUser || msg.Role == roleAssistant) {
|
||||
last := &result[len(result)-1]
|
||||
last.Content = mergeMessageContent(last.Content, msg.Content)
|
||||
merged = true
|
||||
continue
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
|
||||
if !merged {
|
||||
return body, nil
|
||||
}
|
||||
request.Messages = result
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
// mergeMessageContent concatenates two message content values.
|
||||
// If both are plain strings they are joined with a blank line.
|
||||
// Otherwise both are converted to content-block arrays and concatenated.
|
||||
func mergeMessageContent(prev, curr any) any {
|
||||
prevStr, prevIsStr := prev.(string)
|
||||
currStr, currIsStr := curr.(string)
|
||||
if prevIsStr && currIsStr {
|
||||
return prevStr + "\n\n" + currStr
|
||||
}
|
||||
prevParts := (&chatMessage{Content: prev}).ParseContent()
|
||||
currParts := (&chatMessage{Content: curr}).ParseContent()
|
||||
return append(prevParts, currParts...)
|
||||
}
|
||||
|
||||
func ReplaceResponseBody(body []byte) error {
|
||||
log.Debugf("response body: %s", string(body))
|
||||
err := proxywasm.ReplaceHttpResponseBody(body)
|
||||
|
||||
@@ -8,6 +8,131 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergeConsecutiveMessages(t *testing.T) {
|
||||
t.Run("no_consecutive_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "你好"},
|
||||
{Role: "assistant", Content: "你好!"},
|
||||
{Role: "user", Content: "再见"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
// No merging needed, returned body should be identical
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("merges_consecutive_user_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "第一条"},
|
||||
{Role: "user", Content: "第二条"},
|
||||
{Role: "assistant", Content: "回复"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
require.NoError(t, json.Unmarshal(result, &output))
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "user", output.Messages[0].Role)
|
||||
assert.Equal(t, "第一条\n\n第二条", output.Messages[0].Content)
|
||||
assert.Equal(t, "assistant", output.Messages[1].Role)
|
||||
})
|
||||
|
||||
t.Run("merges_consecutive_assistant_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "问题"},
|
||||
{Role: "assistant", Content: "第一段"},
|
||||
{Role: "assistant", Content: "第二段"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
require.NoError(t, json.Unmarshal(result, &output))
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "user", output.Messages[0].Role)
|
||||
assert.Equal(t, "assistant", output.Messages[1].Role)
|
||||
assert.Equal(t, "第一段\n\n第二段", output.Messages[1].Content)
|
||||
})
|
||||
|
||||
t.Run("merges_multiple_consecutive_same_role", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "A"},
|
||||
{Role: "user", Content: "B"},
|
||||
{Role: "user", Content: "C"},
|
||||
{Role: "assistant", Content: "回复"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
require.NoError(t, json.Unmarshal(result, &output))
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "A\n\nB\n\nC", output.Messages[0].Content)
|
||||
})
|
||||
|
||||
t.Run("system_messages_not_merged", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "系统提示1"},
|
||||
{Role: "system", Content: "系统提示2"},
|
||||
{Role: "user", Content: "问题"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
// system messages are not merged, body unchanged
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("single_message_unchanged", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "只有一条"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("invalid_json_body", func(t *testing.T) {
|
||||
body := []byte(`invalid json`)
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCleanupContextMessages(t *testing.T) {
|
||||
t.Run("empty_cleanup_commands", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
@@ -45,7 +45,9 @@ const (
|
||||
contextClaudeMarker = "isClaudeRequest"
|
||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||
contextVertexRawMarker = "isVertexRawRequest"
|
||||
contextVertexStreamDoneMarker = "vertexStreamDoneSent"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
||||
)
|
||||
|
||||
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
||||
@@ -98,6 +100,8 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||
string(ApiNameImageEdit): vertexPathTemplate,
|
||||
string(ApiNameImageVariation): vertexPathTemplate,
|
||||
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
||||
}
|
||||
}
|
||||
@@ -307,6 +311,10 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||
case ApiNameImageEdit:
|
||||
return v.onImageEditRequestBody(ctx, body, headers)
|
||||
case ApiNameImageVariation:
|
||||
return v.onImageVariationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
@@ -387,11 +395,108 @@ func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, b
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
||||
vertexRequest, err := v.buildVertexImageGenerationRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
||||
func (v *vertexProvider) onImageEditRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageEditRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedRequest.HasMask {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.HasMask() {
|
||||
return nil, fmt.Errorf("mask is not supported for vertex image edits yet")
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
if request.Prompt == "" {
|
||||
return nil, fmt.Errorf("missing prompt in request")
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageEdit, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageVariationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageVariationRequest{}
|
||||
imageURLs := make([]string, 0)
|
||||
contentType := headers.Get("Content-Type")
|
||||
if isMultipartFormData(contentType) {
|
||||
parsedRequest, err := parseMultipartImageRequest(body, contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Model = parsedRequest.Model
|
||||
request.Prompt = parsedRequest.Prompt
|
||||
request.Size = parsedRequest.Size
|
||||
request.OutputFormat = parsedRequest.OutputFormat
|
||||
request.N = parsedRequest.N
|
||||
imageURLs = parsedRequest.ImageURLs
|
||||
if err := v.config.mapModel(ctx, &request.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
imageURLs = request.GetImageURLs()
|
||||
}
|
||||
if len(imageURLs) == 0 {
|
||||
return nil, fmt.Errorf("missing image_url in request")
|
||||
}
|
||||
|
||||
prompt := request.Prompt
|
||||
if prompt == "" {
|
||||
prompt = vertexImageVariationDefaultPrompt
|
||||
}
|
||||
|
||||
path := v.getRequestPath(ApiNameImageVariation, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
headers.Set("Content-Type", util.MimeTypeApplicationJson)
|
||||
vertexRequest, err := v.buildVertexImageRequest(prompt, request.Size, request.OutputFormat, imageURLs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) (*vertexChatRequest, error) {
|
||||
return v.buildVertexImageRequest(request.Prompt, request.Size, request.OutputFormat, nil)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageRequest(prompt string, size string, outputFormat string, imageURLs []string) (*vertexChatRequest, error) {
|
||||
// 构建安全设置
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
@@ -402,12 +507,12 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
|
||||
// 解析尺寸参数
|
||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
||||
aspectRatio, imageSize := v.parseImageSize(size)
|
||||
|
||||
// 确定输出 MIME 类型
|
||||
mimeType := "image/png"
|
||||
if request.OutputFormat != "" {
|
||||
switch request.OutputFormat {
|
||||
if outputFormat != "" {
|
||||
switch outputFormat {
|
||||
case "jpeg", "jpg":
|
||||
mimeType = "image/jpeg"
|
||||
case "webp":
|
||||
@@ -417,12 +522,27 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
}
|
||||
}
|
||||
|
||||
parts := make([]vertexPart, 0, len(imageURLs)+1)
|
||||
for _, imageURL := range imageURLs {
|
||||
part, err := convertMediaContent(imageURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
if prompt != "" {
|
||||
parts = append(parts, vertexPart{
|
||||
Text: prompt,
|
||||
})
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("missing prompt and image_url in request")
|
||||
}
|
||||
|
||||
vertexRequest := &vertexChatRequest{
|
||||
Contents: []vertexChatContent{{
|
||||
Role: roleUser,
|
||||
Parts: []vertexPart{{
|
||||
Text: request.Prompt,
|
||||
}},
|
||||
Role: roleUser,
|
||||
Parts: parts,
|
||||
}},
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: vertexChatGenerationConfig{
|
||||
@@ -440,7 +560,7 @@ func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerat
|
||||
},
|
||||
}
|
||||
|
||||
return vertexRequest
|
||||
return vertexRequest, nil
|
||||
}
|
||||
|
||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||
@@ -502,23 +622,46 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
|
||||
}
|
||||
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
|
||||
if isLastChunk {
|
||||
return []byte(ssePrefix + "[DONE]\n\n"), nil
|
||||
}
|
||||
if len(chunk) == 0 {
|
||||
if len(chunk) == 0 && !isLastChunk {
|
||||
return nil, nil
|
||||
}
|
||||
if name != ApiNameChatCompletion {
|
||||
if isLastChunk {
|
||||
return []byte(ssePrefix + "[DONE]\n\n"), nil
|
||||
}
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
responseBuilder := &strings.Builder{}
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, data := range lines {
|
||||
if len(data) < 6 {
|
||||
// ignore blank line or wrong format
|
||||
// Flush a trailing event when upstream closes stream without a final blank line.
|
||||
chunkForParsing := chunk
|
||||
if isLastChunk {
|
||||
trailingNewLineCount := 0
|
||||
for i := len(chunkForParsing) - 1; i >= 0 && chunkForParsing[i] == '\n'; i-- {
|
||||
trailingNewLineCount++
|
||||
}
|
||||
if trailingNewLineCount < 2 {
|
||||
chunkForParsing = append([]byte(nil), chunk...)
|
||||
for i := 0; i < 2-trailingNewLineCount; i++ {
|
||||
chunkForParsing = append(chunkForParsing, '\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
streamEvents := ExtractStreamingEvents(ctx, chunkForParsing)
|
||||
doneSent, _ := ctx.GetContext(contextVertexStreamDoneMarker).(bool)
|
||||
appendDone := isLastChunk && !doneSent
|
||||
for _, event := range streamEvents {
|
||||
data := event.Data
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
if data == streamEndDataValue {
|
||||
if !doneSent {
|
||||
appendDone = true
|
||||
doneSent = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
var vertexResp vertexChatResponse
|
||||
if err := json.Unmarshal([]byte(data), &vertexResp); err != nil {
|
||||
log.Errorf("unable to unmarshal vertex response: %v", err)
|
||||
@@ -532,7 +675,17 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
}
|
||||
v.appendResponse(responseBuilder, string(responseBody))
|
||||
}
|
||||
if appendDone {
|
||||
responseBuilder.WriteString(ssePrefix + "[DONE]\n\n")
|
||||
doneSent = true
|
||||
}
|
||||
ctx.SetContext(contextVertexStreamDoneMarker, doneSent)
|
||||
modifiedResponseChunk := responseBuilder.String()
|
||||
if modifiedResponseChunk == "" {
|
||||
// Returning an empty payload prevents main.go from falling back to
|
||||
// forwarding the original raw chunk, which may contain partial JSON.
|
||||
return []byte(""), nil
|
||||
}
|
||||
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
@@ -553,7 +706,7 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
return v.onImageGenerationResponseBody(ctx, body)
|
||||
default:
|
||||
return body, nil
|
||||
@@ -784,7 +937,7 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
action = vertexEmbeddingAction
|
||||
case ApiNameImageGeneration:
|
||||
case ApiNameImageGeneration, ApiNameImageEdit, ApiNameImageVariation:
|
||||
// 图片生成使用非流式端点,需要完整响应
|
||||
action = vertexChatCompletionAction
|
||||
default:
|
||||
|
||||
116
plugins/wasm-go/extensions/ai-proxy/test/api_paths.go
Normal file
116
plugins/wasm-go/extensions/ai-proxy/test/api_paths.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
wasmtest "github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func openAICustomEndpointConfig(customURL string) json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-test-custom-endpoint"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4o-mini",
|
||||
},
|
||||
"openaiCustomUrl": customURL,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
var openAICustomAudioTranscriptionsEndpointConfig = openAICustomEndpointConfig("https://custom.openai.com/v1/audio/transcriptions")
|
||||
var openAICustomAudioTranslationsEndpointConfig = openAICustomEndpointConfig("https://custom.openai.com/v1/audio/translations")
|
||||
var openAICustomRealtimeEndpointConfig = openAICustomEndpointConfig("https://custom.openai.com/v1/realtime")
|
||||
var openAICustomRealtimeSessionsEndpointConfig = openAICustomEndpointConfig("https://custom.openai.com/v1/realtime/sessions")
|
||||
|
||||
func RunApiPathRegressionTests(t *testing.T) {
|
||||
wasmtest.RunTest(t, func(t *testing.T) {
|
||||
t.Run("openai direct custom endpoint audio transcriptions", func(t *testing.T) {
|
||||
host, status := wasmtest.NewTestHost(openAICustomAudioTranscriptionsEndpointConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/transcriptions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := wasmtest.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Equal(t, "/v1/audio/transcriptions", pathValue)
|
||||
})
|
||||
|
||||
t.Run("openai direct custom endpoint audio translations", func(t *testing.T) {
|
||||
host, status := wasmtest.NewTestHost(openAICustomAudioTranslationsEndpointConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/translations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := wasmtest.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Equal(t, "/v1/audio/translations", pathValue)
|
||||
})
|
||||
|
||||
t.Run("openai direct custom endpoint realtime", func(t *testing.T) {
|
||||
host, status := wasmtest.NewTestHost(openAICustomRealtimeEndpointConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/realtime"},
|
||||
{":method", "GET"},
|
||||
{"Connection", "Upgrade"},
|
||||
{"Upgrade", "websocket"},
|
||||
{"Sec-WebSocket-Version", "13"},
|
||||
{"Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
})
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := wasmtest.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Equal(t, "/v1/realtime", pathValue)
|
||||
})
|
||||
|
||||
t.Run("openai non-direct endpoint appends mapped realtime suffix", func(t *testing.T) {
|
||||
host, status := wasmtest.NewTestHost(openAICustomRealtimeSessionsEndpointConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/realtime"},
|
||||
{":method", "GET"},
|
||||
{"Connection", "Upgrade"},
|
||||
{"Upgrade", "websocket"},
|
||||
{"Sec-WebSocket-Version", "13"},
|
||||
{"Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
})
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := wasmtest.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Equal(t, "/v1/realtime/sessions/realtime", pathValue)
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
50
plugins/wasm-go/extensions/ai-proxy/test/mock_context.go
Normal file
50
plugins/wasm-go/extensions/ai-proxy/test/mock_context.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package test
|
||||
|
||||
import "github.com/higress-group/wasm-go/pkg/iface"
|
||||
|
||||
// MockHttpContext is a minimal mock for wrapper.HttpContext used in unit tests
|
||||
// that call provider functions directly (e.g. streaming thinking promotion).
|
||||
type MockHttpContext struct {
|
||||
contextMap map[string]interface{}
|
||||
}
|
||||
|
||||
func NewMockHttpContext() *MockHttpContext {
|
||||
return &MockHttpContext{contextMap: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
func (m *MockHttpContext) SetContext(key string, value interface{}) { m.contextMap[key] = value }
|
||||
func (m *MockHttpContext) GetContext(key string) interface{} { return m.contextMap[key] }
|
||||
func (m *MockHttpContext) GetBoolContext(key string, def bool) bool { return def }
|
||||
func (m *MockHttpContext) GetStringContext(key, def string) string { return def }
|
||||
func (m *MockHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def }
|
||||
func (m *MockHttpContext) Scheme() string { return "" }
|
||||
func (m *MockHttpContext) Host() string { return "" }
|
||||
func (m *MockHttpContext) Path() string { return "" }
|
||||
func (m *MockHttpContext) Method() string { return "" }
|
||||
func (m *MockHttpContext) GetUserAttribute(key string) interface{} { return nil }
|
||||
func (m *MockHttpContext) SetUserAttribute(key string, value interface{}) {}
|
||||
func (m *MockHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {}
|
||||
func (m *MockHttpContext) GetUserAttributeMap() map[string]interface{} { return nil }
|
||||
func (m *MockHttpContext) WriteUserAttributeToLog() error { return nil }
|
||||
func (m *MockHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil }
|
||||
func (m *MockHttpContext) WriteUserAttributeToTrace() error { return nil }
|
||||
func (m *MockHttpContext) DontReadRequestBody() {}
|
||||
func (m *MockHttpContext) DontReadResponseBody() {}
|
||||
func (m *MockHttpContext) BufferRequestBody() {}
|
||||
func (m *MockHttpContext) BufferResponseBody() {}
|
||||
func (m *MockHttpContext) NeedPauseStreamingResponse() {}
|
||||
func (m *MockHttpContext) PushBuffer(buffer []byte) {}
|
||||
func (m *MockHttpContext) PopBuffer() []byte { return nil }
|
||||
func (m *MockHttpContext) BufferQueueSize() int { return 0 }
|
||||
func (m *MockHttpContext) DisableReroute() {}
|
||||
func (m *MockHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *MockHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *MockHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error {
|
||||
return nil
|
||||
}
|
||||
func (m *MockHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 }
|
||||
func (m *MockHttpContext) HasRequestBody() bool { return false }
|
||||
func (m *MockHttpContext) HasResponseBody() bool { return false }
|
||||
func (m *MockHttpContext) IsWebsocket() bool { return false }
|
||||
func (m *MockHttpContext) IsBinaryRequestBody() bool { return false }
|
||||
func (m *MockHttpContext) IsBinaryResponseBody() bool { return false }
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -243,6 +244,84 @@ func RunOpenAIOnHttpRequestHeadersTests(t *testing.T) {
|
||||
require.Contains(t, authValue, "sk-openai-test123456789", "Authorization should contain OpenAI API token")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(语音转写接口)
|
||||
t.Run("openai audio transcriptions request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/transcriptions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost)
|
||||
require.Equal(t, "api.openai.com", hostValue)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/audio/transcriptions", "Path should contain audio transcriptions endpoint")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(语音翻译接口)
|
||||
t.Run("openai audio translations request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/translations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/audio/translations", "Path should contain audio translations endpoint")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(实时接口,WebSocket握手)
|
||||
t.Run("openai realtime websocket handshake request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/realtime?model=gpt-4o-realtime-preview"},
|
||||
{":method", "GET"},
|
||||
{"Connection", "Upgrade"},
|
||||
{"Upgrade", "websocket"},
|
||||
{"Sec-WebSocket-Version", "13"},
|
||||
{"Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
})
|
||||
|
||||
// WebSocket 握手本身不应依赖请求体。受测试框架限制,某些场景可能仍返回 HeaderStopIteration。
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/realtime", "Path should contain realtime endpoint")
|
||||
require.Contains(t, pathValue, "model=gpt-4o-realtime-preview", "Query parameters should be preserved")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(图像生成接口)
|
||||
t.Run("openai image generation request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
@@ -305,6 +384,61 @@ func RunOpenAIOnHttpRequestHeadersTests(t *testing.T) {
|
||||
// 对于直接路径,应该保持原有路径
|
||||
require.Contains(t, pathValue, "/v1/chat/completions", "Path should be preserved for direct custom path")
|
||||
})
|
||||
|
||||
// 测试OpenAI自定义域名请求头处理(间接路径语音转写)
|
||||
t.Run("openai custom domain indirect path audio transcriptions request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAICustomDomainIndirectPathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/audio/transcriptions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost)
|
||||
require.Equal(t, "custom.openai.com", hostValue, "Host should be changed to custom domain")
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/api/audio/transcriptions", "Path should be rewritten with indirect custom prefix")
|
||||
})
|
||||
|
||||
// 测试OpenAI自定义域名请求头处理(间接路径 realtime,WebSocket握手)
|
||||
t.Run("openai custom domain indirect path realtime websocket handshake request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAICustomDomainIndirectPathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/realtime?model=gpt-4o-realtime-preview"},
|
||||
{":method", "GET"},
|
||||
{"Connection", "Upgrade"},
|
||||
{"Upgrade", "websocket"},
|
||||
{"Sec-WebSocket-Version", "13"},
|
||||
{"Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
})
|
||||
|
||||
// WebSocket 握手本身不应依赖请求体。受测试框架限制,某些场景可能仍返回 HeaderStopIteration。
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/api/realtime", "Path should be rewritten with indirect custom prefix")
|
||||
require.Contains(t, pathValue, "model=gpt-4o-realtime-preview", "Query parameters should be preserved")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -864,3 +998,158 @@ func RunOpenAIOnStreamingResponseBodyTests(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 测试配置:OpenAI配置 + promoteThinkingOnEmpty
|
||||
var openAIPromoteThinkingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-test123456789"},
|
||||
"promoteThinkingOnEmpty": true,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:OpenAI配置 + hiclawMode
|
||||
var openAIHiclawModeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-test123456789"},
|
||||
"hiclawMode": true,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunOpenAIPromoteThinkingOnEmptyTests(t *testing.T) {
|
||||
// Config parsing tests via host framework
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("promoteThinkingOnEmpty config parses", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAIPromoteThinkingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
t.Run("hiclawMode config parses", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAIHiclawModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
|
||||
// Non-streaming promote logic tests via provider functions directly
|
||||
t.Run("promotes reasoning_content when content is empty string", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"","reasoning_content":"这是思考内容"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(result), `"content":"这是思考内容"`)
|
||||
require.NotContains(t, string(result), `"reasoning_content":"这是思考内容"`)
|
||||
})
|
||||
|
||||
t.Run("promotes reasoning_content when content is nil", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","reasoning_content":"思考结果"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(result), `"content":"思考结果"`)
|
||||
})
|
||||
|
||||
t.Run("no promotion when content is present", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复","reasoning_content":"思考过程"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
|
||||
t.Run("no promotion when no reasoning", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
|
||||
t.Run("no promotion when both empty", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
|
||||
t.Run("invalid json returns error", func(t *testing.T) {
|
||||
body := []byte(`not json`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
}
|
||||
|
||||
func RunOpenAIPromoteThinkingOnEmptyStreamingTests(t *testing.T) {
|
||||
// Streaming tests use provider functions directly since the test framework
|
||||
// does not expose GetStreamingResponseBody.
|
||||
t.Run("streaming: buffers reasoning and flushes on end when no content", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
// Chunk with only reasoning_content
|
||||
data := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"流式思考"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
require.NoError(t, err)
|
||||
// Reasoning should be stripped (not promoted inline)
|
||||
require.NotContains(t, string(result), `"content":"流式思考"`)
|
||||
|
||||
// Flush should emit buffered reasoning as content
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.NotNil(t, flush)
|
||||
require.Contains(t, string(flush), `"content":"流式思考"`)
|
||||
})
|
||||
|
||||
t.Run("streaming: no flush when content was seen", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
// First chunk: content delta
|
||||
data1 := []byte(`{"choices":[{"index":0,"delta":{"content":"正文"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1)
|
||||
|
||||
// Second chunk: reasoning only
|
||||
data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"后续思考"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2)
|
||||
require.NoError(t, err)
|
||||
// Should be unchanged since content was already seen
|
||||
require.Equal(t, string(data2), string(result))
|
||||
|
||||
// Flush should return nil since content was seen
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.Nil(t, flush)
|
||||
})
|
||||
|
||||
t.Run("streaming: accumulates multiple reasoning chunks", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
data1 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"第一段"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1)
|
||||
|
||||
data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"第二段"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2)
|
||||
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.NotNil(t, flush)
|
||||
require.Contains(t, string(flush), `"content":"第一段第二段"`)
|
||||
})
|
||||
|
||||
t.Run("streaming: no flush when no reasoning buffered", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.Nil(t, flush)
|
||||
})
|
||||
|
||||
t.Run("streaming: invalid json returns original", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
data := []byte(`not json`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(data), string(result))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -100,6 +100,22 @@ var qwenEnableCompatibleConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:qwen original + 兼容模式(用于覆盖 provider.GetApiName 分支)
|
||||
var qwenOriginalCompatibleConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "qwen",
|
||||
"apiTokens": []string{"sk-qwen-original-compatible"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "qwen-turbo",
|
||||
},
|
||||
"qwenEnableCompatible": true,
|
||||
"protocol": "original",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:qwen文件ID配置
|
||||
var qwenFileIdsConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
@@ -159,6 +175,15 @@ var qwenConflictConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
func hasUnsupportedAPINameError(errorLogs []string) bool {
|
||||
for _, log := range errorLogs {
|
||||
if strings.Contains(log, "unsupported API name") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func RunQwenParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基本qwen配置解析
|
||||
@@ -403,6 +428,29 @@ func RunQwenOnHttpRequestHeadersTests(t *testing.T) {
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/compatible-mode/v1/chat/completions", "Path should use compatible mode path")
|
||||
})
|
||||
|
||||
// 测试qwen兼容模式请求头处理(responses接口)
|
||||
t.Run("qwen compatible mode responses request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(qwenEnableCompatibleConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/api/v2/apps/protocols/compatible-mode/v1/responses", "Path should use compatible mode responses path")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -651,6 +699,112 @@ func RunQwenOnHttpRequestBodyTests(t *testing.T) {
|
||||
}
|
||||
require.True(t, hasCompatibleLogs, "Should have compatible mode processing logs")
|
||||
})
|
||||
|
||||
// 测试qwen请求体处理(兼容模式 responses接口)
|
||||
t.Run("qwen compatible mode responses request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(qwenEnableCompatibleConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"qwen-turbo","input":"test"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
require.Contains(t, string(processedBody), "qwen-turbo", "Model name should be preserved in responses request")
|
||||
})
|
||||
|
||||
// 测试qwen请求体处理(非兼容模式 responses接口应报不支持)
|
||||
t.Run("qwen non-compatible mode responses request body unsupported", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicQwenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/responses", "Path should remain unchanged when responses is unsupported")
|
||||
|
||||
requestBody := `{"model":"qwen-turbo","input":"test"}`
|
||||
bodyAction := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, bodyAction)
|
||||
|
||||
hasUnsupportedErr := hasUnsupportedAPINameError(host.GetErrorLogs())
|
||||
require.True(t, hasUnsupportedErr, "Should log unsupported API name for non-compatible responses")
|
||||
})
|
||||
|
||||
// 覆盖 qwen.GetApiName 中以下分支:
|
||||
// - qwenCompatibleTextEmbeddingPath => ApiNameEmbeddings
|
||||
// - qwenCompatibleResponsesPath => ApiNameResponses
|
||||
// - qwenAsyncAIGCPath => ApiNameQwenAsyncAIGC
|
||||
// - qwenAsyncTaskPath => ApiNameQwenAsyncTask
|
||||
t.Run("qwen original protocol get api name coverage for compatible embeddings responses and async paths", func(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
path string
|
||||
}{
|
||||
{
|
||||
name: "compatible embeddings path",
|
||||
path: "/compatible-mode/v1/embeddings",
|
||||
},
|
||||
{
|
||||
name: "compatible responses path",
|
||||
path: "/api/v2/apps/protocols/compatible-mode/v1/responses",
|
||||
},
|
||||
{
|
||||
name: "async aigc path",
|
||||
path: "/api/v1/services/aigc/custom-async-endpoint",
|
||||
},
|
||||
{
|
||||
name: "async task path",
|
||||
path: "/api/v1/tasks/task-123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
host, status := test.NewTestHost(qwenOriginalCompatibleConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", tc.path},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
// 测试框架中 action 可能表现为 Continue 或 HeaderStopIteration,
|
||||
// 这里关注的是后续 body 阶段不出现 unsupported API name。
|
||||
require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration)
|
||||
|
||||
requestBody := `{"model":"qwen-turbo","input":"test"}`
|
||||
bodyAction := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, bodyAction)
|
||||
|
||||
hasUnsupportedErr := hasUnsupportedAPINameError(host.GetErrorLogs())
|
||||
require.False(t, hasUnsupportedErr, "Path should be recognized by qwen.GetApiName in original protocol")
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -986,6 +1140,51 @@ func RunQwenOnHttpResponseBodyTests(t *testing.T) {
|
||||
require.Contains(t, responseStr, "chat.completion", "Response should contain chat completion object")
|
||||
require.Contains(t, responseStr, "qwen-turbo", "Response should contain model name")
|
||||
})
|
||||
|
||||
// 测试qwen响应体处理(兼容模式 responses 接口透传)
|
||||
t.Run("qwen compatible mode responses response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(qwenEnableCompatibleConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"qwen-turbo","input":"test"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
responseBody := `{
|
||||
"id": "resp-123",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "hello"
|
||||
}]
|
||||
}]
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "\"object\": \"response\"", "Responses API payload should be passthrough in compatible mode")
|
||||
require.Contains(t, responseStr, "\"text\": \"hello\"", "Assistant content should be preserved")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -689,7 +691,7 @@ func RunVertexOpenAICompatibleModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
|
||||
func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 流式响应处理
|
||||
// 测试 Vertex Express Mode 流式响应处理:最后一个 chunk 不应丢失
|
||||
t.Run("vertex express mode streaming response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
@@ -707,6 +709,9 @@ func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性,确保IsResponseFromUpstream()返回true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// 设置流式响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
@@ -715,8 +720,8 @@ func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 模拟流式响应体
|
||||
chunk1 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":5,"totalTokenCount":14}}`
|
||||
chunk2 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello! How can I help you today?"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":12,"totalTokenCount":21}}`
|
||||
chunk1 := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Hello\"}],\"role\":\"model\"},\"finishReason\":\"\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":9,\"candidatesTokenCount\":5,\"totalTokenCount\":14}}\n\n"
|
||||
chunk2 := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Hello! How can I help you today?\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":9,\"candidatesTokenCount\":12,\"totalTokenCount\":21}}\n\n"
|
||||
|
||||
// 处理流式响应体
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
|
||||
@@ -725,16 +730,194 @@ func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
// 验证流式响应处理
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasStreamingLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "vertex") {
|
||||
hasStreamingLogs = true
|
||||
// 验证最后一个 chunk 的内容不会被 [DONE] 覆盖
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
responseStr := string(transformedResponseBody)
|
||||
require.Contains(t, responseStr, "Hello! How can I help you today?", "last chunk content should be preserved")
|
||||
require.Contains(t, responseStr, "data: [DONE]", "stream should end with [DONE]")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 流式响应处理:单个 SSE 事件被拆包时可正确重组
|
||||
t.Run("vertex express mode streaming response body with split sse event", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性,确保IsResponseFromUpstream()返回true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
fullEvent := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"split chunk\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":1,\"candidatesTokenCount\":2,\"totalTokenCount\":3}}\n\n"
|
||||
splitIdx := strings.Index(fullEvent, "chunk")
|
||||
require.Greater(t, splitIdx, 0, "split marker should exist in test payload")
|
||||
chunkPart1 := fullEvent[:splitIdx]
|
||||
chunkPart2 := fullEvent[splitIdx:]
|
||||
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart1), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart2), true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
responseStr := string(transformedResponseBody)
|
||||
require.Contains(t, responseStr, "split chunk", "split SSE event should be reassembled and parsed")
|
||||
require.Contains(t, responseStr, "data: [DONE]", "stream should end with [DONE]")
|
||||
})
|
||||
|
||||
// 测试:thoughtSignature 很大时,单个 SSE 事件被拆成多段也能重组并成功解析
|
||||
t.Run("vertex express mode streaming response body with huge thought signature split across chunks", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
})
|
||||
|
||||
hugeThoughtSignature := strings.Repeat("CmMBjz1rX4j+TQjtDy2rZxSdYOE1jUqDbRhWetraLlQNrkyaRNQZ/", 180)
|
||||
fullEvent := "data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"thought-signature-merge-ok\",\"thoughtSignature\":\"" +
|
||||
hugeThoughtSignature +
|
||||
"\"}]},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":28,\"candidatesTokenCount\":3589,\"totalTokenCount\":5240,\"thoughtsTokenCount\":1623}}\n\n"
|
||||
|
||||
signatureStart := strings.Index(fullEvent, "\"thoughtSignature\":\"")
|
||||
require.Greater(t, signatureStart, 0, "thoughtSignature field should exist in test payload")
|
||||
splitAt1 := signatureStart + len("\"thoughtSignature\":\"") + 700
|
||||
splitAt2 := splitAt1 + 1600
|
||||
require.Less(t, splitAt2, len(fullEvent)-1, "split indexes should keep payload in three chunks")
|
||||
|
||||
chunkPart1 := fullEvent[:splitAt1]
|
||||
chunkPart2 := fullEvent[splitAt1:splitAt2]
|
||||
chunkPart3 := fullEvent[splitAt2:]
|
||||
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart1), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
firstBody := host.GetResponseBody()
|
||||
require.Equal(t, 0, len(firstBody), "partial chunk should not be forwarded to client")
|
||||
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart2), false)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
secondBody := host.GetResponseBody()
|
||||
require.Equal(t, 0, len(secondBody), "partial chunk should not be forwarded to client")
|
||||
|
||||
action3 := host.CallOnHttpStreamingResponseBody([]byte(chunkPart3), true)
|
||||
require.Equal(t, types.ActionContinue, action3)
|
||||
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
responseStr := string(transformedResponseBody)
|
||||
require.Contains(t, responseStr, "thought-signature-merge-ok", "split huge thoughtSignature event should be reassembled and parsed")
|
||||
require.Contains(t, responseStr, "data: [DONE]", "stream should end with [DONE]")
|
||||
|
||||
errorLogs := host.GetErrorLogs()
|
||||
hasUnmarshalError := false
|
||||
for _, log := range errorLogs {
|
||||
if strings.Contains(log, "unable to unmarshal vertex response") {
|
||||
hasUnmarshalError = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasStreamingLogs, "Should have streaming response processing logs")
|
||||
require.False(t, hasUnmarshalError, "should not have vertex unmarshal errors for split huge thoughtSignature event")
|
||||
})
|
||||
|
||||
// 测试:上游已发送 [DONE],框架再触发空的最后回调时不应重复输出 [DONE]
|
||||
t.Run("vertex express mode streaming response body with upstream done and empty final callback", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
})
|
||||
|
||||
doneChunk := "data: [DONE]\n\n"
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(doneChunk), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
firstBody := host.GetResponseBody()
|
||||
require.NotNil(t, firstBody)
|
||||
require.Contains(t, string(firstBody), "data: [DONE]", "first callback should output [DONE]")
|
||||
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte{}, true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
debugLogs := host.GetDebugLogs()
|
||||
doneChunkLogCount := 0
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "=== modified response chunk: data: [DONE]") {
|
||||
doneChunkLogCount++
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, doneChunkLogCount, "[DONE] should only be emitted once when upstream already sent it")
|
||||
})
|
||||
|
||||
// 测试:最后一个 chunk 缺少 SSE 结束空行时,isLastChunk=true 也应正确解析并输出
|
||||
t.Run("vertex express mode streaming response body last chunk without terminator", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
})
|
||||
|
||||
lastChunkWithoutTerminator := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"no terminator\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":2,\"candidatesTokenCount\":3,\"totalTokenCount\":5}}"
|
||||
action := host.CallOnHttpStreamingResponseBody([]byte(lastChunkWithoutTerminator), true)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
responseStr := string(transformedResponseBody)
|
||||
require.Contains(t, responseStr, "no terminator", "last chunk without terminator should still be parsed")
|
||||
require.Contains(t, responseStr, "data: [DONE]", "stream should end with [DONE]")
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -1273,6 +1456,324 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func buildMultipartRequestBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
|
||||
var buffer bytes.Buffer
|
||||
writer := multipart.NewWriter(&buffer)
|
||||
|
||||
for key, value := range fields {
|
||||
require.NoError(t, writer.WriteField(key, value))
|
||||
}
|
||||
|
||||
for fieldName, data := range files {
|
||||
part, err := writer.CreateFormFile(fieldName, "upload-image.png")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write(data)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, writer.Close())
|
||||
return buffer.Bytes(), writer.FormDataContentType()
|
||||
}
|
||||
|
||||
func RunVertexExpressModeImageEditVariationRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
|
||||
t.Run("vertex express mode image edit request body with image_url", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":{"image_url":{"url":"` + testDataURL + `"}},"size":"1024x1024"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||
require.NotContains(t, bodyStr, "image_url", "OpenAI image_url field should be converted to Vertex format")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "generateContent", "Image edit should use generateContent action")
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image edit request body with image string", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add sunglasses to the cat","image":"` + testDataURL + `","size":"1024x1024"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image string")
|
||||
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image edit multipart request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"prompt": "Add sunglasses to the cat",
|
||||
"size": "1024x1024",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||
require.Contains(t, bodyStr, "Add sunglasses to the cat", "Prompt text should be preserved")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image variation multipart request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
body, contentType := buildMultipartRequestBody(t, map[string]string{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"size": "1024x1024",
|
||||
}, map[string][]byte{
|
||||
"image": []byte("fake-image-content"),
|
||||
})
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", contentType},
|
||||
})
|
||||
|
||||
action := host.CallOnHttpRequestBody(body)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Multipart image should be converted to inlineData")
|
||||
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, "Content-Type", "application/json"), "Content-Type should be rewritten to application/json")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image edit with model mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gpt-4","prompt":"Turn it into watercolor","image_url":{"url":"` + testDataURL + `"}}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image variation request body with image_url", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "inlineData", "Request should contain inlineData converted from image_url")
|
||||
require.Contains(t, bodyStr, "Create variations of the provided image.", "Variation request should inject a default prompt")
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "generateContent", "Image variation should use generateContent action")
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeImageEditVariationResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
const testDataURL = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
|
||||
t.Run("vertex express mode image edit response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/edits"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"Add glasses","image_url":{"url":"` + testDataURL + `"}}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
}
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 12,
|
||||
"candidatesTokenCount": 1024,
|
||||
"totalTokenCount": 1036
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||
})
|
||||
|
||||
t.Run("vertex express mode image variation response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/variations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","image_url":{"url":"` + testDataURL + `"}}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
}
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 8,
|
||||
"candidatesTokenCount": 768,
|
||||
"totalTokenCount": 776
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage field")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== Vertex Raw 模式测试 ====================
|
||||
|
||||
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user