From 179a233ad6da996985f0b8d02fa35ff5afa8ac30 Mon Sep 17 00:00:00 2001 From: johnlanni Date: Fri, 20 Mar 2026 00:05:17 +0800 Subject: [PATCH] refactor(ai-proxy): redesign streaming thinking promotion to buffer-and-flush Instead of promoting reasoning to content inline per-chunk (which would emit reasoning as content prematurely if real content arrives later), the streaming path now buffers reasoning content and strips it from chunks. On the last chunk, if no content was ever seen, the buffered reasoning is flushed as a single content chunk. Also moves tests into test/openai.go TestOpenAI suite and adds MockHttpContext for provider-level streaming tests. Co-Authored-By: Claude Opus 4.6 (1M context) --- plugins/wasm-go/extensions/ai-proxy/main.go | 34 ++- .../wasm-go/extensions/ai-proxy/main_test.go | 229 +----------------- .../extensions/ai-proxy/provider/model.go | 56 ++++- .../extensions/ai-proxy/provider/provider.go | 1 + .../extensions/ai-proxy/test/mock_context.go | 50 ++++ .../extensions/ai-proxy/test/openai.go | 156 ++++++++++++ 6 files changed, 279 insertions(+), 247 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/mock_context.go diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 49a8ea50a..dc34346c9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -395,7 +395,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk) if err == nil && modifiedChunk != nil { if promoteThinking { - modifiedChunk = promoteThinkingInStreamingChunk(ctx, modifiedChunk) + modifiedChunk = promoteThinkingInStreamingChunk(ctx, modifiedChunk, isLastChunk) } // Convert to Claude format if needed claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, modifiedChunk) @@ -441,7 +441,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin result := []byte(responseBuilder.String()) if promoteThinking { - result = promoteThinkingInStreamingChunk(ctx, result) + result = promoteThinkingInStreamingChunk(ctx, result, isLastChunk) } // Convert to Claude format if needed @@ -475,7 +475,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin result := []byte(responseBuilder.String()) if promoteThinking { - result = promoteThinkingInStreamingChunk(ctx, result) + result = promoteThinkingInStreamingChunk(ctx, result, isLastChunk) } // Convert to Claude format if needed @@ -568,9 +568,10 @@ func convertStreamingResponseToClaude(ctx wrapper.HttpContext, data []byte) ([]b return claudeChunk, nil } -// promoteThinkingInStreamingChunk processes SSE-formatted streaming data and promotes -// reasoning/thinking deltas to content deltas when no content has been seen. -func promoteThinkingInStreamingChunk(ctx wrapper.HttpContext, data []byte) []byte { +// promoteThinkingInStreamingChunk processes SSE-formatted streaming data, buffering +// reasoning deltas and stripping them from chunks. On the last chunk, if no content +// was ever seen, it appends a flush chunk that emits buffered reasoning as content. +func promoteThinkingInStreamingChunk(ctx wrapper.HttpContext, data []byte, isLastChunk bool) []byte { // SSE data contains lines like "data: {...}\n\n" // We need to find and process each data line lines := strings.Split(string(data), "\n") @@ -583,20 +584,31 @@ func promoteThinkingInStreamingChunk(ctx wrapper.HttpContext, data []byte) []byt if payload == "[DONE]" || payload == "" { continue } - promoted, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, []byte(payload)) + stripped, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, []byte(payload)) if err != nil { continue } - newLine := "data: " + string(promoted) + newLine := "data: " + string(stripped) if newLine != line { lines[i] = newLine modified = true } } - if !modified { - return data + + result := data + if modified { + result = []byte(strings.Join(lines, "\n")) } - return []byte(strings.Join(lines, "\n")) + + // On last chunk, flush buffered reasoning as content if no content was seen + if isLastChunk { + flushChunk := provider.PromoteStreamingThinkingFlush(ctx) + if flushChunk != nil { + result = append(flushChunk, result...) + } + } + + return result } // Helper function to convert OpenAI response body to Claude format diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index 693f5d1c5..bd7a421f7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -1,12 +1,10 @@ package main import ( - "strings" "testing" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test" - "github.com/higress-group/wasm-go/pkg/iface" ) func Test_getApiName(t *testing.T) { @@ -135,6 +133,8 @@ func TestOpenAI(t *testing.T) { test.RunOpenAIOnHttpResponseHeadersTests(t) test.RunOpenAIOnHttpResponseBodyTests(t) test.RunOpenAIOnStreamingResponseBodyTests(t) + test.RunOpenAIPromoteThinkingOnEmptyTests(t) + test.RunOpenAIPromoteThinkingOnEmptyStreamingTests(t) } func TestQwen(t *testing.T) { @@ -225,228 +225,3 @@ func TestConsumerAffinity(t *testing.T) { test.RunConsumerAffinityParseConfigTests(t) test.RunConsumerAffinityOnHttpRequestHeadersTests(t) } - -// mockHttpContext is a minimal mock for wrapper.HttpContext used in streaming tests. -type mockHttpContext struct { - contextMap map[string]interface{} -} - -func newMockHttpContext() *mockHttpContext { - return &mockHttpContext{contextMap: make(map[string]interface{})} -} - -func (m *mockHttpContext) SetContext(key string, value interface{}) { m.contextMap[key] = value } -func (m *mockHttpContext) GetContext(key string) interface{} { return m.contextMap[key] } -func (m *mockHttpContext) GetBoolContext(key string, def bool) bool { return def } -func (m *mockHttpContext) GetStringContext(key, def string) string { return def } -func (m *mockHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def } -func (m *mockHttpContext) Scheme() string { return "" } -func (m *mockHttpContext) Host() string { return "" } -func (m *mockHttpContext) Path() string { return "" } -func (m *mockHttpContext) Method() string { return "" } -func (m *mockHttpContext) GetUserAttribute(key string) interface{} { return nil } -func (m *mockHttpContext) SetUserAttribute(key string, value interface{}) {} -func (m *mockHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {} -func (m *mockHttpContext) GetUserAttributeMap() map[string]interface{} { return nil } -func (m *mockHttpContext) WriteUserAttributeToLog() error { return nil } -func (m *mockHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil } -func (m *mockHttpContext) WriteUserAttributeToTrace() error { return nil } -func (m *mockHttpContext) DontReadRequestBody() {} -func (m *mockHttpContext) DontReadResponseBody() {} -func (m *mockHttpContext) BufferRequestBody() {} -func (m *mockHttpContext) BufferResponseBody() {} -func (m *mockHttpContext) NeedPauseStreamingResponse() {} -func (m *mockHttpContext) PushBuffer(buffer []byte) {} -func (m *mockHttpContext) PopBuffer() []byte { return nil } -func (m *mockHttpContext) BufferQueueSize() int { return 0 } -func (m *mockHttpContext) DisableReroute() {} -func (m *mockHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {} -func (m *mockHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {} -func (m *mockHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error { - return nil -} -func (m *mockHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 } -func (m *mockHttpContext) HasRequestBody() bool { return false } -func (m *mockHttpContext) HasResponseBody() bool { return false } -func (m *mockHttpContext) IsWebsocket() bool { return false } -func (m *mockHttpContext) IsBinaryRequestBody() bool { return false } -func (m *mockHttpContext) IsBinaryResponseBody() bool { return false } - -func TestPromoteThinkingOnEmptyResponse(t *testing.T) { - t.Run("promotes_reasoning_when_content_empty", func(t *testing.T) { - body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"","reasoning_content":"这是思考内容"},"finish_reason":"stop"}]}`) - result, err := provider.PromoteThinkingOnEmptyResponse(body) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - // content should now contain the reasoning text - if !contains(result, `"content":"这是思考内容"`) { - t.Errorf("expected reasoning promoted to content, got: %s", result) - } - // reasoning_content should be cleared - if contains(result, `"reasoning_content":"这是思考内容"`) { - t.Errorf("expected reasoning_content to be cleared, got: %s", result) - } - }) - - t.Run("promotes_reasoning_when_content_nil", func(t *testing.T) { - body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","reasoning_content":"思考结果"},"finish_reason":"stop"}]}`) - result, err := provider.PromoteThinkingOnEmptyResponse(body) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !contains(result, `"content":"思考结果"`) { - t.Errorf("expected reasoning promoted to content, got: %s", result) - } - }) - - t.Run("no_change_when_content_present", func(t *testing.T) { - body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复","reasoning_content":"思考过程"},"finish_reason":"stop"}]}`) - result, err := provider.PromoteThinkingOnEmptyResponse(body) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - // Should return original body unchanged - if string(result) != string(body) { - t.Errorf("expected body unchanged, got: %s", result) - } - }) - - t.Run("no_change_when_no_reasoning", func(t *testing.T) { - body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复"},"finish_reason":"stop"}]}`) - result, err := provider.PromoteThinkingOnEmptyResponse(body) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if string(result) != string(body) { - t.Errorf("expected body unchanged, got: %s", result) - } - }) - - t.Run("no_change_when_both_empty", func(t *testing.T) { - body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}]}`) - result, err := provider.PromoteThinkingOnEmptyResponse(body) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if string(result) != string(body) { - t.Errorf("expected body unchanged, got: %s", result) - } - }) - - t.Run("invalid_json_returns_original", func(t *testing.T) { - body := []byte(`not json`) - result, err := provider.PromoteThinkingOnEmptyResponse(body) - if err == nil { - t.Fatal("expected error for invalid json") - } - if string(result) != string(body) { - t.Errorf("expected original body returned on error") - } - }) -} - -func TestPromoteStreamingThinkingOnEmptyChunk(t *testing.T) { - t.Run("promotes_reasoning_delta_when_no_content", func(t *testing.T) { - ctx := newMockHttpContext() - data := []byte(`{"choices":[{"index":0,"delta":{"role":"assistant","reasoning_content":"思考中"}}]}`) - result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !contains(result, `"content":"思考中"`) { - t.Errorf("expected reasoning promoted to content delta, got: %s", result) - } - }) - - t.Run("no_promote_after_content_seen", func(t *testing.T) { - ctx := newMockHttpContext() - // First chunk: content delta - data1 := []byte(`{"choices":[{"index":0,"delta":{"content":"正文"}}]}`) - _, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1) - - // Second chunk: reasoning only — should NOT be promoted - data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"后续思考"}}]}`) - result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - // Should return unchanged since content was already seen - if string(result) != string(data2) { - t.Errorf("expected no promotion after content seen, got: %s", result) - } - }) - - t.Run("promotes_reasoning_field_when_no_content", func(t *testing.T) { - ctx := newMockHttpContext() - data := []byte(`{"choices":[{"index":0,"delta":{"reasoning":"流式思考"}}]}`) - result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !contains(result, `"content":"流式思考"`) { - t.Errorf("expected reasoning promoted to content delta, got: %s", result) - } - }) - - t.Run("no_change_when_content_present_in_delta", func(t *testing.T) { - ctx := newMockHttpContext() - data := []byte(`{"choices":[{"index":0,"delta":{"content":"有内容","reasoning_content":"也有思考"}}]}`) - result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if string(result) != string(data) { - t.Errorf("expected no change when content present, got: %s", result) - } - }) - - t.Run("invalid_json_returns_original", func(t *testing.T) { - ctx := newMockHttpContext() - data := []byte(`not json`) - result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) - if err != nil { - t.Fatalf("unexpected error for invalid json: %v", err) - } - if string(result) != string(data) { - t.Errorf("expected original data returned") - } - }) -} - -func TestPromoteThinkingInStreamingChunk(t *testing.T) { - t.Run("promotes_in_sse_format", func(t *testing.T) { - ctx := newMockHttpContext() - chunk := []byte("data: {\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"思考\"}}]}\n\n") - result := promoteThinkingInStreamingChunk(ctx, chunk) - if !contains(result, `"content":"思考"`) { - t.Errorf("expected reasoning promoted in SSE chunk, got: %s", result) - } - }) - - t.Run("skips_done_marker", func(t *testing.T) { - ctx := newMockHttpContext() - chunk := []byte("data: [DONE]\n\n") - result := promoteThinkingInStreamingChunk(ctx, chunk) - if string(result) != string(chunk) { - t.Errorf("expected [DONE] unchanged, got: %s", result) - } - }) - - t.Run("handles_multiple_events", func(t *testing.T) { - ctx := newMockHttpContext() - chunk := []byte("data: {\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"第一段\"}}]}\n\ndata: {\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"第二段\"}}]}\n\n") - result := promoteThinkingInStreamingChunk(ctx, chunk) - if !contains(result, `"content":"第一段"`) { - t.Errorf("expected first reasoning promoted, got: %s", result) - } - if !contains(result, `"content":"第二段"`) { - t.Errorf("expected second reasoning promoted, got: %s", result) - } - }) -} - -// contains checks if s contains substr -func contains(b []byte, substr string) bool { - return strings.Contains(string(b), substr) -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index f687c4e65..8f951f543 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -273,9 +273,12 @@ func (r *chatCompletionResponse) promoteThinkingOnEmpty() { } } -// promoteStreamingThinkingOnEmpty promotes reasoning delta to content delta when no content -// has been seen in the stream so far. Uses context to track state across chunks. -// Returns true if a promotion was performed. +// promoteStreamingThinkingOnEmpty accumulates reasoning content during streaming. +// It strips reasoning from chunks and buffers it. When content is seen, it marks +// the stream as having content so no promotion will happen. +// Call PromoteStreamingThinkingFlush at the end of the stream to emit buffered +// reasoning as content if no content was ever seen. +// Returns true if the chunk was modified (reasoning stripped). func promoteStreamingThinkingOnEmpty(ctx wrapper.HttpContext, msg *chatMessage) bool { if msg == nil { return false @@ -290,12 +293,14 @@ func promoteStreamingThinkingOnEmpty(ctx wrapper.HttpContext, msg *chatMessage) return false } + // Buffer reasoning content and strip it from the chunk reasoning := msg.ReasoningContent if reasoning == "" { reasoning = msg.Reasoning } if reasoning != "" { - msg.Content = reasoning + buffered, _ := ctx.GetContext(ctxKeyBufferedReasoning).(string) + ctx.SetContext(ctxKeyBufferedReasoning, buffered+reasoning) msg.ReasoningContent = "" msg.Reasoning = "" return true @@ -736,25 +741,58 @@ func PromoteThinkingOnEmptyResponse(body []byte) ([]byte, error) { return json.Marshal(resp) } -// PromoteStreamingThinkingOnEmptyChunk promotes reasoning delta to content delta in a -// streaming SSE data payload when no content has been seen in the stream so far. +// PromoteStreamingThinkingOnEmptyChunk buffers reasoning deltas and strips them from +// the chunk during streaming. Call PromoteStreamingThinkingFlush on the last chunk +// to emit buffered reasoning as content if no real content was ever seen. func PromoteStreamingThinkingOnEmptyChunk(ctx wrapper.HttpContext, data []byte) ([]byte, error) { var resp chatCompletionResponse if err := json.Unmarshal(data, &resp); err != nil { return data, nil // not a valid chat completion chunk, skip } - promoted := false + modified := false for i := range resp.Choices { msg := resp.Choices[i].Delta if msg == nil { continue } if promoteStreamingThinkingOnEmpty(ctx, msg) { - promoted = true + modified = true } } - if !promoted { + if !modified { return data, nil } return json.Marshal(resp) } + +// PromoteStreamingThinkingFlush checks if the stream had no content and returns +// an SSE chunk that emits the buffered reasoning as content. Returns nil if +// content was already seen or no reasoning was buffered. +func PromoteStreamingThinkingFlush(ctx wrapper.HttpContext) []byte { + hasContentDelta, _ := ctx.GetContext(ctxKeyHasContentDelta).(bool) + if hasContentDelta { + return nil + } + buffered, _ := ctx.GetContext(ctxKeyBufferedReasoning).(string) + if buffered == "" { + return nil + } + // Build a minimal chat.completion.chunk with the buffered reasoning as content + resp := chatCompletionResponse{ + Object: objectChatCompletionChunk, + Choices: []chatCompletionChoice{ + { + Index: 0, + Delta: &chatMessage{ + Content: buffered, + }, + }, + }, + } + data, err := json.Marshal(resp) + if err != nil { + return nil + } + // Format as SSE + return []byte("data: " + string(data) + "\n\n") +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 5c8f63ffa..54e9203e4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -179,6 +179,7 @@ const ( ctxKeyContentPushed = "contentPushed" ctxKeyReasoningContentPushed = "reasoningContentPushed" ctxKeyHasContentDelta = "hasContentDelta" + ctxKeyBufferedReasoning = "bufferedReasoning" objectChatCompletion = "chat.completion" objectChatCompletionChunk = "chat.completion.chunk" diff --git a/plugins/wasm-go/extensions/ai-proxy/test/mock_context.go b/plugins/wasm-go/extensions/ai-proxy/test/mock_context.go new file mode 100644 index 000000000..8d48af7e3 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/mock_context.go @@ -0,0 +1,50 @@ +package test + +import "github.com/higress-group/wasm-go/pkg/iface" + +// MockHttpContext is a minimal mock for wrapper.HttpContext used in unit tests +// that call provider functions directly (e.g. streaming thinking promotion). +type MockHttpContext struct { + contextMap map[string]interface{} +} + +func NewMockHttpContext() *MockHttpContext { + return &MockHttpContext{contextMap: make(map[string]interface{})} +} + +func (m *MockHttpContext) SetContext(key string, value interface{}) { m.contextMap[key] = value } +func (m *MockHttpContext) GetContext(key string) interface{} { return m.contextMap[key] } +func (m *MockHttpContext) GetBoolContext(key string, def bool) bool { return def } +func (m *MockHttpContext) GetStringContext(key, def string) string { return def } +func (m *MockHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def } +func (m *MockHttpContext) Scheme() string { return "" } +func (m *MockHttpContext) Host() string { return "" } +func (m *MockHttpContext) Path() string { return "" } +func (m *MockHttpContext) Method() string { return "" } +func (m *MockHttpContext) GetUserAttribute(key string) interface{} { return nil } +func (m *MockHttpContext) SetUserAttribute(key string, value interface{}) {} +func (m *MockHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {} +func (m *MockHttpContext) GetUserAttributeMap() map[string]interface{} { return nil } +func (m *MockHttpContext) WriteUserAttributeToLog() error { return nil } +func (m *MockHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil } +func (m *MockHttpContext) WriteUserAttributeToTrace() error { return nil } +func (m *MockHttpContext) DontReadRequestBody() {} +func (m *MockHttpContext) DontReadResponseBody() {} +func (m *MockHttpContext) BufferRequestBody() {} +func (m *MockHttpContext) BufferResponseBody() {} +func (m *MockHttpContext) NeedPauseStreamingResponse() {} +func (m *MockHttpContext) PushBuffer(buffer []byte) {} +func (m *MockHttpContext) PopBuffer() []byte { return nil } +func (m *MockHttpContext) BufferQueueSize() int { return 0 } +func (m *MockHttpContext) DisableReroute() {} +func (m *MockHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {} +func (m *MockHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {} +func (m *MockHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error { + return nil +} +func (m *MockHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 } +func (m *MockHttpContext) HasRequestBody() bool { return false } +func (m *MockHttpContext) HasResponseBody() bool { return false } +func (m *MockHttpContext) IsWebsocket() bool { return false } +func (m *MockHttpContext) IsBinaryRequestBody() bool { return false } +func (m *MockHttpContext) IsBinaryResponseBody() bool { return false } diff --git a/plugins/wasm-go/extensions/ai-proxy/test/openai.go b/plugins/wasm-go/extensions/ai-proxy/test/openai.go index 2f72fabb0..9c5d0562f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/openai.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/test" "github.com/stretchr/testify/require" @@ -997,3 +998,158 @@ func RunOpenAIOnStreamingResponseBodyTests(t *testing.T) { }) }) } + +// 测试配置:OpenAI配置 + promoteThinkingOnEmpty +var openAIPromoteThinkingConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-openai-test123456789"}, + "promoteThinkingOnEmpty": true, + }, + }) + return data +}() + +// 测试配置:OpenAI配置 + hiclawMode +var openAIHiclawModeConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-openai-test123456789"}, + "hiclawMode": true, + }, + }) + return data +}() + +func RunOpenAIPromoteThinkingOnEmptyTests(t *testing.T) { + // Config parsing tests via host framework + test.RunGoTest(t, func(t *testing.T) { + t.Run("promoteThinkingOnEmpty config parses", func(t *testing.T) { + host, status := test.NewTestHost(openAIPromoteThinkingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + t.Run("hiclawMode config parses", func(t *testing.T) { + host, status := test.NewTestHost(openAIHiclawModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + }) + + // Non-streaming promote logic tests via provider functions directly + t.Run("promotes reasoning_content when content is empty string", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"","reasoning_content":"这是思考内容"},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + require.NoError(t, err) + require.Contains(t, string(result), `"content":"这是思考内容"`) + require.NotContains(t, string(result), `"reasoning_content":"这是思考内容"`) + }) + + t.Run("promotes reasoning_content when content is nil", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","reasoning_content":"思考结果"},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + require.NoError(t, err) + require.Contains(t, string(result), `"content":"思考结果"`) + }) + + t.Run("no promotion when content is present", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复","reasoning_content":"思考过程"},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + require.NoError(t, err) + require.Equal(t, string(body), string(result)) + }) + + t.Run("no promotion when no reasoning", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复"},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + require.NoError(t, err) + require.Equal(t, string(body), string(result)) + }) + + t.Run("no promotion when both empty", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + require.NoError(t, err) + require.Equal(t, string(body), string(result)) + }) + + t.Run("invalid json returns error", func(t *testing.T) { + body := []byte(`not json`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + require.Error(t, err) + require.Equal(t, string(body), string(result)) + }) +} + +func RunOpenAIPromoteThinkingOnEmptyStreamingTests(t *testing.T) { + // Streaming tests use provider functions directly since the test framework + // does not expose GetStreamingResponseBody. + t.Run("streaming: buffers reasoning and flushes on end when no content", func(t *testing.T) { + ctx := NewMockHttpContext() + // Chunk with only reasoning_content + data := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"流式思考"}}]}`) + result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) + require.NoError(t, err) + // Reasoning should be stripped (not promoted inline) + require.NotContains(t, string(result), `"content":"流式思考"`) + + // Flush should emit buffered reasoning as content + flush := provider.PromoteStreamingThinkingFlush(ctx) + require.NotNil(t, flush) + require.Contains(t, string(flush), `"content":"流式思考"`) + }) + + t.Run("streaming: no flush when content was seen", func(t *testing.T) { + ctx := NewMockHttpContext() + // First chunk: content delta + data1 := []byte(`{"choices":[{"index":0,"delta":{"content":"正文"}}]}`) + _, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1) + + // Second chunk: reasoning only + data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"后续思考"}}]}`) + result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2) + require.NoError(t, err) + // Should be unchanged since content was already seen + require.Equal(t, string(data2), string(result)) + + // Flush should return nil since content was seen + flush := provider.PromoteStreamingThinkingFlush(ctx) + require.Nil(t, flush) + }) + + t.Run("streaming: accumulates multiple reasoning chunks", func(t *testing.T) { + ctx := NewMockHttpContext() + data1 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"第一段"}}]}`) + _, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1) + + data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"第二段"}}]}`) + _, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2) + + flush := provider.PromoteStreamingThinkingFlush(ctx) + require.NotNil(t, flush) + require.Contains(t, string(flush), `"content":"第一段第二段"`) + }) + + t.Run("streaming: no flush when no reasoning buffered", func(t *testing.T) { + ctx := NewMockHttpContext() + flush := provider.PromoteStreamingThinkingFlush(ctx) + require.Nil(t, flush) + }) + + t.Run("streaming: invalid json returns original", func(t *testing.T) { + ctx := NewMockHttpContext() + data := []byte(`not json`) + result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) + require.NoError(t, err) + require.Equal(t, string(data), string(result)) + }) +}