feat: AI Proxy Wasm Plugin Integration with GitHub Models #1304 (#1362)

This commit is contained in:
YeHaitao
2024-10-06 17:12:48 +08:00
committed by GitHub
parent 71aae9ddf6
commit 14b11dcb05
3 changed files with 222 additions and 2 deletions

View File

@@ -143,6 +143,10 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
360智脑所对应的 `type``ai360`。它并无特有的配置字段。
#### GitHub模型
GitHub模型所对应的 `type``github`。它并无特有的配置字段。
#### Mistral
Mistral 所对应的 `type``mistral`。它并无特有的配置字段。
@@ -1018,6 +1022,107 @@ provider:
}
```
### 使用 OpenAI 协议代理 GitHub 模型服务
**配置信息**
```yaml
provider:
type: github
apiTokens:
- "YOUR_GITHUB_ACCESS_TOKEN"
modelMapping:
"gpt-4o": "gpt-4o"
"gpt-4": "Phi-3.5-MoE-instruct"
"gpt-3.5": "cohere-command-r-08-2024"
"text-embedding-3-large": "text-embedding-3-large"
```
**请求示例**
```json
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is the capital of France?"
}
],
"stream": true,
"temperature": 1.0,
"top_p": 1.0,
"max_tokens": 1000,
"model": "gpt-4o"
}
```
**响应示例**
```json
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The capital of France is Paris.",
"role": "assistant"
}
}
],
"created": 1728131051,
"id": "chatcmpl-AEy7PU2JImdsD1W6Jw8GigZSEnM2u",
"model": "gpt-4o-2024-08-06",
"object": "chat.completion",
"system_fingerprint": "fp_67802d9a6d",
"usage": {
"completion_tokens": 7,
"prompt_tokens": 24,
"total_tokens": 31
}
}
```
**文本向量请求示例**
```json
{
"input": ["first phrase", "second phrase", "third phrase"],
"model": "text-embedding-3-large"
}
```
响应示例:
```json
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
-0.0012583479,
0.0020349282,
...
0.012051377,
-0.0053306012,
0.0060688322
]
}
],
"model": "text-embedding-3-large",
"usage": {
"prompt_tokens": 6,
"total_tokens": 6
}
}
```
### 使用 OpenAI 协议代理360智脑服务
**配置信息**
@@ -1026,7 +1131,7 @@ provider:
provider:
type: ai360
apiTokens:
- "YOUR_MINIMAX_API_TOKEN"
- "YOUR_360_API_TOKEN"
modelMapping:
"gpt-4o": "360gpt-turbo-responsibility-8k"
"gpt-4": "360gpt2-pro"

View File

@@ -0,0 +1,112 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
// githubProvider is the provider for GitHub OpenAI service.
const (
githubDomain = "models.inference.ai.azure.com"
githubCompletionPath = "/chat/completions"
githubEmbeddingPath = "/embeddings"
)
type githubProviderInitializer struct {
}
type githubProvider struct {
config ProviderConfig
contextCache *contextCache
}
func (m *githubProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (m *githubProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &githubProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
func (m *githubProvider) GetProviderType() string {
return providerTypeGithub
}
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(githubDomain)
if apiName == ApiNameChatCompletion {
_ = util.OverwriteRequestPath(githubCompletionPath)
}
if apiName == ApiNameEmbeddings {
_ = util.OverwriteRequestPath(githubEmbeddingPath)
}
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
}
func (m *githubProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
func (m *githubProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}

View File

@@ -5,9 +5,10 @@ import (
"math/rand"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
type ApiName string
@@ -20,6 +21,7 @@ const (
providerTypeMoonshot = "moonshot"
providerTypeAzure = "azure"
providerTypeAi360 = "ai360"
providerTypeGithub = "github"
providerTypeQwen = "qwen"
providerTypeOpenAI = "openai"
providerTypeGroq = "groq"
@@ -78,6 +80,7 @@ var (
providerTypeMoonshot: &moonshotProviderInitializer{},
providerTypeAzure: &azureProviderInitializer{},
providerTypeAi360: &ai360ProviderInitializer{},
providerTypeGithub: &githubProviderInitializer{},
providerTypeQwen: &qwenProviderInitializer{},
providerTypeOpenAI: &openaiProviderInitializer{},
providerTypeGroq: &groqProviderInitializer{},