feat: Enhance the feature of ai-proxy plugin (#976)

This commit is contained in:
Kent Dong
2024-05-22 20:30:46 +08:00
committed by GitHub
parent fc6a6aad89
commit 76b5f2af79
52 changed files with 543 additions and 107 deletions

View File

@@ -24,12 +24,17 @@ const (
qwenTopPMin = 0.000001
qwenTopPMax = 0.999999
qwenDummySystemMessageContent = "You are a helpful assistant."
)
type qwenProviderInitializer struct {
}
func (m *qwenProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.qwenFileIds) != 0 && config.context != nil {
return errors.New("qwenFileIds and context cannot be configured at the same time")
}
return nil
}
@@ -50,10 +55,6 @@ func (m *qwenProvider) GetProviderType() string {
return providerTypeQwen
}
const (
forceStreaming = true
)
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
@@ -67,17 +68,11 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
return types.ActionContinue, nil
}
if forceStreaming {
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
return types.ActionContinue, nil
} else {
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -132,18 +127,20 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
if !forceStreaming {
if request.Stream {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
} else {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "*/*")
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
}
streaming := request.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
} else {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "*/*")
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
}
if m.config.context == nil {
qwenRequest := m.buildQwenTextGenerationRequest(request)
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
}
return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
}
@@ -156,7 +153,10 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
qwenRequest := m.buildQwenTextGenerationRequest(request)
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
}
if err := replaceJsonRequestBody(qwenRequest, log); err != nil {
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
@@ -183,6 +183,11 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
receivedBody = append(bufferedStreamingBody, chunk...)
}
incrementalStreaming, err := ctx.GetContext(ctxKeyIncrementalStreaming).(bool)
if !err {
incrementalStreaming = false
}
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
defer func() {
@@ -194,7 +199,7 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
}
}()
// Sample event response:
// Sample Qwen event response:
//
// event:result
// :HTTP_STATUS/200
@@ -233,12 +238,11 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
if lineStartIndex != -1 {
value := string(receivedBody[valueStartIndex:i])
log.Debugf("key: %s value: %s", currentKey, value)
currentEvent.setValue(currentKey, value)
} else {
// Extra new line. The current event is complete.
log.Debugf("processing event: %v", currentEvent)
if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, log); err != nil {
if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, incrementalStreaming, log); err != nil {
return nil, err
}
// Reset event parsing state.
@@ -266,28 +270,52 @@ func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName,
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest) *qwenTextGenRequest {
return &qwenTextGenRequest{
func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest, streaming bool) *qwenTextGenRequest {
messages := make([]qwenMessage, 0, len(origRequest.Messages))
for i := range origRequest.Messages {
messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i]))
}
request := &qwenTextGenRequest{
Model: origRequest.Model,
Input: qwenTextGenInput{
Messages: origRequest.Messages,
Messages: messages,
},
Parameters: qwenTextGenParameters{
ResultFormat: qwenResultFormatMessage,
MaxTokens: origRequest.MaxTokens,
N: origRequest.N,
Seed: origRequest.Seed,
Temperature: origRequest.Temperature,
TopP: math.Max(qwenTopPMin, math.Min(origRequest.TopP, qwenTopPMax)),
ResultFormat: qwenResultFormatMessage,
MaxTokens: origRequest.MaxTokens,
N: origRequest.N,
Seed: origRequest.Seed,
Temperature: origRequest.Temperature,
TopP: math.Max(qwenTopPMin, math.Min(origRequest.TopP, qwenTopPMax)),
IncrementalOutput: streaming && (origRequest.Tools == nil || len(origRequest.Tools) == 0),
EnableSearch: m.config.qwenEnableSearch,
Tools: origRequest.Tools,
},
}
if len(m.config.qwenFileIds) != 0 {
builder := strings.Builder{}
for _, fileId := range m.config.qwenFileIds {
if builder.Len() != 0 {
builder.WriteRune(',')
}
builder.WriteString("fileid://")
builder.WriteString(fileId)
}
contextMessageId := m.insertContextMessage(request, builder.String())
if contextMessageId == 0 {
// The context message cannot come first. We need to add another dummy system message before it.
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
}
}
return request
}
func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse {
choices := make([]chatCompletionChoice, 0, len(qwenResponse.Output.Choices))
for _, qwenChoice := range qwenResponse.Output.Choices {
message := qwenMessageToChatMessage(qwenChoice.Message)
choices = append(choices, chatCompletionChoice{
Message: &qwenChoice.Message,
Message: &message,
FinishReason: qwenChoice.FinishReason,
})
}
@@ -306,7 +334,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen
}
}
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) []*chatCompletionResponse {
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool, log wrapper.Log) []*chatCompletionResponse {
baseMessage := chatCompletionResponse{
Id: qwenResponse.RequestId,
Created: time.Now().UnixMilli() / 1000,
@@ -320,17 +348,30 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
qwenChoice := qwenResponse.Output.Choices[0]
message := qwenChoice.Message
content := message.Content
if rawPushedContent := ctx.GetContext(ctxKeyPushedMessageContent); rawPushedContent != nil {
if pushedContent := rawPushedContent.(string); pushedContent != "" && strings.HasPrefix(content, pushedContent) {
content = content[len(pushedContent):]
deltaMessage := &chatMessage{Role: message.Role, Content: message.Content, ToolCalls: append([]toolCall{}, message.ToolCalls...)}
if !incrementalStreaming {
if pushedMessage, ok := ctx.GetContext(ctxKeyPushedMessage).(qwenMessage); ok {
deltaMessage.Content = util.StripPrefix(deltaMessage.Content, pushedMessage.Content)
if len(deltaMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil {
for i, tc := range deltaMessage.ToolCalls {
if i >= len(pushedMessage.ToolCalls) {
break
}
pushedFunction := pushedMessage.ToolCalls[i].Function
tc.Function.Id = util.StripPrefix(tc.Function.Id, pushedFunction.Id)
tc.Function.Name = util.StripPrefix(tc.Function.Name, pushedFunction.Name)
tc.Function.Arguments = util.StripPrefix(tc.Function.Arguments, pushedFunction.Arguments)
deltaMessage.ToolCalls[i] = tc
}
}
}
ctx.SetContext(ctxKeyPushedMessage, message)
}
if content != "" {
if !deltaMessage.IsEmpty() {
deltaResponse := *&baseMessage
deltaResponse.Choices = append(deltaResponse.Choices, chatCompletionChoice{Delta: &chatMessage{Role: message.Role, Content: content}})
deltaResponse.Choices = append(deltaResponse.Choices, chatCompletionChoice{Delta: deltaMessage})
responses = append(responses, &deltaResponse)
ctx.SetContext(ctxKeyPushedMessageContent, message.Content)
}
// Yes, Qwen uses a string "null" as null.
@@ -343,7 +384,7 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
return responses
}
func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error {
func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, incrementalStreaming bool, log wrapper.Log) error {
if event.Data == streamEndDataValue {
m.appendStreamEvent(responseBuilder, event)
return nil
@@ -361,7 +402,7 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
return fmt.Errorf("unable to unmarshal Qwen response: %v", err)
}
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse)
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
for _, response := range responses {
responseBody, err := json.Marshal(response)
if err != nil {
@@ -376,8 +417,8 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
return nil
}
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string) {
fileMessage := chatMessage{
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string) int {
fileMessage := qwenMessage{
Role: roleSystem,
Content: content,
}
@@ -392,23 +433,15 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
}
}
if firstNonSystemMessageIndex == -1 {
request.Input.Messages = append(request.Input.Messages, fileMessage)
request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...)
return 0
} else {
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
return firstNonSystemMessageIndex
}
}
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
responseBuilder.WriteString(streamEventIdItemKey)
responseBuilder.WriteString(event.Id)
responseBuilder.WriteString("\n")
responseBuilder.WriteString(streamEventNameItemKey)
responseBuilder.WriteString(event.Event)
responseBuilder.WriteString("\n")
responseBuilder.WriteString(streamBuiltInItemKey)
responseBuilder.WriteString(streamHttpStatusValuePrefix)
responseBuilder.WriteString(event.HttpStatus)
responseBuilder.WriteString("\n")
responseBuilder.WriteString(streamDataItemKey)
responseBuilder.WriteString(event.Data)
responseBuilder.WriteString("\n\n")
@@ -421,7 +454,7 @@ type qwenTextGenRequest struct {
}
type qwenTextGenInput struct {
Messages []chatMessage `json:"messages"`
Messages []qwenMessage `json:"messages"`
}
type qwenTextGenParameters struct {
@@ -432,6 +465,9 @@ type qwenTextGenParameters struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
Tools []tool `json:"tools,omitempty"`
}
type qwenTextGenResponse struct {
@@ -447,7 +483,7 @@ type qwenTextGenOutput struct {
type qwenTextGenChoice struct {
FinishReason string `json:"finish_reason"`
Message chatMessage `json:"message"`
Message qwenMessage `json:"message"`
}
type qwenTextGenUsage struct {
@@ -455,3 +491,28 @@ type qwenTextGenUsage struct {
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type qwenMessage struct {
Name string `json:"name,omitempty"`
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
}
func qwenMessageToChatMessage(qwenMessage qwenMessage) chatMessage {
return chatMessage{
Name: qwenMessage.Name,
Role: qwenMessage.Role,
Content: qwenMessage.Content,
ToolCalls: qwenMessage.ToolCalls,
}
}
func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage {
return qwenMessage{
Name: chatMessage.Name,
Role: chatMessage.Role,
Content: chatMessage.Content,
ToolCalls: chatMessage.ToolCalls,
}
}