diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 78566c9de..9d221a3e5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -64,6 +64,10 @@ Azure OpenAI 所对应的 `type` 为 `azure`。它特有的配置字段如下: 通义千问所对应的 `type` 为 `qwen`。它并无特有的配置字段。 +#### 百川智能 (Baichuan AI) + +百川智能所对应的 `type` 为 `baichuan` 。它并无特有的配置字段。 + #### 零一万物(Yi) 零一万物所对应的 `type` 为 `yi`。它并无特有的配置字段。 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go new file mode 100644 index 000000000..aecea6bc7 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -0,0 +1,87 @@ +package provider + +import ( + "fmt" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" +) + +// baichuanProvider is the provider for baichuan Ai service. + +const ( + baichuanDomain = "api.baichuan-ai.com" + baichuanChatCompletionPath = "/v1/chat/completions" +) + +type baichuanProviderInitializer struct { +} + +func (m *baichuanProviderInitializer) ValidateConfig(config ProviderConfig) error { + return nil +} + +func (m *baichuanProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &baichuanProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +type baichuanProvider struct { + config ProviderConfig + contextCache *contextCache +} + +func (m *baichuanProvider) GetProviderType() string { + return providerTypeBaichuan +} + +func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + _ = util.OverwriteRequestPath(baichuanChatCompletionPath) + _ = util.OverwriteRequestHost(baichuanDomain) + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+m.config.GetRandomToken()) + + if m.contextCache == nil { + ctx.DontReadRequestBody() + } else { + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + } + + return types.ActionContinue, nil +} + +func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + if m.contextCache == nil { + return types.ActionContinue, nil + } + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return types.ActionContinue, err + } + err := m.contextCache.GetContent(func(content string, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + insertContextMessage(request, content) + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + }, log) + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index d8996541c..a408ace15 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -20,10 +20,11 @@ const ( providerTypeQwen = "qwen" providerTypeOpenAI = "openai" providerTypeGroq = "groq" + providerTypeBaichuan = "baichuan" providerTypeYi = "yi" - protocolOpenAI = "openai" - protocolOriginal = "original" + protocolOpenAI = "openai" + protocolOriginal = "original" roleSystem = "system" @@ -54,6 +55,7 @@ var ( providerTypeQwen: &qwenProviderInitializer{}, providerTypeOpenAI: &openaiProviderInitializer{}, providerTypeGroq: &groqProviderInitializer{}, + providerTypeBaichuan: &baichuanProviderInitializer{}, providerTypeYi: &yiProviderInitializer{}, } ) @@ -84,7 +86,7 @@ type ResponseBodyHandler interface { type ProviderConfig struct { // @Title zh-CN AI服务提供商 - // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure" + // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi" typ string `required:"true" yaml:"type" json:"type"` // @Title zh-CN API Tokens // @Description zh-CN 在请求AI服务时用于认证的API Token列表。不同的AI服务提供商可能有不同的名称。部分供应商只支持配置一个API Token(如Azure OpenAI)。