diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md
index 96f0e4b84..a54a404a1 100644
--- a/plugins/wasm-go/extensions/ai-proxy/README.md
+++ b/plugins/wasm-go/extensions/ai-proxy/README.md
@@ -21,7 +21,7 @@ description: AI 代理插件配置参考
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
-| `type` | string | 必填 | - | AI 服务提供商名称。目前支持以下取值:openai, azure, moonshot, qwen, zhipuai |
+| `type` | string | 必填 | - | AI 服务提供商名称。目前支持以下取值:openai, azure, moonshot, qwen, zhipuai, baidu |
| `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
可以使用 "*" 为键来配置通用兜底映射关系 |
@@ -89,6 +89,10 @@ DeepSeek所对应的 `type` 为 `deepseek`。它并无特有的配置字段。
Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
+#### 文心一言(Baidu)
+
+文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。
+
#### Anthropic Claude
Anthropic Claude 所对应的 `type` 为 `claude`。它特有的配置字段如下:
@@ -621,6 +625,61 @@ curl --location 'http:///v1/chat/completions' \
}
```
+### 使用 OpenAI 协议代理百度文心一言服务
+
+**配置信息**
+
+```yaml
+provider:
+ type: baidu
+ apiTokens:
+ - "YOUR_BAIDU_API_TOKEN"
+ modelMapping:
+ 'gpt-3': "ERNIE-4.0"
+ '*': "ERNIE-4.0"
+```
+
+**请求示例**
+
+```json
+{
+ "model": "gpt-4-turbo",
+ "messages": [
+ {
+ "role": "user",
+ "content": "你好,你是谁?"
+ }
+ ],
+ "stream": false
+}
+```
+
+**响应示例**
+
+```json
+{
+ "id": "as-e90yfg1pk1",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "你好,我是文心一言,英文名是ERNIE Bot。我能够与人对话互动,回答问题,协助创作,高效便捷地帮助人们获取信息、知识和灵感。"
+ },
+ "finish_reason": "stop"
+ }
+ ],
+ "created": 1717251488,
+ "model": "ERNIE-4.0",
+ "object": "chat.completion",
+ "usage": {
+ "prompt_tokens": 4,
+ "completion_tokens": 33,
+ "total_tokens": 37
+ }
+}
+```
+
## 完整配置示例
以下以使用 OpenAI 协议代理 Groq 服务为例,展示完整的插件配置示例。
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
new file mode 100644
index 000000000..c2a91b5e8
--- /dev/null
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
@@ -0,0 +1,340 @@
+package provider
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
+ "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
+ "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
+ "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
+)
+
+// baiduProvider is the provider for baidu ernie bot service.
+
+const (
+ baiduDomain = "aip.baidubce.com"
+)
+
+var baiduModelToPathSuffixMap = map[string]string{
+ "ERNIE-4.0-8K": "completions_pro",
+ "ERNIE-3.5-8K": "completions",
+ "ERNIE-3.5-128K": "ernie-3.5-128k",
+ "ERNIE-Speed-8K": "ernie_speed",
+ "ERNIE-Speed-128K": "ernie-speed-128k",
+ "ERNIE-Tiny-8K": "ernie-tiny-8k",
+ "ERNIE-Bot-8K": "ernie_bot_8k",
+ "BLOOMZ-7B": "bloomz_7b1",
+}
+
+type baiduProviderInitializer struct {
+}
+
+func (b *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error {
+ return nil
+}
+
+func (b *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
+ return &baiduProvider{
+ config: config,
+ contextCache: createContextCache(&config),
+ }, nil
+}
+
+type baiduProvider struct {
+ config ProviderConfig
+ contextCache *contextCache
+}
+
+func (b *baiduProvider) GetProviderType() string {
+ return providerTypeBaidu
+}
+
+func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+ if apiName != ApiNameChatCompletion {
+ return types.ActionContinue, errUnsupportedApiName
+ }
+ _ = util.OverwriteRequestHost(baiduDomain)
+
+ _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
+ _ = proxywasm.RemoveHttpRequestHeader("Content-Length")
+
+ // Delay the header processing to allow changing streaming mode in OnRequestBody
+ return types.HeaderStopIteration, nil
+}
+
+func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
+ if apiName != ApiNameChatCompletion {
+ return types.ActionContinue, errUnsupportedApiName
+ }
+ // 使用文心一言接口协议
+ if b.config.protocol == protocolOriginal {
+ request := &baiduTextGenRequest{}
+ if err := json.Unmarshal(body, request); err != nil {
+ return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
+ }
+ if request.Model == "" {
+ return types.ActionContinue, errors.New("request model is empty")
+ }
+ // 根据模型重写requestPath
+ path := b.GetRequestPath(request.Model)
+ _ = util.OverwriteRequestPath(path)
+
+ if b.config.context == nil {
+ return types.ActionContinue, nil
+ }
+
+ err := b.contextCache.GetContent(func(content string, err error) {
+ defer func() {
+ _ = proxywasm.ResumeHttpRequest()
+ }()
+
+ if err != nil {
+ log.Errorf("failed to load context file: %v", err)
+ _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
+ }
+ b.setSystemContent(request, content)
+ if err := replaceJsonRequestBody(request, log); err != nil {
+ _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
+ }
+ }, log)
+ if err == nil {
+ return types.ActionPause, nil
+ }
+ return types.ActionContinue, err
+ }
+ request := &chatCompletionRequest{}
+ if err := decodeChatCompletionRequest(body, request); err != nil {
+ return types.ActionContinue, err
+ }
+
+ // 映射模型重写requestPath
+ model := request.Model
+ if model == "" {
+ return types.ActionContinue, errors.New("missing model in chat completion request")
+ }
+ ctx.SetContext(ctxKeyOriginalRequestModel, model)
+ mappedModel := getMappedModel(model, b.config.modelMapping, log)
+ if mappedModel == "" {
+ return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
+ }
+ request.Model = mappedModel
+ ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
+ path := b.GetRequestPath(mappedModel)
+ _ = util.OverwriteRequestPath(path)
+
+ if b.config.context == nil {
+ baiduRequest := b.baiduTextGenRequest(request)
+ return types.ActionContinue, replaceJsonRequestBody(baiduRequest, log)
+ }
+
+ err := b.contextCache.GetContent(func(content string, err error) {
+ defer func() {
+ _ = proxywasm.ResumeHttpRequest()
+ }()
+ if err != nil {
+ log.Errorf("failed to load context file: %v", err)
+ _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
+ }
+ insertContextMessage(request, content)
+ baiduRequest := b.baiduTextGenRequest(request)
+ if err := replaceJsonRequestBody(baiduRequest, log); err != nil {
+ _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
+ }
+ }, log)
+ if err == nil {
+ return types.ActionPause, nil
+ }
+ return types.ActionContinue, err
+}
+
+func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
+ // 使用文心一言接口协议,跳过OnStreamingResponseBody()和OnResponseBody()
+ if b.config.protocol == protocolOriginal {
+ ctx.DontReadResponseBody()
+ return types.ActionContinue, nil
+ }
+
+ _ = proxywasm.RemoveHttpResponseHeader("Content-Length")
+ return types.ActionContinue, nil
+}
+
+func (b *baiduProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
+ if isLastChunk || len(chunk) == 0 {
+ return nil, nil
+ }
+ // sample event response:
+ // data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089502,"sentence_id":0,"is_end":false,"is_truncated":false,"result":"当然可以,","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}
+
+ // sample end event response:
+ // data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089531,"sentence_id":20,"is_end":true,"is_truncated":false,"result":"","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":420,"total_tokens":425}}
+ responseBuilder := &strings.Builder{}
+ lines := strings.Split(string(chunk), "\n")
+ for _, data := range lines {
+ if len(data) < 6 {
+ // ignore blank line or wrong format
+ continue
+ }
+ data = data[6:]
+ var baiduResponse baiduTextGenStreamResponse
+ if err := json.Unmarshal([]byte(data), &baiduResponse); err != nil {
+ log.Errorf("unable to unmarshal baidu response: %v", err)
+ continue
+ }
+ response := b.streamResponseBaidu2OpenAI(ctx, &baiduResponse)
+ responseBody, err := json.Marshal(response)
+ if err != nil {
+ log.Errorf("unable to marshal response: %v", err)
+ return nil, err
+ }
+ b.appendResponse(responseBuilder, string(responseBody))
+ }
+ modifiedResponseChunk := responseBuilder.String()
+ log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
+ return []byte(modifiedResponseChunk), nil
+}
+
+func (b *baiduProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
+ baiduResponse := &baiduTextGenResponse{}
+ if err := json.Unmarshal(body, baiduResponse); err != nil {
+ return types.ActionContinue, fmt.Errorf("unable to unmarshal baidu response: %v", err)
+ }
+ if baiduResponse.ErrorMsg != "" {
+ return types.ActionContinue, fmt.Errorf("baidu response error, error_code: %d, error_message: %s", baiduResponse.ErrorCode, baiduResponse.ErrorMsg)
+ }
+ response := b.responseBaidu2OpenAI(ctx, baiduResponse)
+ return types.ActionContinue, replaceJsonResponseBody(response, log)
+}
+
+type baiduTextGenRequest struct {
+ Model string `json:"model"`
+ Messages []chatMessage `json:"messages"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ PenaltyScore float64 `json:"penalty_score,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ System string `json:"system,omitempty"`
+ DisableSearch bool `json:"disable_search,omitempty"`
+ EnableCitation bool `json:"enable_citation,omitempty"`
+ MaxOutputTokens int `json:"max_output_tokens,omitempty"`
+ UserId string `json:"user_id,omitempty"`
+}
+
+func (b *baiduProvider) GetRequestPath(baiduModel string) string {
+ // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
+ suffix, ok := baiduModelToPathSuffixMap[baiduModel]
+ if !ok {
+ suffix = baiduModel
+ }
+ return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken())
+}
+
+func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
+ request.System = content
+}
+
+func (b *baiduProvider) baiduTextGenRequest(request *chatCompletionRequest) *baiduTextGenRequest {
+ baiduRequest := baiduTextGenRequest{
+ Messages: make([]chatMessage, 0, len(request.Messages)),
+ Temperature: request.Temperature,
+ TopP: request.TopP,
+ PenaltyScore: request.FrequencyPenalty,
+ Stream: request.Stream,
+ DisableSearch: false,
+ EnableCitation: false,
+ MaxOutputTokens: request.MaxTokens,
+ UserId: request.User,
+ }
+ for _, message := range request.Messages {
+ if message.Role == roleSystem {
+ baiduRequest.System = message.Content
+ } else {
+ baiduRequest.Messages = append(baiduRequest.Messages, chatMessage{
+ Role: message.Role,
+ Content: message.Content,
+ })
+ }
+ }
+ return &baiduRequest
+}
+
+type baiduTextGenResponse struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Result string `json:"result"`
+ IsTruncated bool `json:"is_truncated"`
+ NeedClearHistory bool `json:"need_clear_history"`
+ Usage baiduTextGenResponseUsage `json:"usage"`
+ baiduTextGenResponseError
+}
+
+type baiduTextGenResponseError struct {
+ ErrorCode int `json:"error_code"`
+ ErrorMsg string `json:"error_msg"`
+}
+
+type baiduTextGenStreamResponse struct {
+ baiduTextGenResponse
+ SentenceId int `json:"sentence_id"`
+ IsEnd bool `json:"is_end"`
+}
+
+type baiduTextGenResponseUsage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenResponse) *chatCompletionResponse {
+ choice := chatCompletionChoice{
+ Index: 0,
+ Message: &chatMessage{Role: roleAssistant, Content: response.Result},
+ FinishReason: finishReasonStop,
+ }
+ return &chatCompletionResponse{
+ Id: response.Id,
+ Created: time.Now().UnixMilli() / 1000,
+ Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
+ SystemFingerprint: "",
+ Object: objectChatCompletion,
+ Choices: []chatCompletionChoice{choice},
+ Usage: chatCompletionUsage{
+ PromptTokens: response.Usage.PromptTokens,
+ CompletionTokens: response.Usage.CompletionTokens,
+ TotalTokens: response.Usage.TotalTokens,
+ },
+ }
+}
+
+func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenStreamResponse) *chatCompletionResponse {
+ choice := chatCompletionChoice{
+ Index: 0,
+ Message: &chatMessage{Role: roleAssistant, Content: response.Result},
+ }
+ if response.IsEnd {
+ choice.FinishReason = finishReasonStop
+ }
+ return &chatCompletionResponse{
+ Id: response.Id,
+ Created: time.Now().UnixMilli() / 1000,
+ Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
+ SystemFingerprint: "",
+ Object: objectChatCompletion,
+ Choices: []chatCompletionChoice{choice},
+ Usage: chatCompletionUsage{
+ PromptTokens: response.Usage.PromptTokens,
+ CompletionTokens: response.Usage.CompletionTokens,
+ TotalTokens: response.Usage.TotalTokens,
+ },
+ }
+}
+
+func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
+ responseBuilder.WriteString(streamDataItemKey)
+ responseBuilder.WriteString(responseBody)
+ responseBuilder.WriteString("\n\n")
+}
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
index 5ac17fee0..d5ef4a6e8 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
@@ -25,6 +25,7 @@ const (
providerTypeDeepSeek = "deepseek"
providerTypeZhipuAi = "zhipuai"
providerTypeOllama = "ollama"
+ providerTypeBaidu = "baidu"
providerTypeHunyuan = "hunyuan"
protocolOpenAI = "openai"
@@ -34,6 +35,8 @@ const (
roleUser = "user"
roleAssistant = "assistant"
+ finishReasonStop = "stop"
+
ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyApiName = "apiKey"
ctxKeyStreamingBody = "streamingBody"
@@ -68,6 +71,7 @@ var (
providerTypeDeepSeek: &deepseekProviderInitializer{},
providerTypeZhipuAi: &zhipuAiProviderInitializer{},
providerTypeOllama: &ollamaProviderInitializer{},
+ providerTypeBaidu: &baiduProviderInitializer{},
providerTypeHunyuan: &hunyuanProviderInitializer{},
}
)
@@ -98,7 +102,7 @@ type ResponseBodyHandler interface {
type ProviderConfig struct {
// @Title zh-CN AI服务提供商
- // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"、"zhipuai"、"ollama"
+ // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"、"zhipuai"、"ollama"、"baidu"
typ string `required:"true" yaml:"type" json:"type"`
// @Title zh-CN API Tokens
// @Description zh-CN 在请求AI服务时用于认证的API Token列表。不同的AI服务提供商可能有不同的名称。部分供应商只支持配置一个API Token(如Azure OpenAI)。