diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index d664d10e..1450297b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -957,6 +957,7 @@ provider: "gpt-4o": "360gpt-turbo-responsibility-8k" "gpt-4": "360gpt2-pro" "gpt-3.5": "360gpt-turbo" + "text-embedding-3-small": "embedding_s1_v1.2" "*": "360gpt-pro" ``` @@ -1015,6 +1016,48 @@ provider: } ``` +**文本向量请求示例** + +URL: http://your-domain/v1/embeddings + +请求示例: + +```json +{ + "input":["你好"], + "model":"text-embedding-3-small" +} +``` + +响应示例: + +```json +{ + "data": [ + { + "embedding": [ + -0.011237, + -0.015433, + ..., + -0.028946, + -0.052778, + 0.003768, + -0.007917, + -0.042201 + ], + "index": 0, + "object": "" + } + ], + "model": "embedding_s1_v1.2", + "object": "", + "usage": { + "prompt_tokens": 2, + "total_tokens": 2 + } +} +``` + ### 使用 OpenAI 协议代理 Cloudflare Workers AI 服务 **配置信息** diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index c9f68710..00443fcf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -1,7 +1,9 @@ package provider import ( + "encoding/json" "errors" + "fmt" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" @@ -41,7 +43,7 @@ func (m *ai360Provider) GetProviderType() string { } func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } _ = util.OverwriteRequestHost(ai360Domain) @@ -53,9 +55,19 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } + if apiName == ApiNameChatCompletion { + return m.onChatCompletionRequestBody(ctx, body, log) + } + if apiName == ApiNameEmbeddings { + return m.onEmbeddingsRequestBody(ctx, body, log) + } + return types.ActionContinue, errUnsupportedApiName +} + +func (m *ai360Provider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { return types.ActionContinue, err @@ -72,3 +84,21 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, request.Model = mappedModel return types.ActionContinue, replaceJsonRequestBody(request, log) } + +func (m *ai360Provider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { + request := &embeddingsRequest{} + if err := json.Unmarshal(body, request); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + } + if request.Model == "" { + return types.ActionContinue, errors.New("missing model in embeddings request") + } + // 映射模型 + mappedModel := getMappedModel(request.Model, m.config.modelMapping, log) + if mappedModel == "" { + return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + } + ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) + request.Model = mappedModel + return types.ActionContinue, replaceJsonRequestBody(request, log) +}