From 51b9d9ec4bc50192c33536f8a4eb07d5fa81c973 Mon Sep 17 00:00:00 2001 From: Yang Beining <35399433+Suchun-sv@users.noreply.github.com> Date: Tue, 28 May 2024 11:29:36 +0800 Subject: [PATCH] feat: Add the ZhipuAI (ChatGLM) provider to the ai-proxy wasm plugin #950 (#1007) Co-authored-by: Kent Dong --- plugins/wasm-go/extensions/ai-proxy/README.md | 6 +- .../extensions/ai-proxy/provider/provider.go | 5 +- .../extensions/ai-proxy/provider/zhipuai.go | 84 +++++++++++++++++++ 3 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index d8d30a83d..1c101b381 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -21,7 +21,7 @@ description: AI 代理插件配置参考 | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | |----------------|-----------------|------|-----|----------------------------------------------------------------------------------| -| `type` | string | 必填 | - | AI 服务提供商名称。目前支持以下取值:openai, azure, moonshot, qwen | +| `type` | string | 必填 | - | AI 服务提供商名称。目前支持以下取值:openai, azure, moonshot, qwen, zhipuai | | `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | | `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | | `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
可以使用 "*" 为键来配置通用兜底映射关系 | @@ -77,6 +77,10 @@ Azure OpenAI 所对应的 `type` 为 `azure`。它特有的配置字段如下: 零一万物所对应的 `type` 为 `yi`。它并无特有的配置字段。 +#### 智谱AI(Zhipu AI) + +智谱AI所对应的 `type` 为 `zhipuai`。它并无特有的配置字段。 + #### DeepSeek(DeepSeek) DeepSeek所对应的 `type` 为 `deepseek`。它并无特有的配置字段。 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index ff396e28e..cd0a0aba4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -23,6 +23,7 @@ const ( providerTypeBaichuan = "baichuan" providerTypeYi = "yi" providerTypeDeepSeek = "deepseek" + providerTypeZhipuAi = "zhipuai" providerTypeOllama = "ollama" protocolOpenAI = "openai" @@ -61,6 +62,7 @@ var ( providerTypeBaichuan: &baichuanProviderInitializer{}, providerTypeYi: &yiProviderInitializer{}, providerTypeDeepSeek: &deepseekProviderInitializer{}, + providerTypeZhipuAi: &zhipuAiProviderInitializer{}, providerTypeOllama: &ollamaProviderInitializer{}, } ) @@ -91,7 +93,7 @@ type ResponseBodyHandler interface { type ProviderConfig struct { // @Title zh-CN AI服务提供商 - // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"、"ollama" + // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"、"zhipuai"、"ollama" typ string `required:"true" yaml:"type" json:"type"` // @Title zh-CN API Tokens // @Description zh-CN 在请求AI服务时用于认证的API Token列表。不同的AI服务提供商可能有不同的名称。部分供应商只支持配置一个API Token(如Azure OpenAI)。 @@ -180,6 +182,7 @@ func (c *ProviderConfig) Validate() error { if c.typ == "" { return errors.New("missing type in provider config") + } initializer, has := providerInitializers[c.typ] if !has { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go new file mode 100644 index 000000000..ae6434057 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -0,0 +1,84 @@ +package provider + +import ( + "fmt" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" +) + +const ( + zhipuAiDomain = "open.bigmodel.cn" + zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions" +) + +type zhipuAiProviderInitializer struct{} + +func (m *zhipuAiProviderInitializer) ValidateConfig(config ProviderConfig) error { + return nil +} + +func (m *zhipuAiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &zhipuAiProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +type zhipuAiProvider struct { + config ProviderConfig + contextCache *contextCache +} + +func (m *zhipuAiProvider) GetProviderType() string { + return providerTypeZhipuAi +} + +func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + _ = util.OverwriteRequestPath(zhipuAiChatCompletionPath) + _ = util.OverwriteRequestHost(zhipuAiDomain) + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+m.config.GetRandomToken()) + + if m.contextCache == nil { + ctx.DontReadRequestBody() + } else { + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + } + + return types.ActionContinue, nil +} + +func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + if m.contextCache == nil { + return types.ActionContinue, nil + } + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return types.ActionContinue, err + } + err := m.contextCache.GetContent(func(content string, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + insertContextMessage(request, content) + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + }, log) + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err +}