Feat: Add Spark llm support for plugins/ai-proxy (#1139)

This commit is contained in:
urlyy
2024-08-08 15:16:58 +08:00
committed by GitHub
parent dc0dcaaaee
commit c78ef7011d
4 changed files with 276 additions and 7 deletions

View File

@@ -0,0 +1,207 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// sparkProvider is the provider for SparkLLM AI service.
const (
sparkHost = "spark-api-open.xf-yun.com"
sparkChatCompletionPath = "/v1/chat/completions"
)
type sparkProviderInitializer struct {
}
type sparkProvider struct {
config ProviderConfig
contextCache *contextCache
}
type sparkRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
TopK int `json:"top_k,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Tools []tool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
}
type sparkResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Choices []chatCompletionChoice `json:"choices"`
Usage usage `json:"usage,omitempty"`
}
type sparkStreamResponse struct {
sparkResponse
Id string `json:"id"`
Created int64 `json:"created"`
}
func (i *sparkProviderInitializer) ValidateConfig(config ProviderConfig) error {
return nil
}
func (i *sparkProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &sparkProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
func (p *sparkProvider) GetProviderType() string {
return providerTypeSpark
}
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(sparkHost)
_ = util.OverwriteRequestPath(sparkChatCompletionPath)
_ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
}
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
// 使用Spark协议
if p.config.protocol == protocolOriginal {
request := &sparkRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 目前星火在模型名称错误时也会调用generalv3这里还是按照输入的模型名称设置响应里的模型名称
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
return types.ActionContinue, replaceJsonRequestBody(request, log)
} else {
// 使用openai协议
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, p.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
}
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
sparkResponse := &sparkResponse{}
if err := json.Unmarshal(body, sparkResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal spark response: %v", err)
}
if sparkResponse.Code != 0 {
return types.ActionContinue, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message)
}
response := p.responseSpark2OpenAI(ctx, sparkResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
// The final response is `data: [DONE]`
if data == "[DONE]" {
continue
}
var sparkResponse sparkStreamResponse
if err := json.Unmarshal([]byte(data), &sparkResponse); err != nil {
log.Errorf("unable to unmarshal spark response: %v", err)
continue
}
response := p.streamResponseSpark2OpenAI(ctx, &sparkResponse)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
p.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
}
func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkResponse) *chatCompletionResponse {
choices := make([]chatCompletionChoice, len(response.Choices))
for idx, c := range response.Choices {
choices[idx] = chatCompletionChoice{
Index: c.Index,
Message: &chatMessage{Role: c.Message.Role, Content: c.Message.Content},
}
}
return &chatCompletionResponse{
Id: response.Sid,
Created: time.Now().UnixMilli() / 1000,
Object: objectChatCompletion,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Choices: choices,
Usage: response.Usage,
}
}
func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkStreamResponse) *chatCompletionResponse {
choices := make([]chatCompletionChoice, len(response.Choices))
for idx, c := range response.Choices {
choices[idx] = chatCompletionChoice{
Index: c.Index,
Delta: &chatMessage{Role: c.Delta.Role, Content: c.Delta.Content},
}
}
return &chatCompletionResponse{
Id: response.Sid,
Created: response.Created,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Object: objectChatCompletion,
Choices: choices,
Usage: response.Usage,
}
}
func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}