fix: Fix the incorrect reasoning content concat logic in ai-proxy (#1842)

This commit is contained in:
Kent Dong
2025-03-07 10:33:45 +08:00
committed by GitHub
parent d4155411ee
commit ab419efda4
3 changed files with 70 additions and 13 deletions

View File

@@ -3,6 +3,9 @@ package provider
import (
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
const (
@@ -19,6 +22,9 @@ const (
contentTypeText = "text"
contentTypeImageUrl = "image_url"
reasoningStartTag = "<think>"
reasoningEndTag = "</think>"
)
type chatCompletionRequest struct {
@@ -136,7 +142,7 @@ type chatMessage struct {
Refusal string `json:"refusal,omitempty"`
}
func (m *chatMessage) handleReasoningContent(reasoningContentMode string) {
func (m *chatMessage) handleNonStreamingReasoningContent(reasoningContentMode string) {
if m.ReasoningContent == "" {
return
}
@@ -145,7 +151,7 @@ func (m *chatMessage) handleReasoningContent(reasoningContentMode string) {
m.ReasoningContent = ""
break
case reasoningBehaviorConcat:
m.Content = fmt.Sprintf("%v\n%v", m.ReasoningContent, m.Content)
m.Content = fmt.Sprintf("%s%v%s\n%v", reasoningStartTag, m.ReasoningContent, reasoningEndTag, m.Content)
m.ReasoningContent = ""
break
case reasoningBehaviorPassThrough:
@@ -154,6 +160,46 @@ func (m *chatMessage) handleReasoningContent(reasoningContentMode string) {
}
}
func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, reasoningContentMode string) {
switch reasoningContentMode {
case reasoningBehaviorIgnore:
m.ReasoningContent = ""
break
case reasoningBehaviorConcat:
contentPushed, _ := ctx.GetContext(ctxKeyContentPushed).(bool)
reasoningContentPushed, _ := ctx.GetContext(ctxKeyReasoningContentPushed).(bool)
if contentPushed {
if m.ReasoningContent != "" {
// This shouldn't happen, but if it does, we can add a log here.
proxywasm.LogWarnf("[ai-proxy] Content already pushed, but reasoning content is not empty: %v", m)
}
return
}
if m.ReasoningContent != "" && !reasoningContentPushed {
m.ReasoningContent = reasoningStartTag + m.ReasoningContent
reasoningContentPushed = true
}
if m.Content != "" {
if reasoningContentPushed && !contentPushed /* Keep the second part just to make it easy to understand*/ {
m.ReasoningContent += reasoningEndTag
}
contentPushed = true
}
m.Content = fmt.Sprintf("%s\n%v", m.ReasoningContent, m.Content)
m.ReasoningContent = ""
ctx.SetContext(ctxKeyContentPushed, contentPushed)
ctx.SetContext(ctxKeyReasoningContentPushed, reasoningContentPushed)
break
case reasoningBehaviorPassThrough:
default:
break
}
}
type messageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`