mirror of
https://github.com/alibaba/higress.git
synced 2026-03-08 10:40:48 +08:00
fix: Fix the incorrect reasoning content concat logic in ai-proxy (#1842)
This commit is contained in:
@@ -3,6 +3,9 @@ package provider
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,6 +22,9 @@ const (
|
||||
|
||||
contentTypeText = "text"
|
||||
contentTypeImageUrl = "image_url"
|
||||
|
||||
reasoningStartTag = "<think>"
|
||||
reasoningEndTag = "</think>"
|
||||
)
|
||||
|
||||
type chatCompletionRequest struct {
|
||||
@@ -136,7 +142,7 @@ type chatMessage struct {
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
}
|
||||
|
||||
func (m *chatMessage) handleReasoningContent(reasoningContentMode string) {
|
||||
func (m *chatMessage) handleNonStreamingReasoningContent(reasoningContentMode string) {
|
||||
if m.ReasoningContent == "" {
|
||||
return
|
||||
}
|
||||
@@ -145,7 +151,7 @@ func (m *chatMessage) handleReasoningContent(reasoningContentMode string) {
|
||||
m.ReasoningContent = ""
|
||||
break
|
||||
case reasoningBehaviorConcat:
|
||||
m.Content = fmt.Sprintf("%v\n%v", m.ReasoningContent, m.Content)
|
||||
m.Content = fmt.Sprintf("%s%v%s\n%v", reasoningStartTag, m.ReasoningContent, reasoningEndTag, m.Content)
|
||||
m.ReasoningContent = ""
|
||||
break
|
||||
case reasoningBehaviorPassThrough:
|
||||
@@ -154,6 +160,46 @@ func (m *chatMessage) handleReasoningContent(reasoningContentMode string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, reasoningContentMode string) {
|
||||
switch reasoningContentMode {
|
||||
case reasoningBehaviorIgnore:
|
||||
m.ReasoningContent = ""
|
||||
break
|
||||
case reasoningBehaviorConcat:
|
||||
contentPushed, _ := ctx.GetContext(ctxKeyContentPushed).(bool)
|
||||
reasoningContentPushed, _ := ctx.GetContext(ctxKeyReasoningContentPushed).(bool)
|
||||
|
||||
if contentPushed {
|
||||
if m.ReasoningContent != "" {
|
||||
// This shouldn't happen, but if it does, we can add a log here.
|
||||
proxywasm.LogWarnf("[ai-proxy] Content already pushed, but reasoning content is not empty: %v", m)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if m.ReasoningContent != "" && !reasoningContentPushed {
|
||||
m.ReasoningContent = reasoningStartTag + m.ReasoningContent
|
||||
reasoningContentPushed = true
|
||||
}
|
||||
if m.Content != "" {
|
||||
if reasoningContentPushed && !contentPushed /* Keep the second part just to make it easy to understand*/ {
|
||||
m.ReasoningContent += reasoningEndTag
|
||||
}
|
||||
contentPushed = true
|
||||
}
|
||||
|
||||
m.Content = fmt.Sprintf("%s\n%v", m.ReasoningContent, m.Content)
|
||||
m.ReasoningContent = ""
|
||||
|
||||
ctx.SetContext(ctxKeyContentPushed, contentPushed)
|
||||
ctx.SetContext(ctxKeyReasoningContentPushed, reasoningContentPushed)
|
||||
break
|
||||
case reasoningBehaviorPassThrough:
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
type messageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text"`
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
@@ -73,14 +74,16 @@ const (
|
||||
finishReasonStop = "stop"
|
||||
finishReasonLength = "length"
|
||||
|
||||
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
||||
ctxKeyApiKey = "apiKey"
|
||||
CtxKeyApiName = "apiName"
|
||||
ctxKeyIsStreaming = "isStreaming"
|
||||
ctxKeyStreamingBody = "streamingBody"
|
||||
ctxKeyOriginalRequestModel = "originalRequestModel"
|
||||
ctxKeyFinalRequestModel = "finalRequestModel"
|
||||
ctxKeyPushedMessage = "pushedMessage"
|
||||
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
||||
ctxKeyApiKey = "apiKey"
|
||||
CtxKeyApiName = "apiName"
|
||||
ctxKeyIsStreaming = "isStreaming"
|
||||
ctxKeyStreamingBody = "streamingBody"
|
||||
ctxKeyOriginalRequestModel = "originalRequestModel"
|
||||
ctxKeyFinalRequestModel = "finalRequestModel"
|
||||
ctxKeyPushedMessage = "pushedMessage"
|
||||
ctxKeyContentPushed = "contentPushed"
|
||||
ctxKeyReasoningContentPushed = "reasoningContentPushed"
|
||||
|
||||
objectChatCompletion = "chat.completion"
|
||||
objectChatCompletionChunk = "chat.completion.chunk"
|
||||
@@ -588,6 +591,8 @@ func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.L
|
||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||
body = append(bufferedStreamingBody, chunk...)
|
||||
}
|
||||
body = bytes.ReplaceAll(body, []byte("\r\n"), []byte("\n"))
|
||||
body = bytes.ReplaceAll(body, []byte("\r"), []byte("\n"))
|
||||
|
||||
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
|
||||
|
||||
|
||||
@@ -338,10 +338,14 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
|
||||
finished := qwenChoice.FinishReason != "" && qwenChoice.FinishReason != "null"
|
||||
message := qwenChoice.Message
|
||||
|
||||
reasoningContentMode := m.config.reasoningContentMode
|
||||
|
||||
log.Warnf("incrementalStreaming: %v", incrementalStreaming)
|
||||
deltaContentMessage := &chatMessage{Role: message.Role, Content: message.Content, ReasoningContent: message.ReasoningContent}
|
||||
deltaContentMessage.handleReasoningContent(m.config.reasoningContentMode)
|
||||
deltaToolCallsMessage := &chatMessage{Role: message.Role, ToolCalls: append([]toolCall{}, message.ToolCalls...)}
|
||||
if !incrementalStreaming {
|
||||
if incrementalStreaming {
|
||||
deltaContentMessage.handleStreamingReasoningContent(ctx, reasoningContentMode)
|
||||
} else {
|
||||
for _, tc := range message.ToolCalls {
|
||||
if tc.Function.Arguments == "" && !finished {
|
||||
// We don't push any tool call until its arguments are available.
|
||||
@@ -379,6 +383,8 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
|
||||
} else {
|
||||
deltaContentMessage.ReasoningContent = util.StripPrefix(deltaContentMessage.ReasoningContent, pushedMessage.ReasoningContent)
|
||||
}
|
||||
deltaContentMessage.handleStreamingReasoningContent(ctx, reasoningContentMode)
|
||||
|
||||
if len(deltaToolCallsMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil {
|
||||
for i, tc := range deltaToolCallsMessage.ToolCalls {
|
||||
if i >= len(pushedMessage.ToolCalls) {
|
||||
@@ -614,7 +620,7 @@ func qwenMessageToChatMessage(qwenMessage qwenMessage, reasoningContentMode stri
|
||||
ReasoningContent: qwenMessage.ReasoningContent,
|
||||
ToolCalls: qwenMessage.ToolCalls,
|
||||
}
|
||||
msg.handleReasoningContent(reasoningContentMode)
|
||||
msg.handleNonStreamingReasoningContent(reasoningContentMode)
|
||||
return msg
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user