fix: Fix the incorrect reasoning content concat logic in ai-proxy (#1842)

This commit is contained in:
Kent Dong
2025-03-07 10:33:45 +08:00
committed by GitHub
parent d4155411ee
commit ab419efda4
3 changed files with 70 additions and 13 deletions

View File

@@ -3,6 +3,9 @@ package provider
import (
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
const (
@@ -19,6 +22,9 @@ const (
contentTypeText = "text"
contentTypeImageUrl = "image_url"
reasoningStartTag = "<think>"
reasoningEndTag = "</think>"
)
type chatCompletionRequest struct {
@@ -136,7 +142,7 @@ type chatMessage struct {
Refusal string `json:"refusal,omitempty"`
}
func (m *chatMessage) handleReasoningContent(reasoningContentMode string) {
func (m *chatMessage) handleNonStreamingReasoningContent(reasoningContentMode string) {
if m.ReasoningContent == "" {
return
}
@@ -145,7 +151,7 @@ func (m *chatMessage) handleReasoningContent(reasoningContentMode string) {
m.ReasoningContent = ""
break
case reasoningBehaviorConcat:
m.Content = fmt.Sprintf("%v\n%v", m.ReasoningContent, m.Content)
m.Content = fmt.Sprintf("%s%v%s\n%v", reasoningStartTag, m.ReasoningContent, reasoningEndTag, m.Content)
m.ReasoningContent = ""
break
case reasoningBehaviorPassThrough:
@@ -154,6 +160,46 @@ func (m *chatMessage) handleReasoningContent(reasoningContentMode string) {
}
}
func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, reasoningContentMode string) {
switch reasoningContentMode {
case reasoningBehaviorIgnore:
m.ReasoningContent = ""
break
case reasoningBehaviorConcat:
contentPushed, _ := ctx.GetContext(ctxKeyContentPushed).(bool)
reasoningContentPushed, _ := ctx.GetContext(ctxKeyReasoningContentPushed).(bool)
if contentPushed {
if m.ReasoningContent != "" {
// This shouldn't happen, but if it does, we can add a log here.
proxywasm.LogWarnf("[ai-proxy] Content already pushed, but reasoning content is not empty: %v", m)
}
return
}
if m.ReasoningContent != "" && !reasoningContentPushed {
m.ReasoningContent = reasoningStartTag + m.ReasoningContent
reasoningContentPushed = true
}
if m.Content != "" {
if reasoningContentPushed && !contentPushed /* Keep the second part just to make it easy to understand*/ {
m.ReasoningContent += reasoningEndTag
}
contentPushed = true
}
m.Content = fmt.Sprintf("%s\n%v", m.ReasoningContent, m.Content)
m.ReasoningContent = ""
ctx.SetContext(ctxKeyContentPushed, contentPushed)
ctx.SetContext(ctxKeyReasoningContentPushed, reasoningContentPushed)
break
case reasoningBehaviorPassThrough:
default:
break
}
}
type messageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`

View File

@@ -1,6 +1,7 @@
package provider
import (
"bytes"
"errors"
"math/rand"
"net/http"
@@ -73,14 +74,16 @@ const (
finishReasonStop = "stop"
finishReasonLength = "length"
ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyApiKey = "apiKey"
CtxKeyApiName = "apiName"
ctxKeyIsStreaming = "isStreaming"
ctxKeyStreamingBody = "streamingBody"
ctxKeyOriginalRequestModel = "originalRequestModel"
ctxKeyFinalRequestModel = "finalRequestModel"
ctxKeyPushedMessage = "pushedMessage"
ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyApiKey = "apiKey"
CtxKeyApiName = "apiName"
ctxKeyIsStreaming = "isStreaming"
ctxKeyStreamingBody = "streamingBody"
ctxKeyOriginalRequestModel = "originalRequestModel"
ctxKeyFinalRequestModel = "finalRequestModel"
ctxKeyPushedMessage = "pushedMessage"
ctxKeyContentPushed = "contentPushed"
ctxKeyReasoningContentPushed = "reasoningContentPushed"
objectChatCompletion = "chat.completion"
objectChatCompletionChunk = "chat.completion.chunk"
@@ -588,6 +591,8 @@ func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.L
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
body = append(bufferedStreamingBody, chunk...)
}
body = bytes.ReplaceAll(body, []byte("\r\n"), []byte("\n"))
body = bytes.ReplaceAll(body, []byte("\r"), []byte("\n"))
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1

View File

@@ -338,10 +338,14 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
finished := qwenChoice.FinishReason != "" && qwenChoice.FinishReason != "null"
message := qwenChoice.Message
reasoningContentMode := m.config.reasoningContentMode
log.Warnf("incrementalStreaming: %v", incrementalStreaming)
deltaContentMessage := &chatMessage{Role: message.Role, Content: message.Content, ReasoningContent: message.ReasoningContent}
deltaContentMessage.handleReasoningContent(m.config.reasoningContentMode)
deltaToolCallsMessage := &chatMessage{Role: message.Role, ToolCalls: append([]toolCall{}, message.ToolCalls...)}
if !incrementalStreaming {
if incrementalStreaming {
deltaContentMessage.handleStreamingReasoningContent(ctx, reasoningContentMode)
} else {
for _, tc := range message.ToolCalls {
if tc.Function.Arguments == "" && !finished {
// We don't push any tool call until its arguments are available.
@@ -379,6 +383,8 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
} else {
deltaContentMessage.ReasoningContent = util.StripPrefix(deltaContentMessage.ReasoningContent, pushedMessage.ReasoningContent)
}
deltaContentMessage.handleStreamingReasoningContent(ctx, reasoningContentMode)
if len(deltaToolCallsMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil {
for i, tc := range deltaToolCallsMessage.ToolCalls {
if i >= len(pushedMessage.ToolCalls) {
@@ -614,7 +620,7 @@ func qwenMessageToChatMessage(qwenMessage qwenMessage, reasoningContentMode stri
ReasoningContent: qwenMessage.ReasoningContent,
ToolCalls: qwenMessage.ToolCalls,
}
msg.handleReasoningContent(reasoningContentMode)
msg.handleNonStreamingReasoningContent(reasoningContentMode)
return msg
}