mirror of
https://github.com/alibaba/higress.git
synced 2026-05-08 04:17:27 +08:00
feat: Enhance the feature of ai-proxy plugin (#976)
This commit is contained in:
@@ -24,12 +24,17 @@ const (
|
||||
|
||||
qwenTopPMin = 0.000001
|
||||
qwenTopPMax = 0.999999
|
||||
|
||||
qwenDummySystemMessageContent = "You are a helpful assistant."
|
||||
)
|
||||
|
||||
type qwenProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (m *qwenProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
if len(config.qwenFileIds) != 0 && config.context != nil {
|
||||
return errors.New("qwenFileIds and context cannot be configured at the same time")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -50,10 +55,6 @@ func (m *qwenProvider) GetProviderType() string {
|
||||
return providerTypeQwen
|
||||
}
|
||||
|
||||
const (
|
||||
forceStreaming = true
|
||||
)
|
||||
|
||||
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
@@ -67,17 +68,11 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
if forceStreaming {
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
|
||||
return types.ActionContinue, nil
|
||||
} else {
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -132,18 +127,20 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
|
||||
request.Model = mappedModel
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
|
||||
|
||||
if !forceStreaming {
|
||||
if request.Stream {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
|
||||
} else {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "*/*")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
|
||||
}
|
||||
streaming := request.Stream
|
||||
if streaming {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
|
||||
} else {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "*/*")
|
||||
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
|
||||
}
|
||||
|
||||
if m.config.context == nil {
|
||||
qwenRequest := m.buildQwenTextGenerationRequest(request)
|
||||
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
|
||||
if streaming {
|
||||
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
|
||||
}
|
||||
return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
|
||||
}
|
||||
|
||||
@@ -156,7 +153,10 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
insertContextMessage(request, content)
|
||||
qwenRequest := m.buildQwenTextGenerationRequest(request)
|
||||
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
|
||||
if streaming {
|
||||
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
|
||||
}
|
||||
if err := replaceJsonRequestBody(qwenRequest, log); err != nil {
|
||||
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
@@ -183,6 +183,11 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
|
||||
receivedBody = append(bufferedStreamingBody, chunk...)
|
||||
}
|
||||
|
||||
incrementalStreaming, err := ctx.GetContext(ctxKeyIncrementalStreaming).(bool)
|
||||
if !err {
|
||||
incrementalStreaming = false
|
||||
}
|
||||
|
||||
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
|
||||
|
||||
defer func() {
|
||||
@@ -194,7 +199,7 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
|
||||
}
|
||||
}()
|
||||
|
||||
// Sample event response:
|
||||
// Sample Qwen event response:
|
||||
//
|
||||
// event:result
|
||||
// :HTTP_STATUS/200
|
||||
@@ -233,12 +238,11 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
|
||||
|
||||
if lineStartIndex != -1 {
|
||||
value := string(receivedBody[valueStartIndex:i])
|
||||
log.Debugf("key: %s value: %s", currentKey, value)
|
||||
currentEvent.setValue(currentKey, value)
|
||||
} else {
|
||||
// Extra new line. The current event is complete.
|
||||
log.Debugf("processing event: %v", currentEvent)
|
||||
if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, log); err != nil {
|
||||
if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, incrementalStreaming, log); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reset event parsing state.
|
||||
@@ -266,28 +270,52 @@ func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest) *qwenTextGenRequest {
|
||||
return &qwenTextGenRequest{
|
||||
func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest, streaming bool) *qwenTextGenRequest {
|
||||
messages := make([]qwenMessage, 0, len(origRequest.Messages))
|
||||
for i := range origRequest.Messages {
|
||||
messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i]))
|
||||
}
|
||||
request := &qwenTextGenRequest{
|
||||
Model: origRequest.Model,
|
||||
Input: qwenTextGenInput{
|
||||
Messages: origRequest.Messages,
|
||||
Messages: messages,
|
||||
},
|
||||
Parameters: qwenTextGenParameters{
|
||||
ResultFormat: qwenResultFormatMessage,
|
||||
MaxTokens: origRequest.MaxTokens,
|
||||
N: origRequest.N,
|
||||
Seed: origRequest.Seed,
|
||||
Temperature: origRequest.Temperature,
|
||||
TopP: math.Max(qwenTopPMin, math.Min(origRequest.TopP, qwenTopPMax)),
|
||||
ResultFormat: qwenResultFormatMessage,
|
||||
MaxTokens: origRequest.MaxTokens,
|
||||
N: origRequest.N,
|
||||
Seed: origRequest.Seed,
|
||||
Temperature: origRequest.Temperature,
|
||||
TopP: math.Max(qwenTopPMin, math.Min(origRequest.TopP, qwenTopPMax)),
|
||||
IncrementalOutput: streaming && (origRequest.Tools == nil || len(origRequest.Tools) == 0),
|
||||
EnableSearch: m.config.qwenEnableSearch,
|
||||
Tools: origRequest.Tools,
|
||||
},
|
||||
}
|
||||
if len(m.config.qwenFileIds) != 0 {
|
||||
builder := strings.Builder{}
|
||||
for _, fileId := range m.config.qwenFileIds {
|
||||
if builder.Len() != 0 {
|
||||
builder.WriteRune(',')
|
||||
}
|
||||
builder.WriteString("fileid://")
|
||||
builder.WriteString(fileId)
|
||||
}
|
||||
contextMessageId := m.insertContextMessage(request, builder.String())
|
||||
if contextMessageId == 0 {
|
||||
// The context message cannot come first. We need to add another dummy system message before it.
|
||||
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
|
||||
}
|
||||
}
|
||||
return request
|
||||
}
|
||||
|
||||
func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse {
|
||||
choices := make([]chatCompletionChoice, 0, len(qwenResponse.Output.Choices))
|
||||
for _, qwenChoice := range qwenResponse.Output.Choices {
|
||||
message := qwenMessageToChatMessage(qwenChoice.Message)
|
||||
choices = append(choices, chatCompletionChoice{
|
||||
Message: &qwenChoice.Message,
|
||||
Message: &message,
|
||||
FinishReason: qwenChoice.FinishReason,
|
||||
})
|
||||
}
|
||||
@@ -306,7 +334,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen
|
||||
}
|
||||
}
|
||||
|
||||
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) []*chatCompletionResponse {
|
||||
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool, log wrapper.Log) []*chatCompletionResponse {
|
||||
baseMessage := chatCompletionResponse{
|
||||
Id: qwenResponse.RequestId,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
@@ -320,17 +348,30 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
|
||||
qwenChoice := qwenResponse.Output.Choices[0]
|
||||
message := qwenChoice.Message
|
||||
|
||||
content := message.Content
|
||||
if rawPushedContent := ctx.GetContext(ctxKeyPushedMessageContent); rawPushedContent != nil {
|
||||
if pushedContent := rawPushedContent.(string); pushedContent != "" && strings.HasPrefix(content, pushedContent) {
|
||||
content = content[len(pushedContent):]
|
||||
deltaMessage := &chatMessage{Role: message.Role, Content: message.Content, ToolCalls: append([]toolCall{}, message.ToolCalls...)}
|
||||
if !incrementalStreaming {
|
||||
if pushedMessage, ok := ctx.GetContext(ctxKeyPushedMessage).(qwenMessage); ok {
|
||||
deltaMessage.Content = util.StripPrefix(deltaMessage.Content, pushedMessage.Content)
|
||||
if len(deltaMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil {
|
||||
for i, tc := range deltaMessage.ToolCalls {
|
||||
if i >= len(pushedMessage.ToolCalls) {
|
||||
break
|
||||
}
|
||||
pushedFunction := pushedMessage.ToolCalls[i].Function
|
||||
tc.Function.Id = util.StripPrefix(tc.Function.Id, pushedFunction.Id)
|
||||
tc.Function.Name = util.StripPrefix(tc.Function.Name, pushedFunction.Name)
|
||||
tc.Function.Arguments = util.StripPrefix(tc.Function.Arguments, pushedFunction.Arguments)
|
||||
deltaMessage.ToolCalls[i] = tc
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx.SetContext(ctxKeyPushedMessage, message)
|
||||
}
|
||||
if content != "" {
|
||||
|
||||
if !deltaMessage.IsEmpty() {
|
||||
deltaResponse := *&baseMessage
|
||||
deltaResponse.Choices = append(deltaResponse.Choices, chatCompletionChoice{Delta: &chatMessage{Role: message.Role, Content: content}})
|
||||
deltaResponse.Choices = append(deltaResponse.Choices, chatCompletionChoice{Delta: deltaMessage})
|
||||
responses = append(responses, &deltaResponse)
|
||||
ctx.SetContext(ctxKeyPushedMessageContent, message.Content)
|
||||
}
|
||||
|
||||
// Yes, Qwen uses a string "null" as null.
|
||||
@@ -343,7 +384,7 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
|
||||
return responses
|
||||
}
|
||||
|
||||
func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error {
|
||||
func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, incrementalStreaming bool, log wrapper.Log) error {
|
||||
if event.Data == streamEndDataValue {
|
||||
m.appendStreamEvent(responseBuilder, event)
|
||||
return nil
|
||||
@@ -361,7 +402,7 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
|
||||
return fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||
}
|
||||
|
||||
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse)
|
||||
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
|
||||
for _, response := range responses {
|
||||
responseBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
@@ -376,8 +417,8 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string) {
|
||||
fileMessage := chatMessage{
|
||||
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string) int {
|
||||
fileMessage := qwenMessage{
|
||||
Role: roleSystem,
|
||||
Content: content,
|
||||
}
|
||||
@@ -392,23 +433,15 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
|
||||
}
|
||||
}
|
||||
if firstNonSystemMessageIndex == -1 {
|
||||
request.Input.Messages = append(request.Input.Messages, fileMessage)
|
||||
request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...)
|
||||
return 0
|
||||
} else {
|
||||
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
|
||||
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
|
||||
return firstNonSystemMessageIndex
|
||||
}
|
||||
}
|
||||
|
||||
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
|
||||
responseBuilder.WriteString(streamEventIdItemKey)
|
||||
responseBuilder.WriteString(event.Id)
|
||||
responseBuilder.WriteString("\n")
|
||||
responseBuilder.WriteString(streamEventNameItemKey)
|
||||
responseBuilder.WriteString(event.Event)
|
||||
responseBuilder.WriteString("\n")
|
||||
responseBuilder.WriteString(streamBuiltInItemKey)
|
||||
responseBuilder.WriteString(streamHttpStatusValuePrefix)
|
||||
responseBuilder.WriteString(event.HttpStatus)
|
||||
responseBuilder.WriteString("\n")
|
||||
responseBuilder.WriteString(streamDataItemKey)
|
||||
responseBuilder.WriteString(event.Data)
|
||||
responseBuilder.WriteString("\n\n")
|
||||
@@ -421,7 +454,7 @@ type qwenTextGenRequest struct {
|
||||
}
|
||||
|
||||
type qwenTextGenInput struct {
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Messages []qwenMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type qwenTextGenParameters struct {
|
||||
@@ -432,6 +465,9 @@ type qwenTextGenParameters struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type qwenTextGenResponse struct {
|
||||
@@ -447,7 +483,7 @@ type qwenTextGenOutput struct {
|
||||
|
||||
type qwenTextGenChoice struct {
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message chatMessage `json:"message"`
|
||||
Message qwenMessage `json:"message"`
|
||||
}
|
||||
|
||||
type qwenTextGenUsage struct {
|
||||
@@ -455,3 +491,28 @@ type qwenTextGenUsage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type qwenMessage struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
func qwenMessageToChatMessage(qwenMessage qwenMessage) chatMessage {
|
||||
return chatMessage{
|
||||
Name: qwenMessage.Name,
|
||||
Role: qwenMessage.Role,
|
||||
Content: qwenMessage.Content,
|
||||
ToolCalls: qwenMessage.ToolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage {
|
||||
return qwenMessage{
|
||||
Name: chatMessage.Name,
|
||||
Role: chatMessage.Role,
|
||||
Content: chatMessage.Content,
|
||||
ToolCalls: chatMessage.ToolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user