feat: Enhance the feature of ai-proxy plugin (#976)

This commit is contained in:
Kent Dong
2024-05-22 20:30:46 +08:00
committed by GitHub
parent fc6a6aad89
commit 76b5f2af79
52 changed files with 543 additions and 107 deletions

View File

@@ -16,17 +16,40 @@ const (
)
type chatCompletionRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed int `json:"seed,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
User string `json:"user,omitempty"`
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed int `json:"seed,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *streamOptions `json:"stream_options,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Tools []tool `json:"tools,omitempty"`
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}
type streamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
type tool struct {
Type string `json:"type"`
Function function `json:"function"`
}
type function struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
}
type toolChoice struct {
Type string `json:"type"`
Function function `json:"function"`
}
type chatCompletionResponse struct {
@@ -53,9 +76,45 @@ type chatCompletionUsage struct {
}
type chatMessage struct {
Name string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Name string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
}
func (m *chatMessage) IsEmpty() bool {
if m.Content != "" {
return false
}
if len(m.ToolCalls) != 0 {
nonEmpty := false
for _, toolCall := range m.ToolCalls {
if !toolCall.Function.IsEmpty() {
nonEmpty = true
break
}
}
if nonEmpty {
return false
}
}
return true
}
type toolCall struct {
Id string `json:"id"`
Type string `json:"type"`
Function functionCall `json:"function"`
}
type functionCall struct {
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Arguments string `json:"arguments"`
}
func (m *functionCall) IsEmpty() bool {
return m.Name == "" && m.Arguments == ""
}
type streamEvent struct {

View File

@@ -23,15 +23,16 @@ const (
providerTypeBaichuan = "baichuan"
providerTypeYi = "yi"
protocolOpenAI = "openai"
protocolOriginal = "original"
protocolOpenAI = "openai"
protocolOriginal = "original"
roleSystem = "system"
ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyStreamingBody = "streamingBody"
ctxKeyOriginalRequestModel = "originalRequestModel"
ctxKeyFinalRequestModel = "finalRequestModel"
ctxKeyPushedMessageContent = "pushedMessageContent"
ctxKeyPushedMessage = "pushedMessage"
objectChatCompletion = "chat.completion"
objectChatCompletionChunk = "chat.completion.chunk"
@@ -95,11 +96,17 @@ type ProviderConfig struct {
// @Description zh-CN 请求AI服务的超时时间单位为毫秒。默认值为120000即2分钟
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
// @Title zh-CN Moonshot File ID
// @Description zh-CN 仅适用于Moonshot AI服务。Moonshot AI服务的文件 ID其内容用于补充 AI 请求上下文
// @Description zh-CN 仅适用于Moonshot AI服务。Moonshot AI服务的文件ID其内容用于补充AI请求上下文
moonshotFileId string `required:"false" yaml:"moonshotFileId" json:"moonshotFileId"`
// @Title zh-CN Azure OpenAI Service URL
// @Description zh-CN 仅适用于Azure OpenAI服务。要请求的OpenAI服务的完整URL包含api-version等参数
azureServiceUrl string `required:"false" yaml:"azureServiceUrl" json:"azureServiceUrl"`
// @Title zh-CN 通义千问File ID
// @Description zh-CN 仅适用于通义千问服务。上传到Dashscope的文件ID其内容用于补充AI请求上下文。仅支持qwen-long模型。
qwenFileIds []string `required:"false" yaml:"qwenFileIds" json:"qwenFileIds"`
// @Title zh-CN 启用通义千问搜索服务
// @Description zh-CN 仅适用于通义千问服务,表示是否启用通义千问的互联网搜索功能。
qwenEnableSearch bool `required:"false" yaml:"qwenEnableSearch" json:"qwenEnableSearch"`
// @Title zh-CN 模型名称映射表
// @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射
modelMapping map[string]string `required:"false" yaml:"modelMapping" json:"modelMapping"`
@@ -123,6 +130,11 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
}
c.moonshotFileId = json.Get("moonshotFileId").String()
c.azureServiceUrl = json.Get("azureServiceUrl").String()
c.qwenFileIds = make([]string, 0)
for _, fileId := range json.Get("qwenFileIds").Array() {
c.qwenFileIds = append(c.qwenFileIds, fileId.String())
}
c.qwenEnableSearch = json.Get("qwenEnableSearch").Bool()
c.modelMapping = make(map[string]string)
for k, v := range json.Get("modelMapping").Map() {
c.modelMapping[k] = v.String()

View File

@@ -24,12 +24,17 @@ const (
qwenTopPMin = 0.000001
qwenTopPMax = 0.999999
qwenDummySystemMessageContent = "You are a helpful assistant."
)
type qwenProviderInitializer struct {
}
func (m *qwenProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.qwenFileIds) != 0 && config.context != nil {
return errors.New("qwenFileIds and context cannot be configured at the same time")
}
return nil
}
@@ -50,10 +55,6 @@ func (m *qwenProvider) GetProviderType() string {
return providerTypeQwen
}
const (
forceStreaming = true
)
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
@@ -67,17 +68,11 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
return types.ActionContinue, nil
}
if forceStreaming {
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
return types.ActionContinue, nil
} else {
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -132,18 +127,20 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
if !forceStreaming {
if request.Stream {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
} else {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "*/*")
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
}
streaming := request.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
} else {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "*/*")
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
}
if m.config.context == nil {
qwenRequest := m.buildQwenTextGenerationRequest(request)
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
}
return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
}
@@ -156,7 +153,10 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
qwenRequest := m.buildQwenTextGenerationRequest(request)
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
}
if err := replaceJsonRequestBody(qwenRequest, log); err != nil {
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
@@ -183,6 +183,11 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
receivedBody = append(bufferedStreamingBody, chunk...)
}
incrementalStreaming, err := ctx.GetContext(ctxKeyIncrementalStreaming).(bool)
if !err {
incrementalStreaming = false
}
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
defer func() {
@@ -194,7 +199,7 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
}
}()
// Sample event response:
// Sample Qwen event response:
//
// event:result
// :HTTP_STATUS/200
@@ -233,12 +238,11 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
if lineStartIndex != -1 {
value := string(receivedBody[valueStartIndex:i])
log.Debugf("key: %s value: %s", currentKey, value)
currentEvent.setValue(currentKey, value)
} else {
// Extra new line. The current event is complete.
log.Debugf("processing event: %v", currentEvent)
if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, log); err != nil {
if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, incrementalStreaming, log); err != nil {
return nil, err
}
// Reset event parsing state.
@@ -266,28 +270,52 @@ func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName,
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest) *qwenTextGenRequest {
return &qwenTextGenRequest{
func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest, streaming bool) *qwenTextGenRequest {
messages := make([]qwenMessage, 0, len(origRequest.Messages))
for i := range origRequest.Messages {
messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i]))
}
request := &qwenTextGenRequest{
Model: origRequest.Model,
Input: qwenTextGenInput{
Messages: origRequest.Messages,
Messages: messages,
},
Parameters: qwenTextGenParameters{
ResultFormat: qwenResultFormatMessage,
MaxTokens: origRequest.MaxTokens,
N: origRequest.N,
Seed: origRequest.Seed,
Temperature: origRequest.Temperature,
TopP: math.Max(qwenTopPMin, math.Min(origRequest.TopP, qwenTopPMax)),
ResultFormat: qwenResultFormatMessage,
MaxTokens: origRequest.MaxTokens,
N: origRequest.N,
Seed: origRequest.Seed,
Temperature: origRequest.Temperature,
TopP: math.Max(qwenTopPMin, math.Min(origRequest.TopP, qwenTopPMax)),
IncrementalOutput: streaming && (origRequest.Tools == nil || len(origRequest.Tools) == 0),
EnableSearch: m.config.qwenEnableSearch,
Tools: origRequest.Tools,
},
}
if len(m.config.qwenFileIds) != 0 {
builder := strings.Builder{}
for _, fileId := range m.config.qwenFileIds {
if builder.Len() != 0 {
builder.WriteRune(',')
}
builder.WriteString("fileid://")
builder.WriteString(fileId)
}
contextMessageId := m.insertContextMessage(request, builder.String())
if contextMessageId == 0 {
// The context message cannot come first. We need to add another dummy system message before it.
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
}
}
return request
}
func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse {
choices := make([]chatCompletionChoice, 0, len(qwenResponse.Output.Choices))
for _, qwenChoice := range qwenResponse.Output.Choices {
message := qwenMessageToChatMessage(qwenChoice.Message)
choices = append(choices, chatCompletionChoice{
Message: &qwenChoice.Message,
Message: &message,
FinishReason: qwenChoice.FinishReason,
})
}
@@ -306,7 +334,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen
}
}
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) []*chatCompletionResponse {
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool, log wrapper.Log) []*chatCompletionResponse {
baseMessage := chatCompletionResponse{
Id: qwenResponse.RequestId,
Created: time.Now().UnixMilli() / 1000,
@@ -320,17 +348,30 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
qwenChoice := qwenResponse.Output.Choices[0]
message := qwenChoice.Message
content := message.Content
if rawPushedContent := ctx.GetContext(ctxKeyPushedMessageContent); rawPushedContent != nil {
if pushedContent := rawPushedContent.(string); pushedContent != "" && strings.HasPrefix(content, pushedContent) {
content = content[len(pushedContent):]
deltaMessage := &chatMessage{Role: message.Role, Content: message.Content, ToolCalls: append([]toolCall{}, message.ToolCalls...)}
if !incrementalStreaming {
if pushedMessage, ok := ctx.GetContext(ctxKeyPushedMessage).(qwenMessage); ok {
deltaMessage.Content = util.StripPrefix(deltaMessage.Content, pushedMessage.Content)
if len(deltaMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil {
for i, tc := range deltaMessage.ToolCalls {
if i >= len(pushedMessage.ToolCalls) {
break
}
pushedFunction := pushedMessage.ToolCalls[i].Function
tc.Function.Id = util.StripPrefix(tc.Function.Id, pushedFunction.Id)
tc.Function.Name = util.StripPrefix(tc.Function.Name, pushedFunction.Name)
tc.Function.Arguments = util.StripPrefix(tc.Function.Arguments, pushedFunction.Arguments)
deltaMessage.ToolCalls[i] = tc
}
}
}
ctx.SetContext(ctxKeyPushedMessage, message)
}
if content != "" {
if !deltaMessage.IsEmpty() {
deltaResponse := *&baseMessage
deltaResponse.Choices = append(deltaResponse.Choices, chatCompletionChoice{Delta: &chatMessage{Role: message.Role, Content: content}})
deltaResponse.Choices = append(deltaResponse.Choices, chatCompletionChoice{Delta: deltaMessage})
responses = append(responses, &deltaResponse)
ctx.SetContext(ctxKeyPushedMessageContent, message.Content)
}
// Yes, Qwen uses a string "null" as null.
@@ -343,7 +384,7 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
return responses
}
func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error {
func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, incrementalStreaming bool, log wrapper.Log) error {
if event.Data == streamEndDataValue {
m.appendStreamEvent(responseBuilder, event)
return nil
@@ -361,7 +402,7 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
return fmt.Errorf("unable to unmarshal Qwen response: %v", err)
}
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse)
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
for _, response := range responses {
responseBody, err := json.Marshal(response)
if err != nil {
@@ -376,8 +417,8 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
return nil
}
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string) {
fileMessage := chatMessage{
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string) int {
fileMessage := qwenMessage{
Role: roleSystem,
Content: content,
}
@@ -392,23 +433,15 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
}
}
if firstNonSystemMessageIndex == -1 {
request.Input.Messages = append(request.Input.Messages, fileMessage)
request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...)
return 0
} else {
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
return firstNonSystemMessageIndex
}
}
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
responseBuilder.WriteString(streamEventIdItemKey)
responseBuilder.WriteString(event.Id)
responseBuilder.WriteString("\n")
responseBuilder.WriteString(streamEventNameItemKey)
responseBuilder.WriteString(event.Event)
responseBuilder.WriteString("\n")
responseBuilder.WriteString(streamBuiltInItemKey)
responseBuilder.WriteString(streamHttpStatusValuePrefix)
responseBuilder.WriteString(event.HttpStatus)
responseBuilder.WriteString("\n")
responseBuilder.WriteString(streamDataItemKey)
responseBuilder.WriteString(event.Data)
responseBuilder.WriteString("\n\n")
@@ -421,7 +454,7 @@ type qwenTextGenRequest struct {
}
type qwenTextGenInput struct {
Messages []chatMessage `json:"messages"`
Messages []qwenMessage `json:"messages"`
}
type qwenTextGenParameters struct {
@@ -432,6 +465,9 @@ type qwenTextGenParameters struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
Tools []tool `json:"tools,omitempty"`
}
type qwenTextGenResponse struct {
@@ -447,7 +483,7 @@ type qwenTextGenOutput struct {
type qwenTextGenChoice struct {
FinishReason string `json:"finish_reason"`
Message chatMessage `json:"message"`
Message qwenMessage `json:"message"`
}
type qwenTextGenUsage struct {
@@ -455,3 +491,28 @@ type qwenTextGenUsage struct {
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type qwenMessage struct {
Name string `json:"name,omitempty"`
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
}
func qwenMessageToChatMessage(qwenMessage qwenMessage) chatMessage {
return chatMessage{
Name: qwenMessage.Name,
Role: qwenMessage.Role,
Content: qwenMessage.Content,
ToolCalls: qwenMessage.ToolCalls,
}
}
func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage {
return qwenMessage{
Name: chatMessage.Name,
Role: chatMessage.Role,
Content: chatMessage.Content,
ToolCalls: chatMessage.ToolCalls,
}
}

View File

@@ -45,7 +45,7 @@ func insertContextMessage(request *chatCompletionRequest, content string) {
}
}
if firstNonSystemMessageIndex == -1 {
request.Messages = append(request.Messages, fileMessage)
request.Messages = append([]chatMessage{fileMessage}, request.Messages...)
} else {
request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...)
}