From b9f5c4d1f2092bfed613089d71f95dec8aa44b5a Mon Sep 17 00:00:00 2001 From: nohup Date: Mon, 8 Jul 2024 19:27:11 +0800 Subject: [PATCH] feat: support Cloudflare Workers AI (#1068) Co-authored-by: Kent Dong --- plugins/wasm-go/extensions/ai-proxy/README.md | 60 ++++++++++ .../ai-proxy/provider/cloudflare.go | 108 ++++++++++++++++++ .../extensions/ai-proxy/provider/provider.go | 66 ++++++----- 3 files changed, 204 insertions(+), 30 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 80fd0a082..a21f16eff 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -131,6 +131,15 @@ Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下: 阶跃星辰所对应的 `type` 为 `stepfun`。它并无特有的配置字段。 +#### Cloudflare Workers AI + +Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置字段如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|-------------------|--------|------|-----|----------------------------------------------------------------------------------------------------------------------------| +| `cloudflareAccountId` | string | 必填 | - | [Cloudflare Account ID](https://developers.cloudflare.com/workers-ai/get-started/rest-api/#1-get-api-token-and-account-id) | + + ## 用法示例 ### 使用 OpenAI 协议代理 Azure OpenAI 服务 @@ -758,6 +767,57 @@ provider: } ``` +### 使用 OpenAI 协议代理 Cloudflare Workers AI 服务 + +**配置信息** + +```yaml +provider: + type: cloudflare + apiTokens: + - "YOUR_WORKERS_AI_API_TOKEN" + cloudflareAccountId: "YOUR_CLOUDFLARE_ACCOUNT_ID" + modelMapping: + "*": "@cf/meta/llama-3-8b-instruct" +``` + +**请求示例** + +```json +{ + "model": "gpt-3.5", + "max_tokens": 1024, + "messages": [ + { + "role": "user", + "content": "Who are you?" + } + ] +} +``` + +**响应示例** + +```json +{ + "id": "id-1720367803430", + "object": "chat.completion", + "created": 1720367803, + "model": "@cf/meta/llama-3-8b-instruct", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I am LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational manner. I'm not a human, but a computer program designed to simulate conversation and answer questions to the best of my knowledge. I can be used to generate text on a wide range of topics, from science and history to entertainment and culture.\n\nI'm a large language model, which means I've been trained on a massive dataset of text from the internet and can generate human-like responses. I can understand natural language and respond accordingly, making me suitable for tasks such as:\n\n* Answering questions on various topics\n* Generating text based on a given prompt\n* Translating text from one language to another\n* Summarizing long pieces of text\n* Creating chatbot dialogues\n\nI'm constantly learning and improving, so the more conversations I have with users like you, the better I'll become." + }, + "logprobs": null, + "finish_reason": "stop" + } + ] +} +``` + ## 完整配置示例 ### Kubernetes 示例 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go new file mode 100644 index 000000000..560cca3a3 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -0,0 +1,108 @@ +package provider + +import ( + "errors" + "fmt" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "strings" +) + +const ( + cloudflareDomain = "api.cloudflare.com" + // https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility/ + cloudflareChatCompletionPath = "/client/v4/accounts/{account_id}/ai/v1/chat/completions" +) + +type cloudflareProviderInitializer struct { +} + +func (c *cloudflareProviderInitializer) ValidateConfig(config ProviderConfig) error { + return nil +} + +func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &cloudflareProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +type cloudflareProvider struct { + config ProviderConfig + contextCache *contextCache +} + +func (c *cloudflareProvider) GetProviderType() string { + return providerTypeCloudflare +} + +func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + _ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) + _ = util.OverwriteRequestHost(cloudflareDomain) + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+c.config.GetRandomToken()) + + if c.config.context == nil && c.config.protocol == protocolOriginal { + ctx.DontReadRequestBody() + } + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + + return types.ActionContinue, nil +} + +func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return types.ActionContinue, err + } + model := request.Model + if model == "" { + return types.ActionContinue, errors.New("missing model in chat completion request") + } + ctx.SetContext(ctxKeyOriginalRequestModel, model) + mappedModel := getMappedModel(model, c.config.modelMapping, log) + if mappedModel == "" { + return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + } + request.Model = mappedModel + ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + + streaming := request.Stream + if streaming { + _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") + } + + if c.contextCache == nil { + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + return types.ActionContinue, nil + } + err := c.contextCache.GetContent(func(content string, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + insertContextMessage(request, content) + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + }, log) + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index b56c52ca7..580ec3294 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -15,21 +15,22 @@ type Pointcut string const ( ApiNameChatCompletion ApiName = "chatCompletion" - providerTypeMoonshot = "moonshot" - providerTypeAzure = "azure" - providerTypeQwen = "qwen" - providerTypeOpenAI = "openai" - providerTypeGroq = "groq" - providerTypeBaichuan = "baichuan" - providerTypeYi = "yi" - providerTypeDeepSeek = "deepseek" - providerTypeZhipuAi = "zhipuai" - providerTypeOllama = "ollama" - providerTypeClaude = "claude" - providerTypeBaidu = "baidu" - providerTypeHunyuan = "hunyuan" - providerTypeStepfun = "stepfun" - providerTypeMinimax = "minimax" + providerTypeMoonshot = "moonshot" + providerTypeAzure = "azure" + providerTypeQwen = "qwen" + providerTypeOpenAI = "openai" + providerTypeGroq = "groq" + providerTypeBaichuan = "baichuan" + providerTypeYi = "yi" + providerTypeDeepSeek = "deepseek" + providerTypeZhipuAi = "zhipuai" + providerTypeOllama = "ollama" + providerTypeClaude = "claude" + providerTypeBaidu = "baidu" + providerTypeHunyuan = "hunyuan" + providerTypeStepfun = "stepfun" + providerTypeMinimax = "minimax" + providerTypeCloudflare = "cloudflare" protocolOpenAI = "openai" protocolOriginal = "original" @@ -65,21 +66,22 @@ var ( errUnsupportedApiName = errors.New("unsupported API name") providerInitializers = map[string]providerInitializer{ - providerTypeMoonshot: &moonshotProviderInitializer{}, - providerTypeAzure: &azureProviderInitializer{}, - providerTypeQwen: &qwenProviderInitializer{}, - providerTypeOpenAI: &openaiProviderInitializer{}, - providerTypeGroq: &groqProviderInitializer{}, - providerTypeBaichuan: &baichuanProviderInitializer{}, - providerTypeYi: &yiProviderInitializer{}, - providerTypeDeepSeek: &deepseekProviderInitializer{}, - providerTypeZhipuAi: &zhipuAiProviderInitializer{}, - providerTypeOllama: &ollamaProviderInitializer{}, - providerTypeClaude: &claudeProviderInitializer{}, - providerTypeBaidu: &baiduProviderInitializer{}, - providerTypeHunyuan: &hunyuanProviderInitializer{}, - providerTypeStepfun: &stepfunProviderInitializer{}, - providerTypeMinimax: &minimaxProviderInitializer{}, + providerTypeMoonshot: &moonshotProviderInitializer{}, + providerTypeAzure: &azureProviderInitializer{}, + providerTypeQwen: &qwenProviderInitializer{}, + providerTypeOpenAI: &openaiProviderInitializer{}, + providerTypeGroq: &groqProviderInitializer{}, + providerTypeBaichuan: &baichuanProviderInitializer{}, + providerTypeYi: &yiProviderInitializer{}, + providerTypeDeepSeek: &deepseekProviderInitializer{}, + providerTypeZhipuAi: &zhipuAiProviderInitializer{}, + providerTypeOllama: &ollamaProviderInitializer{}, + providerTypeClaude: &claudeProviderInitializer{}, + providerTypeBaidu: &baiduProviderInitializer{}, + providerTypeHunyuan: &hunyuanProviderInitializer{}, + providerTypeStepfun: &stepfunProviderInitializer{}, + providerTypeMinimax: &minimaxProviderInitializer{}, + providerTypeCloudflare: &cloudflareProviderInitializer{}, } ) @@ -156,6 +158,9 @@ type ProviderConfig struct { // @Title zh-CN 版本 // @Description zh-CN 请求AI服务的版本,目前仅适用于Claude AI服务 claudeVersion string `required:"false" yaml:"version" json:"version"` + // @Title zh-CN Cloudflare Account ID + // @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考:https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api + cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"` } func (c *ProviderConfig) FromJson(json gjson.Result) { @@ -194,6 +199,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.hunyuanAuthId = json.Get("hunyuanAuthId").String() c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String() c.minimaxGroupId = json.Get("minimaxGroupId").String() + c.cloudflareAccountId = json.Get("cloudflareAccountId").String() } func (c *ProviderConfig) Validate() error {