Files
higress/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go

477 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package provider
import (
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// minimaxProvider is the provider for minimax service.
const (
minimaxDomain = "api.minimax.chat"
// minimaxChatCompletionV2Path 接口请求响应格式与OpenAI相同
// 接口文档: https://platform.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
minimaxChatCompletionV2Path = "/v1/text/chatcompletion_v2"
// minimaxChatCompletionProPath 接口请求响应格式与OpenAI不同
// 接口文档: https://platform.minimaxi.com/document/guides/chat-model/pro/api?id=6569c85948bc7b684b30377e
minimaxChatCompletionProPath = "/v1/text/chatcompletion_pro"
senderTypeUser string = "USER" // 用户发送的内容
senderTypeBot string = "BOT" // 模型生成的内容
// 默认机器人设置
defaultBotName string = "MM智能助理"
defaultBotSettingContent string = "MM智能助理是一款由MiniMax自研的没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司一直致力于进行大模型相关的研究。"
defaultSenderName string = "小明"
)
// chatCompletionProModels 这些模型对应接口为ChatCompletion Pro
var chatCompletionProModels = map[string]struct{}{
"abab6.5-chat": {},
"abab6.5s-chat": {},
"abab5.5s-chat": {},
"abab5.5-chat": {},
}
type minimaxProviderInitializer struct {
}
func (m *minimaxProviderInitializer) ValidateConfig(config ProviderConfig) error {
// 如果存在模型对应接口为ChatCompletion Pro必须配置minimaxGroupId
if len(config.modelMapping) > 0 && config.minimaxGroupId == "" {
for _, minimaxModel := range config.modelMapping {
if _, exists := chatCompletionProModels[minimaxModel]; exists {
return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when %s model is provided", minimaxModel))
}
}
}
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (m *minimaxProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &minimaxProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
type minimaxProvider struct {
config ProviderConfig
contextCache *contextCache
}
func (m *minimaxProvider) GetProviderType() string {
return providerTypeMinimax
}
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(minimaxDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
// 解析并映射模型,设置上下文
model, err := m.parseModel(body)
if err != nil {
return types.ActionContinue, err
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
_, ok := chatCompletionProModels[mappedModel]
if ok {
// 使用ChatCompletion Pro接口
return m.handleRequestBodyByChatCompletionPro(body, log)
} else {
// 使用ChatCompletion v2接口
return m.handleRequestBodyByChatCompletionV2(body, log)
}
}
// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
// 使用minimax接口协议
if m.config.protocol == protocolOriginal {
request := &minimaxChatCompletionV2Request{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
if m.config.minimaxGroupId == "" {
return types.ActionContinue, errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when use %s model ", request.Model))
}
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
if m.config.context == nil {
return types.ActionContinue, nil
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
m.setBotSettings(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
// 映射模型重写requestPath
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
if m.config.context == nil {
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, "")
return types.ActionContinue, replaceJsonRequestBody(minimaxRequest, log)
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content)
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
// 映射模型重写requestPath
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
_ = util.OverwriteRequestPath(minimaxChatCompletionV2Path)
if m.contextCache == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
// 使用minimax接口协议,跳过OnStreamingResponseBody()和OnResponseBody()
if m.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
// 模型对应接口为ChatCompletion v2,跳过OnStreamingResponseBody()和OnResponseBody()
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
if model != "" {
_, ok := chatCompletionProModels[model]
if !ok {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
// OnStreamingResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应
func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
// sample event response:
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false}
// sample end event response:
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"I am from China.","choices":[{"finish_reason":"stop","messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"I am from China."}]}],"usage":{"total_tokens":187},"input_sensitive":false,"output_sensitive":false,"id":"0106b3bc9fd844a9f3de1aa06004e2ab","base_resp":{"status_code":0,"status_msg":""}}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
var minimaxResp minimaxChatCompletionV2Resp
if err := json.Unmarshal([]byte(data), &minimaxResp); err != nil {
log.Errorf("unable to unmarshal minimax response: %v", err)
continue
}
response := m.responseV2ToOpenAI(&minimaxResp)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
m.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
}
// OnResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应
func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
minimaxResp := &minimaxChatCompletionV2Resp{}
if err := json.Unmarshal(body, minimaxResp); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal minimax response: %v", err)
}
if minimaxResp.BaseResp.StatusCode != 0 {
return types.ActionContinue, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
}
response := m.responseV2ToOpenAI(minimaxResp)
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
// minimaxChatCompletionV2Request 表示ChatCompletion V2请求的结构体
type minimaxChatCompletionV2Request struct {
Model string `json:"model"`
Stream bool `json:"stream,omitempty"`
TokensToGenerate int64 `json:"tokens_to_generate,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
MaskSensitiveInfo bool `json:"mask_sensitive_info"` // 是否开启隐私信息打码,默认true
Messages []minimaxMessage `json:"messages"`
BotSettings []minimaxBotSetting `json:"bot_setting"`
ReplyConstraints minimaxReplyConstraints `json:"reply_constraints"`
}
// minimaxMessage 表示对话中的消息
type minimaxMessage struct {
SenderType string `json:"sender_type"`
SenderName string `json:"sender_name"`
Text string `json:"text"`
}
// minimaxBotSetting 表示机器人的设置
type minimaxBotSetting struct {
BotName string `json:"bot_name"`
Content string `json:"content"`
}
// minimaxReplyConstraints 表示模型回复要求
type minimaxReplyConstraints struct {
SenderType string `json:"sender_type"`
SenderName string `json:"sender_name"`
}
// minimaxChatCompletionV2Resp Minimax Chat Completion V2响应结构体
type minimaxChatCompletionV2Resp struct {
Created int64 `json:"created"`
Model string `json:"model"`
Reply string `json:"reply"`
InputSensitive bool `json:"input_sensitive,omitempty"`
InputSensitiveType int64 `json:"input_sensitive_type,omitempty"`
OutputSensitive bool `json:"output_sensitive,omitempty"`
OutputSensitiveType int64 `json:"output_sensitive_type,omitempty"`
Choices []minimaxChoice `json:"choices,omitempty"`
Usage minimaxUsage `json:"usage,omitempty"`
Id string `json:"id"`
BaseResp minimaxBaseResp `json:"base_resp"`
}
// minimaxBaseResp 包含错误状态码和详情
type minimaxBaseResp struct {
StatusCode int64 `json:"status_code"`
StatusMsg string `json:"status_msg"`
}
// minimaxChoice 结果选项
type minimaxChoice struct {
Messages []minimaxMessage `json:"messages"`
Index int64 `json:"index"`
FinishReason string `json:"finish_reason"`
}
// minimaxUsage 令牌使用情况
type minimaxUsage struct {
TotalTokens int64 `json:"total_tokens"`
}
func (m *minimaxProvider) parseModel(body []byte) (string, error) {
var tempMap map[string]interface{}
if err := json.Unmarshal(body, &tempMap); err != nil {
return "", err
}
model, ok := tempMap["model"].(string)
if !ok {
return "", errors.New("missing model in chat completion request")
}
return model, nil
}
func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionV2Request, botSettingContent string) {
if len(request.BotSettings) == 0 {
request.BotSettings = []minimaxBotSetting{
{
BotName: defaultBotName,
Content: func() string {
if botSettingContent != "" {
return botSettingContent
}
return defaultBotSettingContent
}(),
},
}
} else if botSettingContent != "" {
newSetting := minimaxBotSetting{
BotName: request.BotSettings[0].BotName,
Content: botSettingContent,
}
request.BotSettings = append([]minimaxBotSetting{newSetting}, request.BotSettings...)
}
}
func (m *minimaxProvider) buildMinimaxChatCompletionV2Request(request *chatCompletionRequest, botSettingContent string) *minimaxChatCompletionV2Request {
var messages []minimaxMessage
var botSetting []minimaxBotSetting
var botName string
determineName := func(name string, defaultName string) string {
if name != "" {
return name
}
return defaultName
}
for _, message := range request.Messages {
switch message.Role {
case roleSystem:
botName = determineName(message.Name, defaultBotName)
botSetting = append(botSetting, minimaxBotSetting{
BotName: botName,
Content: message.StringContent(),
})
case roleAssistant:
messages = append(messages, minimaxMessage{
SenderType: senderTypeBot,
SenderName: determineName(message.Name, defaultBotName),
Text: message.StringContent(),
})
case roleUser:
messages = append(messages, minimaxMessage{
SenderType: senderTypeUser,
SenderName: determineName(message.Name, defaultSenderName),
Text: message.StringContent(),
})
}
}
replyConstraints := minimaxReplyConstraints{
SenderType: senderTypeBot,
SenderName: determineName(botName, defaultBotName),
}
result := &minimaxChatCompletionV2Request{
Model: request.Model,
Stream: request.Stream,
TokensToGenerate: int64(request.MaxTokens),
Temperature: request.Temperature,
TopP: request.TopP,
MaskSensitiveInfo: true,
Messages: messages,
BotSettings: botSetting,
ReplyConstraints: replyConstraints,
}
m.setBotSettings(result, botSettingContent)
return result
}
func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Resp) *chatCompletionResponse {
var choices []chatCompletionChoice
messageIndex := 0
for _, choice := range response.Choices {
for _, message := range choice.Messages {
message := &chatMessage{
Name: message.SenderName,
Role: roleAssistant,
Content: message.Text,
}
choices = append(choices, chatCompletionChoice{
FinishReason: choice.FinishReason,
Index: messageIndex,
Message: message,
})
messageIndex++
}
}
return &chatCompletionResponse{
Id: response.Id,
Object: objectChatCompletion,
Created: response.Created,
Model: response.Model,
Choices: choices,
Usage: usage{
TotalTokens: int(response.Usage.TotalTokens),
},
}
}
func (m *minimaxProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}