diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 1e2e574aa..49a8ea50a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -385,6 +385,8 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin return chunk } + promoteThinking := pluginConfig.GetProviderConfig().GetPromoteThinkingOnEmpty() + log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType()) log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk)) @@ -392,6 +394,9 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk) if err == nil && modifiedChunk != nil { + if promoteThinking { + modifiedChunk = promoteThinkingInStreamingChunk(ctx, modifiedChunk) + } // Convert to Claude format if needed claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, modifiedChunk) if convertErr != nil { @@ -435,6 +440,10 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin result := []byte(responseBuilder.String()) + if promoteThinking { + result = promoteThinkingInStreamingChunk(ctx, result) + } + // Convert to Claude format if needed claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result) if convertErr != nil { @@ -443,11 +452,12 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin return claudeChunk } - if !needsClaudeResponseConversion(ctx) { + if !needsClaudeResponseConversion(ctx) && !promoteThinking { return chunk } // If provider doesn't implement any streaming handlers but we need Claude conversion + // or thinking promotion // First extract complete events from the chunk events := provider.ExtractStreamingEvents(ctx, chunk) log.Debugf("[onStreamingResponseBody] %d events received (no handler)", len(events)) @@ -464,6 +474,10 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin result := []byte(responseBuilder.String()) + if promoteThinking { + result = promoteThinkingInStreamingChunk(ctx, result) + } + // Convert to Claude format if needed claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result) if convertErr != nil { @@ -496,6 +510,16 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi finalBody = body } + // Promote thinking/reasoning to content when content is empty + if pluginConfig.GetProviderConfig().GetPromoteThinkingOnEmpty() { + promoted, err := provider.PromoteThinkingOnEmptyResponse(finalBody) + if err != nil { + log.Warnf("[promoteThinkingOnEmpty] failed: %v", err) + } else { + finalBody = promoted + } + } + // Convert to Claude format if needed (applies to both branches) convertedBody, err := convertResponseBodyToClaude(ctx, finalBody) if err != nil { @@ -544,6 +568,37 @@ func convertStreamingResponseToClaude(ctx wrapper.HttpContext, data []byte) ([]b return claudeChunk, nil } +// promoteThinkingInStreamingChunk processes SSE-formatted streaming data and promotes +// reasoning/thinking deltas to content deltas when no content has been seen. +func promoteThinkingInStreamingChunk(ctx wrapper.HttpContext, data []byte) []byte { + // SSE data contains lines like "data: {...}\n\n" + // We need to find and process each data line + lines := strings.Split(string(data), "\n") + modified := false + for i, line := range lines { + if !strings.HasPrefix(line, "data: ") { + continue + } + payload := strings.TrimPrefix(line, "data: ") + if payload == "[DONE]" || payload == "" { + continue + } + promoted, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, []byte(payload)) + if err != nil { + continue + } + newLine := "data: " + string(promoted) + if newLine != line { + lines[i] = newLine + modified = true + } + } + if !modified { + return data + } + return []byte(strings.Join(lines, "\n")) +} + // Helper function to convert OpenAI response body to Claude format func convertResponseBodyToClaude(ctx wrapper.HttpContext, body []byte) ([]byte, error) { if !needsClaudeResponseConversion(ctx) { diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index e3ef5842f..693f5d1c5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -1,10 +1,12 @@ package main import ( + "strings" "testing" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test" + "github.com/higress-group/wasm-go/pkg/iface" ) func Test_getApiName(t *testing.T) { @@ -223,3 +225,228 @@ func TestConsumerAffinity(t *testing.T) { test.RunConsumerAffinityParseConfigTests(t) test.RunConsumerAffinityOnHttpRequestHeadersTests(t) } + +// mockHttpContext is a minimal mock for wrapper.HttpContext used in streaming tests. +type mockHttpContext struct { + contextMap map[string]interface{} +} + +func newMockHttpContext() *mockHttpContext { + return &mockHttpContext{contextMap: make(map[string]interface{})} +} + +func (m *mockHttpContext) SetContext(key string, value interface{}) { m.contextMap[key] = value } +func (m *mockHttpContext) GetContext(key string) interface{} { return m.contextMap[key] } +func (m *mockHttpContext) GetBoolContext(key string, def bool) bool { return def } +func (m *mockHttpContext) GetStringContext(key, def string) string { return def } +func (m *mockHttpContext) GetByteSliceContext(key string, def []byte) []byte { return def } +func (m *mockHttpContext) Scheme() string { return "" } +func (m *mockHttpContext) Host() string { return "" } +func (m *mockHttpContext) Path() string { return "" } +func (m *mockHttpContext) Method() string { return "" } +func (m *mockHttpContext) GetUserAttribute(key string) interface{} { return nil } +func (m *mockHttpContext) SetUserAttribute(key string, value interface{}) {} +func (m *mockHttpContext) SetUserAttributeMap(kvmap map[string]interface{}) {} +func (m *mockHttpContext) GetUserAttributeMap() map[string]interface{} { return nil } +func (m *mockHttpContext) WriteUserAttributeToLog() error { return nil } +func (m *mockHttpContext) WriteUserAttributeToLogWithKey(key string) error { return nil } +func (m *mockHttpContext) WriteUserAttributeToTrace() error { return nil } +func (m *mockHttpContext) DontReadRequestBody() {} +func (m *mockHttpContext) DontReadResponseBody() {} +func (m *mockHttpContext) BufferRequestBody() {} +func (m *mockHttpContext) BufferResponseBody() {} +func (m *mockHttpContext) NeedPauseStreamingResponse() {} +func (m *mockHttpContext) PushBuffer(buffer []byte) {} +func (m *mockHttpContext) PopBuffer() []byte { return nil } +func (m *mockHttpContext) BufferQueueSize() int { return 0 } +func (m *mockHttpContext) DisableReroute() {} +func (m *mockHttpContext) SetRequestBodyBufferLimit(byteSize uint32) {} +func (m *mockHttpContext) SetResponseBodyBufferLimit(byteSize uint32) {} +func (m *mockHttpContext) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error { + return nil +} +func (m *mockHttpContext) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 } +func (m *mockHttpContext) HasRequestBody() bool { return false } +func (m *mockHttpContext) HasResponseBody() bool { return false } +func (m *mockHttpContext) IsWebsocket() bool { return false } +func (m *mockHttpContext) IsBinaryRequestBody() bool { return false } +func (m *mockHttpContext) IsBinaryResponseBody() bool { return false } + +func TestPromoteThinkingOnEmptyResponse(t *testing.T) { + t.Run("promotes_reasoning_when_content_empty", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"","reasoning_content":"这是思考内容"},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // content should now contain the reasoning text + if !contains(result, `"content":"这是思考内容"`) { + t.Errorf("expected reasoning promoted to content, got: %s", result) + } + // reasoning_content should be cleared + if contains(result, `"reasoning_content":"这是思考内容"`) { + t.Errorf("expected reasoning_content to be cleared, got: %s", result) + } + }) + + t.Run("promotes_reasoning_when_content_nil", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","reasoning_content":"思考结果"},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !contains(result, `"content":"思考结果"`) { + t.Errorf("expected reasoning promoted to content, got: %s", result) + } + }) + + t.Run("no_change_when_content_present", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复","reasoning_content":"思考过程"},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should return original body unchanged + if string(result) != string(body) { + t.Errorf("expected body unchanged, got: %s", result) + } + }) + + t.Run("no_change_when_no_reasoning", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":"正常回复"},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(result) != string(body) { + t.Errorf("expected body unchanged, got: %s", result) + } + }) + + t.Run("no_change_when_both_empty", func(t *testing.T) { + body := []byte(`{"choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}]}`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(result) != string(body) { + t.Errorf("expected body unchanged, got: %s", result) + } + }) + + t.Run("invalid_json_returns_original", func(t *testing.T) { + body := []byte(`not json`) + result, err := provider.PromoteThinkingOnEmptyResponse(body) + if err == nil { + t.Fatal("expected error for invalid json") + } + if string(result) != string(body) { + t.Errorf("expected original body returned on error") + } + }) +} + +func TestPromoteStreamingThinkingOnEmptyChunk(t *testing.T) { + t.Run("promotes_reasoning_delta_when_no_content", func(t *testing.T) { + ctx := newMockHttpContext() + data := []byte(`{"choices":[{"index":0,"delta":{"role":"assistant","reasoning_content":"思考中"}}]}`) + result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !contains(result, `"content":"思考中"`) { + t.Errorf("expected reasoning promoted to content delta, got: %s", result) + } + }) + + t.Run("no_promote_after_content_seen", func(t *testing.T) { + ctx := newMockHttpContext() + // First chunk: content delta + data1 := []byte(`{"choices":[{"index":0,"delta":{"content":"正文"}}]}`) + _, _ = provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data1) + + // Second chunk: reasoning only — should NOT be promoted + data2 := []byte(`{"choices":[{"index":0,"delta":{"reasoning_content":"后续思考"}}]}`) + result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should return unchanged since content was already seen + if string(result) != string(data2) { + t.Errorf("expected no promotion after content seen, got: %s", result) + } + }) + + t.Run("promotes_reasoning_field_when_no_content", func(t *testing.T) { + ctx := newMockHttpContext() + data := []byte(`{"choices":[{"index":0,"delta":{"reasoning":"流式思考"}}]}`) + result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !contains(result, `"content":"流式思考"`) { + t.Errorf("expected reasoning promoted to content delta, got: %s", result) + } + }) + + t.Run("no_change_when_content_present_in_delta", func(t *testing.T) { + ctx := newMockHttpContext() + data := []byte(`{"choices":[{"index":0,"delta":{"content":"有内容","reasoning_content":"也有思考"}}]}`) + result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(result) != string(data) { + t.Errorf("expected no change when content present, got: %s", result) + } + }) + + t.Run("invalid_json_returns_original", func(t *testing.T) { + ctx := newMockHttpContext() + data := []byte(`not json`) + result, err := provider.PromoteStreamingThinkingOnEmptyChunk(ctx, data) + if err != nil { + t.Fatalf("unexpected error for invalid json: %v", err) + } + if string(result) != string(data) { + t.Errorf("expected original data returned") + } + }) +} + +func TestPromoteThinkingInStreamingChunk(t *testing.T) { + t.Run("promotes_in_sse_format", func(t *testing.T) { + ctx := newMockHttpContext() + chunk := []byte("data: {\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"思考\"}}]}\n\n") + result := promoteThinkingInStreamingChunk(ctx, chunk) + if !contains(result, `"content":"思考"`) { + t.Errorf("expected reasoning promoted in SSE chunk, got: %s", result) + } + }) + + t.Run("skips_done_marker", func(t *testing.T) { + ctx := newMockHttpContext() + chunk := []byte("data: [DONE]\n\n") + result := promoteThinkingInStreamingChunk(ctx, chunk) + if string(result) != string(chunk) { + t.Errorf("expected [DONE] unchanged, got: %s", result) + } + }) + + t.Run("handles_multiple_events", func(t *testing.T) { + ctx := newMockHttpContext() + chunk := []byte("data: {\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"第一段\"}}]}\n\ndata: {\"choices\":[{\"index\":0,\"delta\":{\"reasoning_content\":\"第二段\"}}]}\n\n") + result := promoteThinkingInStreamingChunk(ctx, chunk) + if !contains(result, `"content":"第一段"`) { + t.Errorf("expected first reasoning promoted, got: %s", result) + } + if !contains(result, `"content":"第二段"`) { + t.Errorf("expected second reasoning promoted, got: %s", result) + } + }) +} + +// contains checks if s contains substr +func contains(b []byte, substr string) bool { + return strings.Contains(string(b), substr) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 881e4cbc4..f687c4e65 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -255,6 +255,65 @@ func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, r } } +// promoteThinkingOnEmpty promotes reasoning_content to content when content is empty. +// This handles models that put user-facing replies into thinking blocks instead of text blocks. +func (r *chatCompletionResponse) promoteThinkingOnEmpty() { + for i := range r.Choices { + msg := r.Choices[i].Message + if msg == nil { + continue + } + if !isContentEmpty(msg.Content) { + continue + } + if msg.ReasoningContent != "" { + msg.Content = msg.ReasoningContent + msg.ReasoningContent = "" + } + } +} + +// promoteStreamingThinkingOnEmpty promotes reasoning delta to content delta when no content +// has been seen in the stream so far. Uses context to track state across chunks. +// Returns true if a promotion was performed. +func promoteStreamingThinkingOnEmpty(ctx wrapper.HttpContext, msg *chatMessage) bool { + if msg == nil { + return false + } + hasContentDelta, _ := ctx.GetContext(ctxKeyHasContentDelta).(bool) + if hasContentDelta { + return false + } + + if !isContentEmpty(msg.Content) { + ctx.SetContext(ctxKeyHasContentDelta, true) + return false + } + + reasoning := msg.ReasoningContent + if reasoning == "" { + reasoning = msg.Reasoning + } + if reasoning != "" { + msg.Content = reasoning + msg.ReasoningContent = "" + msg.Reasoning = "" + return true + } + return false +} + +func isContentEmpty(content any) bool { + switch v := content.(type) { + case nil: + return true + case string: + return strings.TrimSpace(v) == "" + default: + return false + } +} + type chatMessageContent struct { CacheControl map[string]interface{} `json:"cache_control,omitempty"` Type string `json:"type,omitempty"` @@ -648,3 +707,54 @@ func (r embeddingsRequest) ParseInput() []string { } return input } + +// PromoteThinkingOnEmptyResponse promotes reasoning_content to content in a non-streaming +// response body when content is empty. Returns the original body if no promotion is needed. +func PromoteThinkingOnEmptyResponse(body []byte) ([]byte, error) { + var resp chatCompletionResponse + if err := json.Unmarshal(body, &resp); err != nil { + return body, fmt.Errorf("unable to unmarshal response for thinking promotion: %v", err) + } + promoted := false + for i := range resp.Choices { + msg := resp.Choices[i].Message + if msg == nil { + continue + } + if !isContentEmpty(msg.Content) { + continue + } + if msg.ReasoningContent != "" { + msg.Content = msg.ReasoningContent + msg.ReasoningContent = "" + promoted = true + } + } + if !promoted { + return body, nil + } + return json.Marshal(resp) +} + +// PromoteStreamingThinkingOnEmptyChunk promotes reasoning delta to content delta in a +// streaming SSE data payload when no content has been seen in the stream so far. +func PromoteStreamingThinkingOnEmptyChunk(ctx wrapper.HttpContext, data []byte) ([]byte, error) { + var resp chatCompletionResponse + if err := json.Unmarshal(data, &resp); err != nil { + return data, nil // not a valid chat completion chunk, skip + } + promoted := false + for i := range resp.Choices { + msg := resp.Choices[i].Delta + if msg == nil { + continue + } + if promoteStreamingThinkingOnEmpty(ctx, msg) { + promoted = true + } + } + if !promoted { + return data, nil + } + return json.Marshal(resp) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index fec47e148..5c8f63ffa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -178,6 +178,7 @@ const ( ctxKeyPushedMessage = "pushedMessage" ctxKeyContentPushed = "contentPushed" ctxKeyReasoningContentPushed = "reasoningContentPushed" + ctxKeyHasContentDelta = "hasContentDelta" objectChatCompletion = "chat.completion" objectChatCompletionChunk = "chat.completion.chunk" @@ -474,6 +475,12 @@ type ProviderConfig struct { // @Title zh-CN 合并连续同角色消息 // @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息),将其内容合并为一条,以满足要求严格轮流交替(user→assistant→user→...)的模型服务商的要求。 mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"` + // @Title zh-CN 空内容时提升思考为正文 + // @Description zh-CN 开启后,若模型响应只包含 reasoning_content/thinking 而没有正文内容,将 reasoning 内容提升为正文内容返回,避免客户端收到空回复。 + promoteThinkingOnEmpty bool `required:"false" yaml:"promoteThinkingOnEmpty" json:"promoteThinkingOnEmpty"` + // @Title zh-CN HiClaw 模式 + // @Description zh-CN 开启后同时启用 mergeConsecutiveMessages 和 promoteThinkingOnEmpty,适用于 HiClaw 多 Agent 协作场景。 + hiclawMode bool `required:"false" yaml:"hiclawMode" json:"hiclawMode"` } func (c *ProviderConfig) GetId() string { @@ -699,6 +706,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } } c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool() + c.promoteThinkingOnEmpty = json.Get("promoteThinkingOnEmpty").Bool() + c.hiclawMode = json.Get("hiclawMode").Bool() + if c.hiclawMode { + c.mergeConsecutiveMessages = true + c.promoteThinkingOnEmpty = true + } } func (c *ProviderConfig) Validate() error { @@ -833,6 +846,10 @@ func (c *ProviderConfig) IsOriginal() bool { return c.protocol == protocolOriginal } +func (c *ProviderConfig) GetPromoteThinkingOnEmpty() bool { + return c.promoteThinkingOnEmpty +} + func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) { return ReplaceByCustomSettings(body, c.customSettings) }