diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 7de9cfe2f..6db58d126 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -3,6 +3,9 @@ package provider import ( "fmt" "strings" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" ) const ( @@ -19,6 +22,9 @@ const ( contentTypeText = "text" contentTypeImageUrl = "image_url" + + reasoningStartTag = "" + reasoningEndTag = "" ) type chatCompletionRequest struct { @@ -136,7 +142,7 @@ type chatMessage struct { Refusal string `json:"refusal,omitempty"` } -func (m *chatMessage) handleReasoningContent(reasoningContentMode string) { +func (m *chatMessage) handleNonStreamingReasoningContent(reasoningContentMode string) { if m.ReasoningContent == "" { return } @@ -145,7 +151,7 @@ func (m *chatMessage) handleReasoningContent(reasoningContentMode string) { m.ReasoningContent = "" break case reasoningBehaviorConcat: - m.Content = fmt.Sprintf("%v\n%v", m.ReasoningContent, m.Content) + m.Content = fmt.Sprintf("%s%v%s\n%v", reasoningStartTag, m.ReasoningContent, reasoningEndTag, m.Content) m.ReasoningContent = "" break case reasoningBehaviorPassThrough: @@ -154,6 +160,46 @@ func (m *chatMessage) handleReasoningContent(reasoningContentMode string) { } } +func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, reasoningContentMode string) { + switch reasoningContentMode { + case reasoningBehaviorIgnore: + m.ReasoningContent = "" + break + case reasoningBehaviorConcat: + contentPushed, _ := ctx.GetContext(ctxKeyContentPushed).(bool) + reasoningContentPushed, _ := ctx.GetContext(ctxKeyReasoningContentPushed).(bool) + + if contentPushed { + if m.ReasoningContent != "" { + // This shouldn't happen, but if it does, we can add a log here. + proxywasm.LogWarnf("[ai-proxy] Content already pushed, but reasoning content is not empty: %v", m) + } + return + } + + if m.ReasoningContent != "" && !reasoningContentPushed { + m.ReasoningContent = reasoningStartTag + m.ReasoningContent + reasoningContentPushed = true + } + if m.Content != "" { + if reasoningContentPushed && !contentPushed /* Keep the second part just to make it easy to understand*/ { + m.ReasoningContent += reasoningEndTag + } + contentPushed = true + } + + m.Content = fmt.Sprintf("%s\n%v", m.ReasoningContent, m.Content) + m.ReasoningContent = "" + + ctx.SetContext(ctxKeyContentPushed, contentPushed) + ctx.SetContext(ctxKeyReasoningContentPushed, reasoningContentPushed) + break + case reasoningBehaviorPassThrough: + default: + break + } +} + type messageContent struct { Type string `json:"type,omitempty"` Text string `json:"text"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 6552f1da1..22ec01baa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -1,6 +1,7 @@ package provider import ( + "bytes" "errors" "math/rand" "net/http" @@ -73,14 +74,16 @@ const ( finishReasonStop = "stop" finishReasonLength = "length" - ctxKeyIncrementalStreaming = "incrementalStreaming" - ctxKeyApiKey = "apiKey" - CtxKeyApiName = "apiName" - ctxKeyIsStreaming = "isStreaming" - ctxKeyStreamingBody = "streamingBody" - ctxKeyOriginalRequestModel = "originalRequestModel" - ctxKeyFinalRequestModel = "finalRequestModel" - ctxKeyPushedMessage = "pushedMessage" + ctxKeyIncrementalStreaming = "incrementalStreaming" + ctxKeyApiKey = "apiKey" + CtxKeyApiName = "apiName" + ctxKeyIsStreaming = "isStreaming" + ctxKeyStreamingBody = "streamingBody" + ctxKeyOriginalRequestModel = "originalRequestModel" + ctxKeyFinalRequestModel = "finalRequestModel" + ctxKeyPushedMessage = "pushedMessage" + ctxKeyContentPushed = "contentPushed" + ctxKeyReasoningContentPushed = "reasoningContentPushed" objectChatCompletion = "chat.completion" objectChatCompletionChunk = "chat.completion.chunk" @@ -588,6 +591,8 @@ func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.L if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { body = append(bufferedStreamingBody, chunk...) } + body = bytes.ReplaceAll(body, []byte("\r\n"), []byte("\n")) + body = bytes.ReplaceAll(body, []byte("\r"), []byte("\n")) eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 4bb39c121..a842cad63 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -338,10 +338,14 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont finished := qwenChoice.FinishReason != "" && qwenChoice.FinishReason != "null" message := qwenChoice.Message + reasoningContentMode := m.config.reasoningContentMode + + log.Warnf("incrementalStreaming: %v", incrementalStreaming) deltaContentMessage := &chatMessage{Role: message.Role, Content: message.Content, ReasoningContent: message.ReasoningContent} - deltaContentMessage.handleReasoningContent(m.config.reasoningContentMode) deltaToolCallsMessage := &chatMessage{Role: message.Role, ToolCalls: append([]toolCall{}, message.ToolCalls...)} - if !incrementalStreaming { + if incrementalStreaming { + deltaContentMessage.handleStreamingReasoningContent(ctx, reasoningContentMode) + } else { for _, tc := range message.ToolCalls { if tc.Function.Arguments == "" && !finished { // We don't push any tool call until its arguments are available. @@ -379,6 +383,8 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont } else { deltaContentMessage.ReasoningContent = util.StripPrefix(deltaContentMessage.ReasoningContent, pushedMessage.ReasoningContent) } + deltaContentMessage.handleStreamingReasoningContent(ctx, reasoningContentMode) + if len(deltaToolCallsMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil { for i, tc := range deltaToolCallsMessage.ToolCalls { if i >= len(pushedMessage.ToolCalls) { @@ -614,7 +620,7 @@ func qwenMessageToChatMessage(qwenMessage qwenMessage, reasoningContentMode stri ReasoningContent: qwenMessage.ReasoningContent, ToolCalls: qwenMessage.ToolCalls, } - msg.handleReasoningContent(reasoningContentMode) + msg.handleNonStreamingReasoningContent(reasoningContentMode) return msg }