From fabc22f21849600d575bcf6373509d58c911a5da Mon Sep 17 00:00:00 2001 From: Kent Dong Date: Fri, 21 Feb 2025 17:32:02 +0800 Subject: [PATCH] feat: Support transforming reasoning_content returned by Qwen to OpenAI contract (#1791) --- plugins/wasm-go/extensions/ai-proxy/README.md | 25 ++++++------ .../extensions/ai-proxy/provider/model.go | 39 +++++++++++++++++-- .../extensions/ai-proxy/provider/provider.go | 21 ++++++++++ .../extensions/ai-proxy/provider/qwen.go | 25 ++++++++---- 4 files changed, 86 insertions(+), 24 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 8f281ffd2..cb685e6e0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -31,18 +31,19 @@ description: AI 代理插件配置参考 `provider`的配置字段说明如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -|------------------| --------------- | -------- | ------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------| -| `type` | string | 必填 | - | AI 服务提供商名称 | -| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | -| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | -| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | -| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | -| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | -| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | -| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | -| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 | -| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|------------------| --------------- | -------- | ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `type` | string | 必填 | - | AI 服务提供商名称 | +| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | +| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | +| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | +| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | +| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | +| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | +| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | +| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 | +| `reasoningContentMode` | string | 非必填 | - | 如何处理大模型服务返回的推理内容。目前支持以下取值:passthrough(正常输出推理内容)、ignore(不输出推理内容)、concat(将推理内容拼接在常规输出内容之前)。默认为 passthrough。仅支持通义千问服务。 | +| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank | `context`的配置字段说明如下: diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 726a18fca..b38b4fde8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -1,6 +1,9 @@ package provider -import "strings" +import ( + "fmt" + "strings" +) const ( streamEventIdItemKey = "id:" @@ -110,9 +113,16 @@ type chatCompletionChoice struct { } type usage struct { - PromptTokens int `json:"prompt_tokens,omitempty"` - CompletionTokens int `json:"completion_tokens,omitempty"` - TotalTokens int `json:"total_tokens,omitempty"` + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` + CompletionTokensDetails *completionTokensDetails `json:"completion_tokens_details,omitempty"` +} + +type completionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` } type chatMessage struct { @@ -126,6 +136,24 @@ type chatMessage struct { Refusal string `json:"refusal,omitempty"` } +func (m *chatMessage) handleReasoningContent(reasoningContentMode string) { + if m.ReasoningContent == "" { + return + } + switch reasoningContentMode { + case reasoningBehaviorIgnore: + m.ReasoningContent = "" + break + case reasoningBehaviorConcat: + m.Content = fmt.Sprintf("%v\n%v", m.ReasoningContent, m.Content) + m.ReasoningContent = "" + break + case reasoningBehaviorPassThrough: + default: + break + } +} + type messageContent struct { Type string `json:"type,omitempty"` Text string `json:"text"` @@ -138,6 +166,9 @@ type imageUrl struct { } func (m *chatMessage) IsEmpty() bool { + if m.ReasoningContent != "" { + return false + } if m.IsStringContent() && m.Content != "" { return false } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 67cce2888..46e6b4ed7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -85,6 +85,10 @@ const ( objectChatCompletion = "chat.completion" objectChatCompletionChunk = "chat.completion.chunk" + reasoningBehaviorPassThrough = "passthrough" + reasoningBehaviorIgnore = "ignore" + reasoningBehaviorConcat = "concat" + wildcard = "*" defaultTimeout = 2 * 60 * 1000 // ms @@ -190,6 +194,9 @@ type ProviderConfig struct { // @Title zh-CN 失败请求重试 // @Description zh-CN 对失败的请求立即进行重试 retryOnFailure *retryOnFailure `required:"false" yaml:"retryOnFailure" json:"retryOnFailure"` + // @Title zh-CN 推理内容处理方式 + // @Description zh-CN 如何处理大模型服务返回的推理内容。目前支持以下取值:passthrough(正常输出推理内容)、ignore(不输出推理内容)、concat(将推理内容拼接在常规输出内容之前)。默认为 normal。仅支持通义千问服务。 + reasoningContentMode string `required:"false" yaml:"reasoningContentMode" json:"reasoningContentMode"` // @Title zh-CN 基于OpenAI协议的自定义后端URL // @Description zh-CN 仅适用于支持 openai 协议的服务。 openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"` @@ -359,6 +366,20 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } } + c.reasoningContentMode = json.Get("reasoningContentMode").String() + if c.reasoningContentMode == "" { + c.reasoningContentMode = reasoningBehaviorPassThrough + } else { + c.reasoningContentMode = strings.ToLower(c.reasoningContentMode) + switch c.reasoningContentMode { + case reasoningBehaviorPassThrough, reasoningBehaviorIgnore, reasoningBehaviorConcat: + break + default: + c.reasoningContentMode = reasoningBehaviorPassThrough + break + } + } + failoverJson := json.Get("failover") c.failover = &failover{ enabled: false, diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index 2f757c683..fd55eee22 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -357,7 +357,7 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, o func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse { choices := make([]chatCompletionChoice, 0, len(qwenResponse.Output.Choices)) for _, qwenChoice := range qwenResponse.Output.Choices { - message := qwenMessageToChatMessage(qwenChoice.Message) + message := qwenMessageToChatMessage(qwenChoice.Message, m.config.reasoningContentMode) choices = append(choices, chatCompletionChoice{ Message: &message, FinishReason: qwenChoice.FinishReason, @@ -395,7 +395,8 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont finished := qwenChoice.FinishReason != "" && qwenChoice.FinishReason != "null" message := qwenChoice.Message - deltaContentMessage := &chatMessage{Role: message.Role, Content: message.Content} + deltaContentMessage := &chatMessage{Role: message.Role, Content: message.Content, ReasoningContent: message.ReasoningContent} + deltaContentMessage.handleReasoningContent(m.config.reasoningContentMode) deltaToolCallsMessage := &chatMessage{Role: message.Role, ToolCalls: append([]toolCall{}, message.ToolCalls...)} if !incrementalStreaming { for _, tc := range message.ToolCalls { @@ -430,6 +431,11 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont } } } + if message.ReasoningContent == "" { + message.ReasoningContent = pushedMessage.ReasoningContent + } else { + deltaContentMessage.ReasoningContent = util.StripPrefix(deltaContentMessage.ReasoningContent, pushedMessage.ReasoningContent) + } if len(deltaToolCallsMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil { for i, tc := range deltaToolCallsMessage.ToolCalls { if i >= len(pushedMessage.ToolCalls) { @@ -690,13 +696,16 @@ type qwenTextEmbeddings struct { Embedding []float64 `json:"embedding"` } -func qwenMessageToChatMessage(qwenMessage qwenMessage) chatMessage { - return chatMessage{ - Name: qwenMessage.Name, - Role: qwenMessage.Role, - Content: qwenMessage.Content, - ToolCalls: qwenMessage.ToolCalls, +func qwenMessageToChatMessage(qwenMessage qwenMessage, reasoningContentMode string) chatMessage { + msg := chatMessage{ + Name: qwenMessage.Name, + Role: qwenMessage.Role, + Content: qwenMessage.Content, + ReasoningContent: qwenMessage.ReasoningContent, + ToolCalls: qwenMessage.ToolCalls, } + msg.handleReasoningContent(reasoningContentMode) + return msg } func (m *qwenMessage) IsStringContent() bool {