mirror of
https://github.com/alibaba/higress.git
synced 2026-03-20 11:08:03 +08:00
refactor(ai-proxy): redesign streaming thinking promotion to buffer-and-flush
Instead of promoting reasoning to content inline per-chunk (which would emit reasoning as content prematurely if real content arrives later), the streaming path now buffers reasoning content and strips it from chunks. On the last chunk, if no content was ever seen, the buffered reasoning is flushed as a single content chunk. Also moves tests into test/openai.go TestOpenAI suite and adds MockHttpContext for provider-level streaming tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -395,7 +395,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk)
|
||||
if err == nil && modifiedChunk != nil {
|
||||
if promoteThinking {
|
||||
modifiedChunk = promoteThinkingInStreamingChunk(ctx, modifiedChunk)
|
||||
modifiedChunk = promoteThinkingInStreamingChunk(ctx, modifiedChunk, isLastChunk)
|
||||
}
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, modifiedChunk)
|
||||
@@ -441,7 +441,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
result := []byte(responseBuilder.String())
|
||||
|
||||
if promoteThinking {
|
||||
result = promoteThinkingInStreamingChunk(ctx, result)
|
||||
result = promoteThinkingInStreamingChunk(ctx, result, isLastChunk)
|
||||
}
|
||||
|
||||
// Convert to Claude format if needed
|
||||
@@ -475,7 +475,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
result := []byte(responseBuilder.String())
|
||||
|
||||
if promoteThinking {
|
||||
result = promoteThinkingInStreamingChunk(ctx, result)
|
||||
result = promoteThinkingInStreamingChunk(ctx, result, isLastChunk)
|
||||
}
|
||||
|
||||
// Convert to Claude format if needed
|
||||
@@ -568,9 +568,10 @@ func convertStreamingResponseToClaude(ctx wrapper.HttpContext, data []byte) ([]b
|
||||
return claudeChunk, nil
|
||||
}
|
||||
|
||||
// promoteThinkingInStreamingChunk processes SSE-formatted streaming data and promotes
|
||||
// reasoning/thinking deltas to content deltas when no content has been seen.
|
||||
func promoteThinkingInStreamingChunk(ctx wrapper.HttpContext, data []byte) []byte {
|
||||
// promoteThinkingInStreamingChunk processes SSE-formatted streaming data, buffering
|
||||
// reasoning deltas and stripping them from chunks. On the last chunk, if no content
|
||||
// was ever seen, it appends a flush chunk that emits buffered reasoning as content.
|
||||
func promoteThinkingInStreamingChunk(ctx wrapper.HttpContext, data []byte, isLastChunk bool) []byte {
|
||||
// SSE data contains lines like "data: {...}\n\n"
|
||||
// We need to find and process each data line
|
||||
lines := strings.Split(string(data), "\n")
|
||||
@@ -583,20 +584,31 @@ func promoteThinkingInStreamingChunk(ctx wrapper.HttpContext, data []byte) []byt
|
||||
if payload == "[DONE]" || payload == "" {
|
||||
continue
|
||||
}
|
||||
promoted, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, []byte(payload))
|
||||
stripped, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, []byte(payload))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
newLine := "data: " + string(promoted)
|
||||
newLine := "data: " + string(stripped)
|
||||
if newLine != line {
|
||||
lines[i] = newLine
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if !modified {
|
||||
return data
|
||||
|
||||
result := data
|
||||
if modified {
|
||||
result = []byte(strings.Join(lines, "\n"))
|
||||
}
|
||||
return []byte(strings.Join(lines, "\n"))
|
||||
|
||||
// On last chunk, flush buffered reasoning as content if no content was seen
|
||||
if isLastChunk {
|
||||
flushChunk := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
if flushChunk != nil {
|
||||
result = append(flushChunk, result...)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper function to convert OpenAI response body to Claude format
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test"
|
||||
"github.com/higress-group/wasm-go/pkg/iface"
|
||||
)
|
||||
|
||||
func Test_getApiName(t *testing.T) {
|
||||
@@ -135,6 +133,8 @@ func TestOpenAI(t *testing.T) {
|
||||
test.RunOpenAIOnHttpResponseHeadersTests(t)
|
||||
test.RunOpenAIOnHttpResponseBodyTests(t)
|
||||
test.RunOpenAIOnStreamingResponseBodyTests(t)
|
||||
test.RunOpenAIPromoteThinkingOnEmptyTests(t)
|
||||
test.RunOpenAIPromoteThinkingOnEmptyStreamingTests(t)
|
||||
}
|
||||
|
||||
func TestQwen(t *testing.T) {
|
||||
@@ -225,228 +225,3 @@ func TestConsumerAffinity(t *testing.T) {
|
||||
test.RunConsumerAffinityParseConfigTests(t)
|
||||
test.RunConsumerAffinityOnHttpRequestHeadersTests(t)
|
||||
}
|
||||
|
||||
// mockHttpContext is a minimal mock for wrapper.HttpContext used in streaming tests.
|
||||
type mockHttpContext struct {
|
||||
contextMap map[string]interface{}
|
||||
}
|
||||
|
||||
func newMockHttpContext() *mockHttpContext {
|
||||
return &mockHttpContext{contextMap: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
func (m *mockHttpContext) SetContext(key string, value interface{}) { m.contextMap[key] = value }
|
||||
func (m *mockHttpContext) GetContext(key string) interface{} { return m.contextMap[key] }
|
||||
func (m *mockHttpContext) GetBoolContext(key string, def bool) bool { return def }
|
||||
func (m *mockHttpContext) GetStringContext(key, def string) string { return def }
|
||||
func (m *mockHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def }
|
||||
func (m *mockHttpContext) Scheme() string { return "" }
|
||||
func (m *mockHttpContext) Host() string { return "" }
|
||||
func (m *mockHttpContext) Path() string { return "" }
|
||||
func (m *mockHttpContext) Method() string { return "" }
|
||||
func (m *mockHttpContext) GetUserAttribute(key string) interface{} { return nil }
|
||||
func (m *mockHttpContext) SetUserAttribute(key string, value interface{}) {}
|
||||
func (m *mockHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {}
|
||||
func (m *mockHttpContext) GetUserAttributeMap() map[string]interface{} { return nil }
|
||||
func (m *mockHttpContext) WriteUserAttributeToLog() error { return nil }
|
||||
func (m *mockHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil }
|
||||
func (m *mockHttpContext) WriteUserAttributeToTrace() error { return nil }
|
||||
func (m *mockHttpContext) DontReadRequestBody() {}
|
||||
func (m *mockHttpContext) DontReadResponseBody() {}
|
||||
func (m *mockHttpContext) BufferRequestBody() {}
|
||||
func (m *mockHttpContext) BufferResponseBody() {}
|
||||
func (m *mockHttpContext) NeedPauseStreamingResponse() {}
|
||||
func (m *mockHttpContext) PushBuffer(buffer []byte) {}
|
||||
func (m *mockHttpContext) PopBuffer() []byte { return nil }
|
||||
func (m *mockHttpContext) BufferQueueSize() int { return 0 }
|
||||
func (m *mockHttpContext) DisableReroute() {}
|
||||
func (m *mockHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *mockHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *mockHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 }
|
||||
func (m *mockHttpContext) HasRequestBody() bool { return false }
|
||||
func (m *mockHttpContext) HasResponseBody() bool { return false }
|
||||
func (m *mockHttpContext) IsWebsocket() bool { return false }
|
||||
func (m *mockHttpContext) IsBinaryRequestBody() bool { return false }
|
||||
func (m *mockHttpContext) IsBinaryResponseBody() bool { return false }
|
||||
|
||||
func TestPromoteThinkingOnEmptyResponse(t *testing.T) {
|
||||
t.Run("promotes_reasoning_when_content_empty", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"","reasoning_content":"这是思考内容"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// content should now contain the reasoning text
|
||||
if !contains(result, `"content":"这是思考内容"`) {
|
||||
t.Errorf("expected reasoning promoted to content, got: %s", result)
|
||||
}
|
||||
// reasoning_content should be cleared
|
||||
if contains(result, `"reasoning_content":"这是思考内容"`) {
|
||||
t.Errorf("expected reasoning_content to be cleared, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("promotes_reasoning_when_content_nil", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","reasoning_content":"思考结果"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !contains(result, `"content":"思考结果"`) {
|
||||
t.Errorf("expected reasoning promoted to content, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no_change_when_content_present", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复","reasoning_content":"思考过程"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Should return original body unchanged
|
||||
if string(result) != string(body) {
|
||||
t.Errorf("expected body unchanged, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no_change_when_no_reasoning", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(result) != string(body) {
|
||||
t.Errorf("expected body unchanged, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no_change_when_both_empty", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(result) != string(body) {
|
||||
t.Errorf("expected body unchanged, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_json_returns_original", func(t *testing.T) {
|
||||
body := []byte(`not json`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid json")
|
||||
}
|
||||
if string(result) != string(body) {
|
||||
t.Errorf("expected original body returned on error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPromoteStreamingThinkingOnEmptyChunk(t *testing.T) {
|
||||
t.Run("promotes_reasoning_delta_when_no_content", func(t *testing.T) {
|
||||
ctx := newMockHttpContext()
|
||||
data := []byte(`{"choices":[{"index":0,"delta":{"role":"assistant","reasoning_content":"思考中"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !contains(result, `"content":"思考中"`) {
|
||||
t.Errorf("expected reasoning promoted to content delta, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no_promote_after_content_seen", func(t *testing.T) {
|
||||
ctx := newMockHttpContext()
|
||||
// First chunk: content delta
|
||||
data1 := []byte(`{"choices":[{"index":0,"delta":{"content":"正文"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1)
|
||||
|
||||
// Second chunk: reasoning only — should NOT be promoted
|
||||
data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"后续思考"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Should return unchanged since content was already seen
|
||||
if string(result) != string(data2) {
|
||||
t.Errorf("expected no promotion after content seen, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("promotes_reasoning_field_when_no_content", func(t *testing.T) {
|
||||
ctx := newMockHttpContext()
|
||||
data := []byte(`{"choices":[{"index":0,"delta":{"reasoning":"流式思考"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !contains(result, `"content":"流式思考"`) {
|
||||
t.Errorf("expected reasoning promoted to content delta, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no_change_when_content_present_in_delta", func(t *testing.T) {
|
||||
ctx := newMockHttpContext()
|
||||
data := []byte(`{"choices":[{"index":0,"delta":{"content":"有内容","reasoning_content":"也有思考"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(result) != string(data) {
|
||||
t.Errorf("expected no change when content present, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_json_returns_original", func(t *testing.T) {
|
||||
ctx := newMockHttpContext()
|
||||
data := []byte(`not json`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for invalid json: %v", err)
|
||||
}
|
||||
if string(result) != string(data) {
|
||||
t.Errorf("expected original data returned")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPromoteThinkingInStreamingChunk(t *testing.T) {
|
||||
t.Run("promotes_in_sse_format", func(t *testing.T) {
|
||||
ctx := newMockHttpContext()
|
||||
chunk := []byte("data: {\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"思考\"}}]}\n\n")
|
||||
result := promoteThinkingInStreamingChunk(ctx, chunk)
|
||||
if !contains(result, `"content":"思考"`) {
|
||||
t.Errorf("expected reasoning promoted in SSE chunk, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips_done_marker", func(t *testing.T) {
|
||||
ctx := newMockHttpContext()
|
||||
chunk := []byte("data: [DONE]\n\n")
|
||||
result := promoteThinkingInStreamingChunk(ctx, chunk)
|
||||
if string(result) != string(chunk) {
|
||||
t.Errorf("expected [DONE] unchanged, got: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles_multiple_events", func(t *testing.T) {
|
||||
ctx := newMockHttpContext()
|
||||
chunk := []byte("data: {\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"第一段\"}}]}\n\ndata: {\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"第二段\"}}]}\n\n")
|
||||
result := promoteThinkingInStreamingChunk(ctx, chunk)
|
||||
if !contains(result, `"content":"第一段"`) {
|
||||
t.Errorf("expected first reasoning promoted, got: %s", result)
|
||||
}
|
||||
if !contains(result, `"content":"第二段"`) {
|
||||
t.Errorf("expected second reasoning promoted, got: %s", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// contains checks if s contains substr
|
||||
func contains(b []byte, substr string) bool {
|
||||
return strings.Contains(string(b), substr)
|
||||
}
|
||||
|
||||
@@ -273,9 +273,12 @@ func (r *chatCompletionResponse) promoteThinkingOnEmpty() {
|
||||
}
|
||||
}
|
||||
|
||||
// promoteStreamingThinkingOnEmpty promotes reasoning delta to content delta when no content
|
||||
// has been seen in the stream so far. Uses context to track state across chunks.
|
||||
// Returns true if a promotion was performed.
|
||||
// promoteStreamingThinkingOnEmpty accumulates reasoning content during streaming.
|
||||
// It strips reasoning from chunks and buffers it. When content is seen, it marks
|
||||
// the stream as having content so no promotion will happen.
|
||||
// Call PromoteStreamingThinkingFlush at the end of the stream to emit buffered
|
||||
// reasoning as content if no content was ever seen.
|
||||
// Returns true if the chunk was modified (reasoning stripped).
|
||||
func promoteStreamingThinkingOnEmpty(ctx wrapper.HttpContext, msg *chatMessage) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
@@ -290,12 +293,14 @@ func promoteStreamingThinkingOnEmpty(ctx wrapper.HttpContext, msg *chatMessage)
|
||||
return false
|
||||
}
|
||||
|
||||
// Buffer reasoning content and strip it from the chunk
|
||||
reasoning := msg.ReasoningContent
|
||||
if reasoning == "" {
|
||||
reasoning = msg.Reasoning
|
||||
}
|
||||
if reasoning != "" {
|
||||
msg.Content = reasoning
|
||||
buffered, _ := ctx.GetContext(ctxKeyBufferedReasoning).(string)
|
||||
ctx.SetContext(ctxKeyBufferedReasoning, buffered+reasoning)
|
||||
msg.ReasoningContent = ""
|
||||
msg.Reasoning = ""
|
||||
return true
|
||||
@@ -736,25 +741,58 @@ func PromoteThinkingOnEmptyResponse(body []byte) ([]byte, error) {
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
// PromoteStreamingThinkingOnEmptyChunk promotes reasoning delta to content delta in a
|
||||
// streaming SSE data payload when no content has been seen in the stream so far.
|
||||
// PromoteStreamingThinkingOnEmptyChunk buffers reasoning deltas and strips them from
|
||||
// the chunk during streaming. Call PromoteStreamingThinkingFlush on the last chunk
|
||||
// to emit buffered reasoning as content if no real content was ever seen.
|
||||
func PromoteStreamingThinkingOnEmptyChunk(ctx wrapper.HttpContext, data []byte) ([]byte, error) {
|
||||
var resp chatCompletionResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return data, nil // not a valid chat completion chunk, skip
|
||||
}
|
||||
promoted := false
|
||||
modified := false
|
||||
for i := range resp.Choices {
|
||||
msg := resp.Choices[i].Delta
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if promoteStreamingThinkingOnEmpty(ctx, msg) {
|
||||
promoted = true
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if !promoted {
|
||||
if !modified {
|
||||
return data, nil
|
||||
}
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
// PromoteStreamingThinkingFlush checks if the stream had no content and returns
|
||||
// an SSE chunk that emits the buffered reasoning as content. Returns nil if
|
||||
// content was already seen or no reasoning was buffered.
|
||||
func PromoteStreamingThinkingFlush(ctx wrapper.HttpContext) []byte {
|
||||
hasContentDelta, _ := ctx.GetContext(ctxKeyHasContentDelta).(bool)
|
||||
if hasContentDelta {
|
||||
return nil
|
||||
}
|
||||
buffered, _ := ctx.GetContext(ctxKeyBufferedReasoning).(string)
|
||||
if buffered == "" {
|
||||
return nil
|
||||
}
|
||||
// Build a minimal chat.completion.chunk with the buffered reasoning as content
|
||||
resp := chatCompletionResponse{
|
||||
Object: objectChatCompletionChunk,
|
||||
Choices: []chatCompletionChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: &chatMessage{
|
||||
Content: buffered,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// Format as SSE
|
||||
return []byte("data: " + string(data) + "\n\n")
|
||||
}
|
||||
|
||||
@@ -179,6 +179,7 @@ const (
|
||||
ctxKeyContentPushed = "contentPushed"
|
||||
ctxKeyReasoningContentPushed = "reasoningContentPushed"
|
||||
ctxKeyHasContentDelta = "hasContentDelta"
|
||||
ctxKeyBufferedReasoning = "bufferedReasoning"
|
||||
|
||||
objectChatCompletion = "chat.completion"
|
||||
objectChatCompletionChunk = "chat.completion.chunk"
|
||||
|
||||
50
plugins/wasm-go/extensions/ai-proxy/test/mock_context.go
Normal file
50
plugins/wasm-go/extensions/ai-proxy/test/mock_context.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package test
|
||||
|
||||
import "github.com/higress-group/wasm-go/pkg/iface"
|
||||
|
||||
// MockHttpContext is a minimal mock for wrapper.HttpContext used in unit tests
|
||||
// that call provider functions directly (e.g. streaming thinking promotion).
|
||||
type MockHttpContext struct {
|
||||
contextMap map[string]interface{}
|
||||
}
|
||||
|
||||
func NewMockHttpContext() *MockHttpContext {
|
||||
return &MockHttpContext{contextMap: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
func (m *MockHttpContext) SetContext(key string, value interface{}) { m.contextMap[key] = value }
|
||||
func (m *MockHttpContext) GetContext(key string) interface{} { return m.contextMap[key] }
|
||||
func (m *MockHttpContext) GetBoolContext(key string, def bool) bool { return def }
|
||||
func (m *MockHttpContext) GetStringContext(key, def string) string { return def }
|
||||
func (m *MockHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def }
|
||||
func (m *MockHttpContext) Scheme() string { return "" }
|
||||
func (m *MockHttpContext) Host() string { return "" }
|
||||
func (m *MockHttpContext) Path() string { return "" }
|
||||
func (m *MockHttpContext) Method() string { return "" }
|
||||
func (m *MockHttpContext) GetUserAttribute(key string) interface{} { return nil }
|
||||
func (m *MockHttpContext) SetUserAttribute(key string, value interface{}) {}
|
||||
func (m *MockHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {}
|
||||
func (m *MockHttpContext) GetUserAttributeMap() map[string]interface{} { return nil }
|
||||
func (m *MockHttpContext) WriteUserAttributeToLog() error { return nil }
|
||||
func (m *MockHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil }
|
||||
func (m *MockHttpContext) WriteUserAttributeToTrace() error { return nil }
|
||||
func (m *MockHttpContext) DontReadRequestBody() {}
|
||||
func (m *MockHttpContext) DontReadResponseBody() {}
|
||||
func (m *MockHttpContext) BufferRequestBody() {}
|
||||
func (m *MockHttpContext) BufferResponseBody() {}
|
||||
func (m *MockHttpContext) NeedPauseStreamingResponse() {}
|
||||
func (m *MockHttpContext) PushBuffer(buffer []byte) {}
|
||||
func (m *MockHttpContext) PopBuffer() []byte { return nil }
|
||||
func (m *MockHttpContext) BufferQueueSize() int { return 0 }
|
||||
func (m *MockHttpContext) DisableReroute() {}
|
||||
func (m *MockHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *MockHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *MockHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error {
|
||||
return nil
|
||||
}
|
||||
func (m *MockHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 }
|
||||
func (m *MockHttpContext) HasRequestBody() bool { return false }
|
||||
func (m *MockHttpContext) HasResponseBody() bool { return false }
|
||||
func (m *MockHttpContext) IsWebsocket() bool { return false }
|
||||
func (m *MockHttpContext) IsBinaryRequestBody() bool { return false }
|
||||
func (m *MockHttpContext) IsBinaryResponseBody() bool { return false }
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -997,3 +998,158 @@ func RunOpenAIOnStreamingResponseBodyTests(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 测试配置:OpenAI配置 + promoteThinkingOnEmpty
|
||||
var openAIPromoteThinkingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-test123456789"},
|
||||
"promoteThinkingOnEmpty": true,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:OpenAI配置 + hiclawMode
|
||||
var openAIHiclawModeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-test123456789"},
|
||||
"hiclawMode": true,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunOpenAIPromoteThinkingOnEmptyTests(t *testing.T) {
|
||||
// Config parsing tests via host framework
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("promoteThinkingOnEmpty config parses", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAIPromoteThinkingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
t.Run("hiclawMode config parses", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAIHiclawModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
|
||||
// Non-streaming promote logic tests via provider functions directly
|
||||
t.Run("promotes reasoning_content when content is empty string", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"","reasoning_content":"这是思考内容"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(result), `"content":"这是思考内容"`)
|
||||
require.NotContains(t, string(result), `"reasoning_content":"这是思考内容"`)
|
||||
})
|
||||
|
||||
t.Run("promotes reasoning_content when content is nil", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","reasoning_content":"思考结果"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(result), `"content":"思考结果"`)
|
||||
})
|
||||
|
||||
t.Run("no promotion when content is present", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复","reasoning_content":"思考过程"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
|
||||
t.Run("no promotion when no reasoning", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复"},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
|
||||
t.Run("no promotion when both empty", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}]}`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
|
||||
t.Run("invalid json returns error", func(t *testing.T) {
|
||||
body := []byte(`not json`)
|
||||
result, err := provider.PromoteThinkingOnEmptyResponse(body)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, string(body), string(result))
|
||||
})
|
||||
}
|
||||
|
||||
func RunOpenAIPromoteThinkingOnEmptyStreamingTests(t *testing.T) {
|
||||
// Streaming tests use provider functions directly since the test framework
|
||||
// does not expose GetStreamingResponseBody.
|
||||
t.Run("streaming: buffers reasoning and flushes on end when no content", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
// Chunk with only reasoning_content
|
||||
data := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"流式思考"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
require.NoError(t, err)
|
||||
// Reasoning should be stripped (not promoted inline)
|
||||
require.NotContains(t, string(result), `"content":"流式思考"`)
|
||||
|
||||
// Flush should emit buffered reasoning as content
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.NotNil(t, flush)
|
||||
require.Contains(t, string(flush), `"content":"流式思考"`)
|
||||
})
|
||||
|
||||
t.Run("streaming: no flush when content was seen", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
// First chunk: content delta
|
||||
data1 := []byte(`{"choices":[{"index":0,"delta":{"content":"正文"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1)
|
||||
|
||||
// Second chunk: reasoning only
|
||||
data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"后续思考"}}]}`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2)
|
||||
require.NoError(t, err)
|
||||
// Should be unchanged since content was already seen
|
||||
require.Equal(t, string(data2), string(result))
|
||||
|
||||
// Flush should return nil since content was seen
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.Nil(t, flush)
|
||||
})
|
||||
|
||||
t.Run("streaming: accumulates multiple reasoning chunks", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
data1 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"第一段"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1)
|
||||
|
||||
data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"第二段"}}]}`)
|
||||
_, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2)
|
||||
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.NotNil(t, flush)
|
||||
require.Contains(t, string(flush), `"content":"第一段第二段"`)
|
||||
})
|
||||
|
||||
t.Run("streaming: no flush when no reasoning buffered", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
flush := provider.PromoteStreamingThinkingFlush(ctx)
|
||||
require.Nil(t, flush)
|
||||
})
|
||||
|
||||
t.Run("streaming: invalid json returns original", func(t *testing.T) {
|
||||
ctx := NewMockHttpContext()
|
||||
data := []byte(`not json`)
|
||||
result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(data), string(result))
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user