diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md
index a21f16eff..b23e2cdcf 100644
--- a/plugins/wasm-go/extensions/ai-proxy/README.md
+++ b/plugins/wasm-go/extensions/ai-proxy/README.md
@@ -19,14 +19,14 @@ description: AI 代理插件配置参考
`provider`的配置字段说明如下:
-| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
-| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------ |
-| `type` | string | 必填 | - | AI 服务提供商名称 |
-| `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
-| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
-| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
可以使用 "*" 为键来配置通用兜底映射关系 |
-| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
-| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
+| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
+| -------------- | --------------- | -------- | ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| `type` | string | 必填 | - | AI 服务提供商名称 |
+| `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
+| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
+| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
+| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
+| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
`context`的配置字段说明如下:
@@ -255,6 +255,7 @@ provider:
'gpt-3': "qwen-turbo"
'gpt-35-turbo': "qwen-plus"
'gpt-4-turbo': "qwen-max"
+ 'qwen-*': ""
'*': "qwen-turbo"
```
diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
index 580ec3294..ad7ac371e 100644
--- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
+++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go
@@ -3,6 +3,7 @@ package provider
import (
"errors"
"math/rand"
+ "strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -253,16 +254,38 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
}
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
- if modelMapping == nil || len(modelMapping) == 0 {
- return model
- }
- if v, ok := modelMapping[model]; ok && len(v) != 0 {
- log.Debugf("model %s is mapped to %s explictly", model, v)
- return v
- }
- if v, ok := modelMapping[wildcard]; ok {
- log.Debugf("model %s is mapped to %s via wildcard", model, v)
- return v
+ mappedModel := doGetMappedModel(model, modelMapping, log)
+ if len(mappedModel) != 0 {
+ return mappedModel
}
return model
}
+
+func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
+ if modelMapping == nil || len(modelMapping) == 0 {
+ return ""
+ }
+
+ if v, ok := modelMapping[model]; ok {
+ log.Debugf("model [%s] is mapped to [%s] explictly", model, v)
+ return v
+ }
+
+ for k, v := range modelMapping {
+ if k == wildcard || !strings.HasSuffix(k, wildcard) {
+ continue
+ }
+ k = strings.TrimSuffix(k, wildcard)
+ if strings.HasPrefix(model, k) {
+ log.Debugf("model [%s] is mapped to [%s] via prefix [%s]", model, v, k)
+ return v
+ }
+ }
+
+ if v, ok := modelMapping[wildcard]; ok {
+ log.Debugf("model [%s] is mapped to [%s] via wildcard", model, v)
+ return v
+ }
+
+ return ""
+}