feat: Support configuring a global provider list in ai-proxy plugin (#1334)

This commit is contained in:
Kent Dong
2024-09-26 11:27:22 +08:00
committed by GitHub
parent 260772926c
commit 708e7af79a
6 changed files with 177 additions and 39 deletions

View File

@@ -607,6 +607,7 @@ provider:
```
### 使用original协议代理百炼智能体应用
**配置信息**
```yaml
@@ -832,6 +833,7 @@ provider:
}
}
```
### 使用 OpenAI 协议代理混元服务
**配置信息**
@@ -849,9 +851,10 @@ provider:
```
**请求示例**
请求脚本:
```sh
请求脚本:
```shell
curl --location 'http://<your higress domain>/v1/chat/completions' \
--header 'Content-Type: application/json' \
--data '{

View File

@@ -139,9 +139,9 @@ For 360 Brain, the corresponding `type` is `ai360`. It has no unique configurati
For Mistral, the corresponding `type` is `mistral`. It has no unique configuration fields.
#### Minimax
#### MiniMax
For Minimax, the corresponding `type` is `minimax`. Its unique configuration field is:
For MiniMax, the corresponding `type` is `minimax`. Its unique configuration field is:
| Name | Data Type | Filling Requirements | Default Value | Description |
| ---------------- | -------- | --------------------- |---------------|------------------------------------------------------------------------------------------------------------|
@@ -593,6 +593,69 @@ provider:
"request_id": "187e99ba-5b64-9ffe-8f69-01dafbaf6ed7"
}
```
### Forwards requests to AliCloud Bailian with the "original" protocol
**Configuration Information**
```yaml
activeProviderId: my-qwen
providers:
- id: my-qwen
type: qwen
apiTokens:
- "YOUR_DASHSCOPE_API_TOKEN"
protocol: original
```
**Example Request**
```json
{
"input": {
"prompt": "What is Dubbo?"
},
"parameters": {},
"debug": {}
}
```
**Example Response**
```json
{
"output": {
"finish_reason": "stop",
"session_id": "677e7e8fbb874e1b84792b65042e1599",
"text": "Apache Dubbo is a..."
},
"usage": {
"models": [
{
"output_tokens": 449,
"model_id": "qwen-max",
"input_tokens": 282
}
]
},
"request_id": "b59e45e3-5af4-91df-b7c6-9d746fd3297c"
}
```
### Using OpenAI Protocol Proxy for Doubao Service
```yaml
activeProviderId: my-doubao
providers:
- id: my-doubao
type: doubao
apiTokens:
- YOUR_DOUBAO_API_KEY
modelMapping:
'*': YOUR_DOUBAO_ENDPOINT
timeout: 1200000
```
### Utilizing Moonshot with its Native File Context
Upload files to Moonshot in advance and use its AI services based on file content.
@@ -782,8 +845,7 @@ provider:
Request script:
```sh
```shell
curl --location 'http://<your higress domain>/v1/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
@@ -955,7 +1017,7 @@ provider:
provider:
type: ai360
apiTokens:
- "YOUR_MINIMAX_API_TOKEN"
- "YOUR_AI360_API_TOKEN"
modelMapping:
"gpt-4o": "360gpt-turbo-responsibility-8k"
"gpt-4": "360gpt2-pro"
@@ -1264,6 +1326,7 @@ Here, `model` denotes the service tier of DeepL and can only be either `Free` or
```
**Response Example**
```json
{
"choices": [

View File

@@ -25,32 +25,70 @@ import (
type PluginConfig struct {
// @Title zh-CN AI服务提供商配置
// @Description zh-CN AI服务提供商配置包含API接口、模型和知识库文件等信息
providerConfig provider.ProviderConfig `required:"true" yaml:"provider"`
providerConfigs []provider.ProviderConfig `required:"true" yaml:"providers"`
provider provider.Provider `yaml:"-"`
activeProviderConfig *provider.ProviderConfig `yaml:"-"`
activeProvider provider.Provider `yaml:"-"`
}
func (c *PluginConfig) FromJson(json gjson.Result) {
c.providerConfig.FromJson(json.Get("provider"))
if providersJson := json.Get("providers"); providersJson.Exists() && providersJson.IsArray() {
c.providerConfigs = make([]provider.ProviderConfig, 0)
for _, providerJson := range providersJson.Array() {
providerConfig := provider.ProviderConfig{}
providerConfig.FromJson(providerJson)
c.providerConfigs = append(c.providerConfigs, providerConfig)
}
}
if providerJson := json.Get("provider"); providerJson.Exists() && providerJson.IsObject() {
// TODO: For legacy config support. To be removed later.
providerConfig := provider.ProviderConfig{}
providerConfig.FromJson(providerJson)
c.providerConfigs = []provider.ProviderConfig{providerConfig}
c.activeProviderConfig = &providerConfig
// Legacy configuration is used and the active provider is determined.
// We don't need to continue with the new configuration style.
return
}
c.activeProviderConfig = nil
activeProviderId := json.Get("activeProviderId").String()
if activeProviderId != "" {
for _, providerConfig := range c.providerConfigs {
if providerConfig.GetId() == activeProviderId {
c.activeProviderConfig = &providerConfig
break
}
}
}
}
func (c *PluginConfig) Validate() error {
if err := c.providerConfig.Validate(); err != nil {
if c.activeProviderConfig == nil {
return nil
}
if err := c.activeProviderConfig.Validate(); err != nil {
return err
}
return nil
}
func (c *PluginConfig) Complete() error {
if c.activeProviderConfig == nil {
c.activeProvider = nil
return nil
}
var err error
c.provider, err = provider.CreateProvider(c.providerConfig)
c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig)
return err
}
func (c *PluginConfig) GetProvider() provider.Provider {
return c.provider
return c.activeProvider
}
func (c *PluginConfig) GetProviderConfig() provider.ProviderConfig {
return c.providerConfig
func (c *PluginConfig) GetProviderConfig() *provider.ProviderConfig {
return c.activeProviderConfig
}

View File

@@ -56,21 +56,24 @@ static_resources:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
{
"provider": {
"type": "moonshot",
"domain": "api.moonshot.cn",
"apiTokens": [
"****",
"****"
],
"timeout": 1200000,
"modelMapping": {
"gpt-3": "moonshot-v1-8k",
"gpt-35-turbo": "moonshot-v1-32k",
"gpt-4-turbo": "moonshot-v1-128k",
"*": "moonshot-v1-8k"
},
}
"activeProviderId": "moonshot",
"providers": [
{
"type": "moonshot",
"domain": "api.moonshot.cn",
"apiTokens": [
"****",
"****"
],
"timeout": 1200000,
"modelMapping": {
"gpt-3": "moonshot-v1-8k",
"gpt-35-turbo": "moonshot-v1-32k",
"gpt-4-turbo": "moonshot-v1-128k",
"*": "moonshot-v1-8k"
},
}
]
}
- name: envoy.filters.http.router
clusters:

View File

@@ -20,7 +20,7 @@ import (
const (
pluginName = "ai-proxy"
ctxKeyApiName = "apiKey"
ctxKeyApiName = "apiName"
defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
)
@@ -28,7 +28,7 @@ const (
func main() {
wrapper.SetCtx(
pluginName,
wrapper.ParseConfigBy(parseConfig),
wrapper.ParseOverrideConfigBy(parseGlobalConfig, parseOverrideRuleConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeader),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
@@ -37,8 +37,23 @@ func main() {
)
}
func parseConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error {
// log.Debugf("loading config: %s", json.String())
func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error {
//log.Debugf("loading global config: %s", json.String())
pluginConfig.FromJson(json)
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
return err
}
return nil
}
func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig, log wrapper.Log) error {
//log.Debugf("loading override rule config: %s", json.String())
*pluginConfig = global
pluginConfig.FromJson(json)
if err := pluginConfig.Validate(); err != nil {

View File

@@ -126,7 +126,10 @@ type ResponseBodyHandler interface {
}
type ProviderConfig struct {
// @Title zh-CN AI服务提供商
// @Title zh-CN ID
// @Description zh-CN AI服务提供商标识
id string `required:"true" yaml:"id" json:"id"`
// @Title zh-CN 类型
// @Description zh-CN AI服务提供商类型
typ string `required:"true" yaml:"type" json:"type"`
// @Title zh-CN API Tokens
@@ -197,7 +200,20 @@ type ProviderConfig struct {
customSettings []CustomSetting
}
func (c *ProviderConfig) GetId() string {
return c.id
}
func (c *ProviderConfig) GetType() string {
return c.typ
}
func (c *ProviderConfig) GetProtocol() string {
return c.protocol
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
c.id = json.Get("id").String()
c.typ = json.Get("type").String()
c.apiTokens = make([]string, 0)
for _, token := range json.Get("apiTokens").Array() {
@@ -322,6 +338,10 @@ func (c *ProviderConfig) IsOriginal() bool {
return c.protocol == protocolOriginal
}
func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
return ReplaceByCustomSettings(body, c.customSettings)
}
func CreateProvider(pc ProviderConfig) (Provider, error) {
initializer, has := providerInitializers[pc.typ]
if !has {
@@ -366,7 +386,3 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
func (c ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
return ReplaceByCustomSettings(body, c.customSettings)
}