From c78ef7011d408f92597cfd2dcd0f495120f6c167 Mon Sep 17 00:00:00 2001 From: urlyy Date: Thu, 8 Aug 2024 15:16:58 +0800 Subject: [PATCH] Feat: Add Spark llm support for plugins/ai-proxy (#1139) --- plugins/wasm-go/extensions/ai-proxy/README.md | 67 +++++- .../extensions/ai-proxy/provider/provider.go | 2 + .../extensions/ai-proxy/provider/spark.go | 207 ++++++++++++++++++ plugins/wasm-go/go.sum | 7 +- 4 files changed, 276 insertions(+), 7 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/spark.go diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index fb5f6f28d..d3573b618 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -144,6 +144,12 @@ Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置 |-------------------|--------|------|-----|----------------------------------------------------------------------------------------------------------------------------| | `cloudflareAccountId` | string | 必填 | - | [Cloudflare Account ID](https://developers.cloudflare.com/workers-ai/get-started/rest-api/#1-get-api-token-and-account-id) | +#### 星火 (Spark) + +星火所对应的 `type` 为 `spark`。它并无特有的配置字段。 + +讯飞星火认知大模型的`apiTokens`字段值为`APIKey:APISecret`。即填入自己的APIKey与APISecret,并以`:`分隔。 + ## 用法示例 @@ -870,6 +876,65 @@ provider: } ``` +### 使用 OpenAI 协议代理Spark服务 + +**配置信息** + +```yaml +provider: + type: spark + apiTokens: + - "APIKey:APISecret" + modelMapping: + "gpt-4o": "generalv3.5" + "gpt-4": "generalv3" + "*": "general" +``` + +**请求示例** + +```json +{ + "model": "gpt-4o", + "messages": [ + { + "role": "system", + "content": "你是一名专业的开发人员!" + }, + { + "role": "user", + "content": "你好,你是谁?" + } + ], + "stream": false +} +``` + +**响应示例** + +```json +{ + "id": "cha000c23c6@dx190ef0b4b96b8f2532", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "你好!我是一名专业的开发人员,擅长编程和解决技术问题。有什么我可以帮助你的吗?" + } + } + ], + "created": 1721997415, + "model": "generalv3.5", + "object": "chat.completion", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 19, + "total_tokens": 29 + } +} +``` + ## 完整配置示例 ### Kubernetes 示例 @@ -1071,4 +1136,4 @@ curl "http://localhost:10000/v1/chat/completions" -H "Content-Type: application } ] }' -``` +``` \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index ec97b826f..ec3d2caf5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -33,6 +33,7 @@ const ( providerTypeStepfun = "stepfun" providerTypeMinimax = "minimax" providerTypeCloudflare = "cloudflare" + providerTypeSpark = "spark" protocolOpenAI = "openai" protocolOriginal = "original" @@ -84,6 +85,7 @@ var ( providerTypeStepfun: &stepfunProviderInitializer{}, providerTypeMinimax: &minimaxProviderInitializer{}, providerTypeCloudflare: &cloudflareProviderInitializer{}, + providerTypeSpark: &sparkProviderInitializer{}, } ) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go new file mode 100644 index 000000000..fc266dfba --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -0,0 +1,207 @@ +package provider + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" +) + +// sparkProvider is the provider for SparkLLM AI service. +const ( + sparkHost = "spark-api-open.xf-yun.com" + sparkChatCompletionPath = "/v1/chat/completions" +) + +type sparkProviderInitializer struct { +} + +type sparkProvider struct { + config ProviderConfig + contextCache *contextCache +} + +type sparkRequest struct { + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + TopK int `json:"top_k,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Tools []tool `json:"tools,omitempty"` + ToolChoice string `json:"tool_choice,omitempty"` +} + +type sparkResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Choices []chatCompletionChoice `json:"choices"` + Usage usage `json:"usage,omitempty"` +} + +type sparkStreamResponse struct { + sparkResponse + Id string `json:"id"` + Created int64 `json:"created"` +} + +func (i *sparkProviderInitializer) ValidateConfig(config ProviderConfig) error { + return nil +} + +func (i *sparkProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &sparkProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +func (p *sparkProvider) GetProviderType() string { + return providerTypeSpark +} + +func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + _ = util.OverwriteRequestHost(sparkHost) + _ = util.OverwriteRequestPath(sparkChatCompletionPath) + _ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetRandomToken()) + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + return types.ActionContinue, nil +} + +func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + // 使用Spark协议 + if p.config.protocol == protocolOriginal { + request := &sparkRequest{} + if err := json.Unmarshal(body, request); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + } + if request.Model == "" { + return types.ActionContinue, errors.New("request model is empty") + } + // 目前星火在模型名称错误时,也会调用generalv3,这里还是按照输入的模型名称设置响应里的模型名称 + ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + return types.ActionContinue, replaceJsonRequestBody(request, log) + } else { + // 使用openai协议 + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return types.ActionContinue, err + } + if request.Model == "" { + return types.ActionContinue, errors.New("missing model in chat completion request") + } + // 映射模型 + mappedModel := getMappedModel(request.Model, p.config.modelMapping, log) + if mappedModel == "" { + return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + } + ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) + request.Model = mappedModel + return types.ActionContinue, replaceJsonRequestBody(request, log) + } +} + +func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + _ = proxywasm.RemoveHttpResponseHeader("Content-Length") + return types.ActionContinue, nil +} + +func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + sparkResponse := &sparkResponse{} + if err := json.Unmarshal(body, sparkResponse); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal spark response: %v", err) + } + if sparkResponse.Code != 0 { + return types.ActionContinue, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message) + } + response := p.responseSpark2OpenAI(ctx, sparkResponse) + return types.ActionContinue, replaceJsonResponseBody(response, log) +} + +func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { + if isLastChunk || len(chunk) == 0 { + return nil, nil + } + responseBuilder := &strings.Builder{} + lines := strings.Split(string(chunk), "\n") + for _, data := range lines { + if len(data) < 6 { + // ignore blank line or wrong format + continue + } + data = data[6:] + // The final response is `data: [DONE]` + if data == "[DONE]" { + continue + } + var sparkResponse sparkStreamResponse + if err := json.Unmarshal([]byte(data), &sparkResponse); err != nil { + log.Errorf("unable to unmarshal spark response: %v", err) + continue + } + response := p.streamResponseSpark2OpenAI(ctx, &sparkResponse) + responseBody, err := json.Marshal(response) + if err != nil { + log.Errorf("unable to marshal response: %v", err) + return nil, err + } + p.appendResponse(responseBuilder, string(responseBody)) + } + modifiedResponseChunk := responseBuilder.String() + log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) + return []byte(modifiedResponseChunk), nil +} + +func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkResponse) *chatCompletionResponse { + choices := make([]chatCompletionChoice, len(response.Choices)) + for idx, c := range response.Choices { + choices[idx] = chatCompletionChoice{ + Index: c.Index, + Message: &chatMessage{Role: c.Message.Role, Content: c.Message.Content}, + } + } + return &chatCompletionResponse{ + Id: response.Sid, + Created: time.Now().UnixMilli() / 1000, + Object: objectChatCompletion, + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + Choices: choices, + Usage: response.Usage, + } +} + +func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkStreamResponse) *chatCompletionResponse { + choices := make([]chatCompletionChoice, len(response.Choices)) + for idx, c := range response.Choices { + choices[idx] = chatCompletionChoice{ + Index: c.Index, + Delta: &chatMessage{Role: c.Delta.Role, Content: c.Delta.Content}, + } + } + return &chatCompletionResponse{ + Id: response.Sid, + Created: response.Created, + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + Object: objectChatCompletion, + Choices: choices, + Usage: response.Usage, + } +} + +func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { + responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) +} diff --git a/plugins/wasm-go/go.sum b/plugins/wasm-go/go.sum index 5b23dc2c4..e726b100a 100644 --- a/plugins/wasm-go/go.sum +++ b/plugins/wasm-go/go.sum @@ -4,12 +4,7 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43 h1:dCw7F/9ciw4NZN7w68wQRaygZ2zGOWMTIEoRvP1tlWs= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=