mirror of
https://github.com/alibaba/higress.git
synced 2026-03-07 10:00:48 +08:00
feat: support Cloudflare Workers AI (#1068)
Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
@@ -131,6 +131,15 @@ Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下:
|
||||
|
||||
阶跃星辰所对应的 `type` 为 `stepfun`。它并无特有的配置字段。
|
||||
|
||||
#### Cloudflare Workers AI
|
||||
|
||||
Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-------------------|--------|------|-----|----------------------------------------------------------------------------------------------------------------------------|
|
||||
| `cloudflareAccountId` | string | 必填 | - | [Cloudflare Account ID](https://developers.cloudflare.com/workers-ai/get-started/rest-api/#1-get-api-token-and-account-id) |
|
||||
|
||||
|
||||
## 用法示例
|
||||
|
||||
### 使用 OpenAI 协议代理 Azure OpenAI 服务
|
||||
@@ -758,6 +767,57 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Cloudflare Workers AI 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: cloudflare
|
||||
apiTokens:
|
||||
- "YOUR_WORKERS_AI_API_TOKEN"
|
||||
cloudflareAccountId: "YOUR_CLOUDFLARE_ACCOUNT_ID"
|
||||
modelMapping:
|
||||
"*": "@cf/meta/llama-3-8b-instruct"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-3.5",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "id-1720367803430",
|
||||
"object": "chat.completion",
|
||||
"created": 1720367803,
|
||||
"model": "@cf/meta/llama-3-8b-instruct",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational manner. I'm not a human, but a computer program designed to simulate conversation and answer questions to the best of my knowledge. I can be used to generate text on a wide range of topics, from science and history to entertainment and culture.\n\nI'm a large language model, which means I've been trained on a massive dataset of text from the internet and can generate human-like responses. I can understand natural language and respond accordingly, making me suitable for tasks such as:\n\n* Answering questions on various topics\n* Generating text based on a given prompt\n* Translating text from one language to another\n* Summarizing long pieces of text\n* Creating chatbot dialogues\n\nI'm constantly learning and improving, so the more conversations I have with users like you, the better I'll become."
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 完整配置示例
|
||||
|
||||
### Kubernetes 示例
|
||||
|
||||
108
plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Normal file
108
plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
cloudflareDomain = "api.cloudflare.com"
|
||||
// https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility/
|
||||
cloudflareChatCompletionPath = "/client/v4/accounts/{account_id}/ai/v1/chat/completions"
|
||||
)
|
||||
|
||||
type cloudflareProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (c *cloudflareProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cloudflareProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &cloudflareProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type cloudflareProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) GetProviderType() string {
|
||||
return providerTypeCloudflare
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
_ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
|
||||
_ = util.OverwriteRequestHost(cloudflareDomain)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+c.config.GetRandomToken())
|
||||
|
||||
if c.config.context == nil && c.config.protocol == protocolOriginal {
|
||||
ctx.DontReadRequestBody()
|
||||
}
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
model := request.Model
|
||||
if model == "" {
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, c.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
|
||||
streaming := request.Stream
|
||||
if streaming {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
}
|
||||
|
||||
if c.contextCache == nil {
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
err := c.contextCache.GetContent(func(content string, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
@@ -15,21 +15,22 @@ type Pointcut string
|
||||
const (
|
||||
ApiNameChatCompletion ApiName = "chatCompletion"
|
||||
|
||||
providerTypeMoonshot = "moonshot"
|
||||
providerTypeAzure = "azure"
|
||||
providerTypeQwen = "qwen"
|
||||
providerTypeOpenAI = "openai"
|
||||
providerTypeGroq = "groq"
|
||||
providerTypeBaichuan = "baichuan"
|
||||
providerTypeYi = "yi"
|
||||
providerTypeDeepSeek = "deepseek"
|
||||
providerTypeZhipuAi = "zhipuai"
|
||||
providerTypeOllama = "ollama"
|
||||
providerTypeClaude = "claude"
|
||||
providerTypeBaidu = "baidu"
|
||||
providerTypeHunyuan = "hunyuan"
|
||||
providerTypeStepfun = "stepfun"
|
||||
providerTypeMinimax = "minimax"
|
||||
providerTypeMoonshot = "moonshot"
|
||||
providerTypeAzure = "azure"
|
||||
providerTypeQwen = "qwen"
|
||||
providerTypeOpenAI = "openai"
|
||||
providerTypeGroq = "groq"
|
||||
providerTypeBaichuan = "baichuan"
|
||||
providerTypeYi = "yi"
|
||||
providerTypeDeepSeek = "deepseek"
|
||||
providerTypeZhipuAi = "zhipuai"
|
||||
providerTypeOllama = "ollama"
|
||||
providerTypeClaude = "claude"
|
||||
providerTypeBaidu = "baidu"
|
||||
providerTypeHunyuan = "hunyuan"
|
||||
providerTypeStepfun = "stepfun"
|
||||
providerTypeMinimax = "minimax"
|
||||
providerTypeCloudflare = "cloudflare"
|
||||
|
||||
protocolOpenAI = "openai"
|
||||
protocolOriginal = "original"
|
||||
@@ -65,21 +66,22 @@ var (
|
||||
errUnsupportedApiName = errors.New("unsupported API name")
|
||||
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
providerTypeMoonshot: &moonshotProviderInitializer{},
|
||||
providerTypeAzure: &azureProviderInitializer{},
|
||||
providerTypeQwen: &qwenProviderInitializer{},
|
||||
providerTypeOpenAI: &openaiProviderInitializer{},
|
||||
providerTypeGroq: &groqProviderInitializer{},
|
||||
providerTypeBaichuan: &baichuanProviderInitializer{},
|
||||
providerTypeYi: &yiProviderInitializer{},
|
||||
providerTypeDeepSeek: &deepseekProviderInitializer{},
|
||||
providerTypeZhipuAi: &zhipuAiProviderInitializer{},
|
||||
providerTypeOllama: &ollamaProviderInitializer{},
|
||||
providerTypeClaude: &claudeProviderInitializer{},
|
||||
providerTypeBaidu: &baiduProviderInitializer{},
|
||||
providerTypeHunyuan: &hunyuanProviderInitializer{},
|
||||
providerTypeStepfun: &stepfunProviderInitializer{},
|
||||
providerTypeMinimax: &minimaxProviderInitializer{},
|
||||
providerTypeMoonshot: &moonshotProviderInitializer{},
|
||||
providerTypeAzure: &azureProviderInitializer{},
|
||||
providerTypeQwen: &qwenProviderInitializer{},
|
||||
providerTypeOpenAI: &openaiProviderInitializer{},
|
||||
providerTypeGroq: &groqProviderInitializer{},
|
||||
providerTypeBaichuan: &baichuanProviderInitializer{},
|
||||
providerTypeYi: &yiProviderInitializer{},
|
||||
providerTypeDeepSeek: &deepseekProviderInitializer{},
|
||||
providerTypeZhipuAi: &zhipuAiProviderInitializer{},
|
||||
providerTypeOllama: &ollamaProviderInitializer{},
|
||||
providerTypeClaude: &claudeProviderInitializer{},
|
||||
providerTypeBaidu: &baiduProviderInitializer{},
|
||||
providerTypeHunyuan: &hunyuanProviderInitializer{},
|
||||
providerTypeStepfun: &stepfunProviderInitializer{},
|
||||
providerTypeMinimax: &minimaxProviderInitializer{},
|
||||
providerTypeCloudflare: &cloudflareProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -156,6 +158,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 版本
|
||||
// @Description zh-CN 请求AI服务的版本,目前仅适用于Claude AI服务
|
||||
claudeVersion string `required:"false" yaml:"version" json:"version"`
|
||||
// @Title zh-CN Cloudflare Account ID
|
||||
// @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考:https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api
|
||||
cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
@@ -194,6 +199,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.hunyuanAuthId = json.Get("hunyuanAuthId").String()
|
||||
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
|
||||
c.minimaxGroupId = json.Get("minimaxGroupId").String()
|
||||
c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
|
||||
Reference in New Issue
Block a user