diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index a21f16eff..b23e2cdcf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -19,14 +19,14 @@ description: AI 代理插件配置参考 `provider`的配置字段说明如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------ | -| `type` | string | 必填 | - | AI 服务提供商名称 | -| `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | -| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | -| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
可以使用 "*" 为键来配置通用兜底映射关系 | -| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | -| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| -------------- | --------------- | -------- | ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `type` | string | 必填 | - | AI 服务提供商名称 | +| `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | +| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | +| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | +| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | +| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | `context`的配置字段说明如下: @@ -255,6 +255,7 @@ provider: 'gpt-3': "qwen-turbo" 'gpt-35-turbo': "qwen-plus" 'gpt-4-turbo': "qwen-max" + 'qwen-*': "" '*': "qwen-turbo" ``` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 580ec3294..ad7ac371e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -3,6 +3,7 @@ package provider import ( "errors" "math/rand" + "strings" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" @@ -253,16 +254,38 @@ func CreateProvider(pc ProviderConfig) (Provider, error) { } func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string { - if modelMapping == nil || len(modelMapping) == 0 { - return model - } - if v, ok := modelMapping[model]; ok && len(v) != 0 { - log.Debugf("model %s is mapped to %s explictly", model, v) - return v - } - if v, ok := modelMapping[wildcard]; ok { - log.Debugf("model %s is mapped to %s via wildcard", model, v) - return v + mappedModel := doGetMappedModel(model, modelMapping, log) + if len(mappedModel) != 0 { + return mappedModel } return model } + +func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string { + if modelMapping == nil || len(modelMapping) == 0 { + return "" + } + + if v, ok := modelMapping[model]; ok { + log.Debugf("model [%s] is mapped to [%s] explictly", model, v) + return v + } + + for k, v := range modelMapping { + if k == wildcard || !strings.HasSuffix(k, wildcard) { + continue + } + k = strings.TrimSuffix(k, wildcard) + if strings.HasPrefix(model, k) { + log.Debugf("model [%s] is mapped to [%s] via prefix [%s]", model, v, k) + return v + } + } + + if v, ok := modelMapping[wildcard]; ok { + log.Debugf("model [%s] is mapped to [%s] via wildcard", model, v) + return v + } + + return "" +}