aiproxy 360代理支持embedding模型 (#1247)

Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
YeHaitao
2024-08-26 20:04:18 +08:00
committed by GitHub
parent 5a87031c0e
commit 40a74e32ac
2 changed files with 75 additions and 2 deletions

View File

@@ -957,6 +957,7 @@ provider:
"gpt-4o": "360gpt-turbo-responsibility-8k"
"gpt-4": "360gpt2-pro"
"gpt-3.5": "360gpt-turbo"
"text-embedding-3-small": "embedding_s1_v1.2"
"*": "360gpt-pro"
```
@@ -1015,6 +1016,48 @@ provider:
}
```
**文本向量请求示例**
URL: http://your-domain/v1/embeddings
请求示例:
```json
{
"input":["你好"],
"model":"text-embedding-3-small"
}
```
响应示例:
```json
{
"data": [
{
"embedding": [
-0.011237,
-0.015433,
...,
-0.028946,
-0.052778,
0.003768,
-0.007917,
-0.042201
],
"index": 0,
"object": ""
}
],
"model": "embedding_s1_v1.2",
"object": "",
"usage": {
"prompt_tokens": 2,
"total_tokens": 2
}
}
```
### 使用 OpenAI 协议代理 Cloudflare Workers AI 服务
**配置信息**

View File

@@ -1,7 +1,9 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -41,7 +43,7 @@ func (m *ai360Provider) GetProviderType() string {
}
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(ai360Domain)
@@ -53,9 +55,19 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
}
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
}
func (m *ai360Provider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
@@ -72,3 +84,21 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
func (m *ai360Provider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}