From 2db0b60a98f982d834922eb3ff667246049c4274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E8=B4=A4=E6=B6=9B?= <601803023@qq.com> Date: Thu, 6 Jun 2024 18:19:55 +0800 Subject: [PATCH] feat: support baidu ernie bot ai model (#1024) Co-authored-by: Kent Dong --- plugins/wasm-go/extensions/ai-proxy/README.md | 61 +++- .../extensions/ai-proxy/provider/baidu.go | 340 ++++++++++++++++++ .../extensions/ai-proxy/provider/provider.go | 6 +- 3 files changed, 405 insertions(+), 2 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/baidu.go diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 96f0e4b84..a54a404a1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -21,7 +21,7 @@ description: AI 代理插件配置参考 | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | |----------------|-----------------|------|-----|----------------------------------------------------------------------------------| -| `type` | string | 必填 | - | AI 服务提供商名称。目前支持以下取值:openai, azure, moonshot, qwen, zhipuai | +| `type` | string | 必填 | - | AI 服务提供商名称。目前支持以下取值:openai, azure, moonshot, qwen, zhipuai, baidu | | `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | | `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | | `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
可以使用 "*" 为键来配置通用兜底映射关系 | @@ -89,6 +89,10 @@ DeepSeek所对应的 `type` 为 `deepseek`。它并无特有的配置字段。 Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。 +#### 文心一言(Baidu) + +文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。 + #### Anthropic Claude Anthropic Claude 所对应的 `type` 为 `claude`。它特有的配置字段如下: @@ -621,6 +625,61 @@ curl --location 'http:///v1/chat/completions' \ } ``` +### 使用 OpenAI 协议代理百度文心一言服务 + +**配置信息** + +```yaml +provider: + type: baidu + apiTokens: + - "YOUR_BAIDU_API_TOKEN" + modelMapping: + 'gpt-3': "ERNIE-4.0" + '*': "ERNIE-4.0" +``` + +**请求示例** + +```json +{ + "model": "gpt-4-turbo", + "messages": [ + { + "role": "user", + "content": "你好,你是谁?" + } + ], + "stream": false +} +``` + +**响应示例** + +```json +{ + "id": "as-e90yfg1pk1", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "你好,我是文心一言,英文名是ERNIE Bot。我能够与人对话互动,回答问题,协助创作,高效便捷地帮助人们获取信息、知识和灵感。" + }, + "finish_reason": "stop" + } + ], + "created": 1717251488, + "model": "ERNIE-4.0", + "object": "chat.completion", + "usage": { + "prompt_tokens": 4, + "completion_tokens": 33, + "total_tokens": 37 + } +} +``` + ## 完整配置示例 以下以使用 OpenAI 协议代理 Groq 服务为例,展示完整的插件配置示例。 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go new file mode 100644 index 000000000..c2a91b5e8 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -0,0 +1,340 @@ +package provider + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" +) + +// baiduProvider is the provider for baidu ernie bot service. + +const ( + baiduDomain = "aip.baidubce.com" +) + +var baiduModelToPathSuffixMap = map[string]string{ + "ERNIE-4.0-8K": "completions_pro", + "ERNIE-3.5-8K": "completions", + "ERNIE-3.5-128K": "ernie-3.5-128k", + "ERNIE-Speed-8K": "ernie_speed", + "ERNIE-Speed-128K": "ernie-speed-128k", + "ERNIE-Tiny-8K": "ernie-tiny-8k", + "ERNIE-Bot-8K": "ernie_bot_8k", + "BLOOMZ-7B": "bloomz_7b1", +} + +type baiduProviderInitializer struct { +} + +func (b *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error { + return nil +} + +func (b *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &baiduProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +type baiduProvider struct { + config ProviderConfig + contextCache *contextCache +} + +func (b *baiduProvider) GetProviderType() string { + return providerTypeBaidu +} + +func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + _ = util.OverwriteRequestHost(baiduDomain) + + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + + // Delay the header processing to allow changing streaming mode in OnRequestBody + return types.HeaderStopIteration, nil +} + +func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + // 使用文心一言接口协议 + if b.config.protocol == protocolOriginal { + request := &baiduTextGenRequest{} + if err := json.Unmarshal(body, request); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + } + if request.Model == "" { + return types.ActionContinue, errors.New("request model is empty") + } + // 根据模型重写requestPath + path := b.GetRequestPath(request.Model) + _ = util.OverwriteRequestPath(path) + + if b.config.context == nil { + return types.ActionContinue, nil + } + + err := b.contextCache.GetContent(func(content string, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + b.setSystemContent(request, content) + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + }, log) + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err + } + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return types.ActionContinue, err + } + + // 映射模型重写requestPath + model := request.Model + if model == "" { + return types.ActionContinue, errors.New("missing model in chat completion request") + } + ctx.SetContext(ctxKeyOriginalRequestModel, model) + mappedModel := getMappedModel(model, b.config.modelMapping, log) + if mappedModel == "" { + return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + } + request.Model = mappedModel + ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + path := b.GetRequestPath(mappedModel) + _ = util.OverwriteRequestPath(path) + + if b.config.context == nil { + baiduRequest := b.baiduTextGenRequest(request) + return types.ActionContinue, replaceJsonRequestBody(baiduRequest, log) + } + + err := b.contextCache.GetContent(func(content string, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + insertContextMessage(request, content) + baiduRequest := b.baiduTextGenRequest(request) + if err := replaceJsonRequestBody(baiduRequest, log); err != nil { + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err)) + } + }, log) + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err +} + +func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + // 使用文心一言接口协议,跳过OnStreamingResponseBody()和OnResponseBody() + if b.config.protocol == protocolOriginal { + ctx.DontReadResponseBody() + return types.ActionContinue, nil + } + + _ = proxywasm.RemoveHttpResponseHeader("Content-Length") + return types.ActionContinue, nil +} + +func (b *baiduProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { + if isLastChunk || len(chunk) == 0 { + return nil, nil + } + // sample event response: + // data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089502,"sentence_id":0,"is_end":false,"is_truncated":false,"result":"当然可以,","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}} + + // sample end event response: + // data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089531,"sentence_id":20,"is_end":true,"is_truncated":false,"result":"","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":420,"total_tokens":425}} + responseBuilder := &strings.Builder{} + lines := strings.Split(string(chunk), "\n") + for _, data := range lines { + if len(data) < 6 { + // ignore blank line or wrong format + continue + } + data = data[6:] + var baiduResponse baiduTextGenStreamResponse + if err := json.Unmarshal([]byte(data), &baiduResponse); err != nil { + log.Errorf("unable to unmarshal baidu response: %v", err) + continue + } + response := b.streamResponseBaidu2OpenAI(ctx, &baiduResponse) + responseBody, err := json.Marshal(response) + if err != nil { + log.Errorf("unable to marshal response: %v", err) + return nil, err + } + b.appendResponse(responseBuilder, string(responseBody)) + } + modifiedResponseChunk := responseBuilder.String() + log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) + return []byte(modifiedResponseChunk), nil +} + +func (b *baiduProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + baiduResponse := &baiduTextGenResponse{} + if err := json.Unmarshal(body, baiduResponse); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal baidu response: %v", err) + } + if baiduResponse.ErrorMsg != "" { + return types.ActionContinue, fmt.Errorf("baidu response error, error_code: %d, error_message: %s", baiduResponse.ErrorCode, baiduResponse.ErrorMsg) + } + response := b.responseBaidu2OpenAI(ctx, baiduResponse) + return types.ActionContinue, replaceJsonResponseBody(response, log) +} + +type baiduTextGenRequest struct { + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + PenaltyScore float64 `json:"penalty_score,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + DisableSearch bool `json:"disable_search,omitempty"` + EnableCitation bool `json:"enable_citation,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + UserId string `json:"user_id,omitempty"` +} + +func (b *baiduProvider) GetRequestPath(baiduModel string) string { + // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t + suffix, ok := baiduModelToPathSuffixMap[baiduModel] + if !ok { + suffix = baiduModel + } + return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken()) +} + +func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) { + request.System = content +} + +func (b *baiduProvider) baiduTextGenRequest(request *chatCompletionRequest) *baiduTextGenRequest { + baiduRequest := baiduTextGenRequest{ + Messages: make([]chatMessage, 0, len(request.Messages)), + Temperature: request.Temperature, + TopP: request.TopP, + PenaltyScore: request.FrequencyPenalty, + Stream: request.Stream, + DisableSearch: false, + EnableCitation: false, + MaxOutputTokens: request.MaxTokens, + UserId: request.User, + } + for _, message := range request.Messages { + if message.Role == roleSystem { + baiduRequest.System = message.Content + } else { + baiduRequest.Messages = append(baiduRequest.Messages, chatMessage{ + Role: message.Role, + Content: message.Content, + }) + } + } + return &baiduRequest +} + +type baiduTextGenResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage baiduTextGenResponseUsage `json:"usage"` + baiduTextGenResponseError +} + +type baiduTextGenResponseError struct { + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` +} + +type baiduTextGenStreamResponse struct { + baiduTextGenResponse + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` +} + +type baiduTextGenResponseUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenResponse) *chatCompletionResponse { + choice := chatCompletionChoice{ + Index: 0, + Message: &chatMessage{Role: roleAssistant, Content: response.Result}, + FinishReason: finishReasonStop, + } + return &chatCompletionResponse{ + Id: response.Id, + Created: time.Now().UnixMilli() / 1000, + Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + SystemFingerprint: "", + Object: objectChatCompletion, + Choices: []chatCompletionChoice{choice}, + Usage: chatCompletionUsage{ + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } +} + +func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenStreamResponse) *chatCompletionResponse { + choice := chatCompletionChoice{ + Index: 0, + Message: &chatMessage{Role: roleAssistant, Content: response.Result}, + } + if response.IsEnd { + choice.FinishReason = finishReasonStop + } + return &chatCompletionResponse{ + Id: response.Id, + Created: time.Now().UnixMilli() / 1000, + Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + SystemFingerprint: "", + Object: objectChatCompletion, + Choices: []chatCompletionChoice{choice}, + Usage: chatCompletionUsage{ + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + }, + } +} + +func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { + responseBuilder.WriteString(streamDataItemKey) + responseBuilder.WriteString(responseBody) + responseBuilder.WriteString("\n\n") +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 5ac17fee0..d5ef4a6e8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -25,6 +25,7 @@ const ( providerTypeDeepSeek = "deepseek" providerTypeZhipuAi = "zhipuai" providerTypeOllama = "ollama" + providerTypeBaidu = "baidu" providerTypeHunyuan = "hunyuan" protocolOpenAI = "openai" @@ -34,6 +35,8 @@ const ( roleUser = "user" roleAssistant = "assistant" + finishReasonStop = "stop" + ctxKeyIncrementalStreaming = "incrementalStreaming" ctxKeyApiName = "apiKey" ctxKeyStreamingBody = "streamingBody" @@ -68,6 +71,7 @@ var ( providerTypeDeepSeek: &deepseekProviderInitializer{}, providerTypeZhipuAi: &zhipuAiProviderInitializer{}, providerTypeOllama: &ollamaProviderInitializer{}, + providerTypeBaidu: &baiduProviderInitializer{}, providerTypeHunyuan: &hunyuanProviderInitializer{}, } ) @@ -98,7 +102,7 @@ type ResponseBodyHandler interface { type ProviderConfig struct { // @Title zh-CN AI服务提供商 - // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"、"zhipuai"、"ollama" + // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"、"zhipuai"、"ollama"、"baidu" typ string `required:"true" yaml:"type" json:"type"` // @Title zh-CN API Tokens // @Description zh-CN 在请求AI服务时用于认证的API Token列表。不同的AI服务提供商可能有不同的名称。部分供应商只支持配置一个API Token(如Azure OpenAI)。