diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 82e038bbc..46e7af794 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -5,7 +5,6 @@ package main import ( "fmt" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "net/url" "strings" @@ -13,6 +12,7 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" ) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index e47e1aec9..5c84adf8c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -1,5 +1,20 @@ package provider +import "strings" + +const ( + streamEventIdItemKey = "id:" + streamEventNameItemKey = "event:" + streamBuiltInItemKey = ":" + streamHttpStatusValuePrefix = "HTTP_STATUS/" + streamDataItemKey = "data:" + streamEndDataValue = "[DONE]" + + eventResult = "result" + + httpStatus200 = "200" +) + type chatCompletionRequest struct { Model string `json:"model"` Messages []chatMessage `json:"messages"` @@ -42,3 +57,25 @@ type chatMessage struct { Role string `json:"role,omitempty"` Content string `json:"content,omitempty"` } + +type streamEvent struct { + Id string `json:"id"` + Event string `json:"event"` + Data string `json:"data"` + HttpStatus string `json:"http_status"` +} + +func (e *streamEvent) setValue(key, value string) { + switch key { + case streamEventIdItemKey: + e.Id = value + case streamEventNameItemKey: + e.Event = value + case streamDataItemKey: + e.Data = value + case streamBuiltInItemKey: + if strings.HasPrefix(value, streamHttpStatusValuePrefix) { + e.HttpStatus = value[len(streamHttpStatusValuePrefix):] + } + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 0513b4415..da1518fb6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -28,12 +28,11 @@ const ( ctxKeyStreamingBody = "streamingBody" ctxKeyOriginalRequestModel = "originalRequestModel" ctxKeyFinalRequestModel = "finalRequestModel" + ctxKeyPushedMessageContent = "pushedMessageContent" objectChatCompletion = "chat.completion" objectChatCompletionChunk = "chat.completion.chunk" - finishReasonStop = "stop" - wildcard = "*" defaultTimeout = 2 * 60 * 1000 // ms diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index ce8d8b809..aa0526c2a 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -24,13 +24,6 @@ const ( qwenTopPMin = 0.000001 qwenTopPMax = 0.999999 - - ctxKeyPushedMessageContent = "pushedMessageContent" - - streamIdItemKey = "id:" - streamDataItemKey = "data:" - streamEndDataValue = "[DONE]" - streamEventHeader = "event: result\n:HTTP_STATUS/200\n" ) type qwenProviderInitializer struct { @@ -190,10 +183,10 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api receivedBody = append(bufferedStreamingBody, chunk...) } - eventStartIndex, lineStartIndex, valueStartIndex := 0, -1, -1 + eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1 defer func() { - if eventStartIndex != -1 { + if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) { // Just in case the received chunk is not a complete event. ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:]) } else { @@ -202,14 +195,27 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api }() // Sample event response: + // + // event:result + // :HTTP_STATUS/200 + // data:{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":116,"input_tokens":114,"output_tokens":2},"request_id":"71689cfc-1f42-9949-86e8-9563b7f832b1"} + // + // event:error + // :HTTP_STATUS/400 + // data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"} + // var responseBuilder strings.Builder - currentEventId, currentKey := "", "" + currentKey := "" + currentEvent := &streamEvent{} i, length := 0, len(receivedBody) for i = 0; i < length; i++ { ch := receivedBody[i] if ch != '\n' { if lineStartIndex == -1 { + if eventStartIndex == -1 { + eventStartIndex = i + } lineStartIndex = i valueStartIndex = -1 } @@ -225,33 +231,25 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api continue } - if lineStartIndex == -1 { - // Extra new line, Should be an event separator. - eventStartIndex = i + 1 - continue + if lineStartIndex != -1 { + value := string(receivedBody[valueStartIndex:i]) + log.Debugf("key: %s value: %s", currentKey, value) + currentEvent.setValue(currentKey, value) + } else { + // Extra new line. The current event is complete. + log.Debugf("processing event: %v", currentEvent) + if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, log); err != nil { + return nil, err + } + // Reset event parsing state. + eventStartIndex = -1 + currentEvent = &streamEvent{} } - key := currentKey - value := receivedBody[valueStartIndex:i] - - // Reset message parsing state. - eventStartIndex = -1 + // Reset line parsing state. lineStartIndex = -1 valueStartIndex = -1 currentKey = "" - - switch key { - case streamIdItemKey: - currentEventId = string(value) - break - case streamDataItemKey: - if err := m.convertStreamEvent(ctx, &responseBuilder, currentEventId, value, log); err != nil { - return nil, err - } - break - default: - break - } } modifiedResponseChunk := responseBuilder.String() @@ -345,20 +343,20 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont return responses } -func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, eventId string, eventData []byte, log wrapper.Log) error { - if string(eventData) == streamEndDataValue { - responseBuilder.WriteString(streamIdItemKey) - responseBuilder.WriteString(eventId) - responseBuilder.WriteString("\n") - responseBuilder.WriteString(streamEventHeader) - responseBuilder.WriteString(streamDataItemKey) - responseBuilder.WriteString(streamEndDataValue) - responseBuilder.WriteString("\n\n") +func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error { + if event.Data == streamEndDataValue { + m.appendStreamEvent(responseBuilder, event) + return nil + } + + if event.Event != eventResult || event.HttpStatus != httpStatus200 { + // Something goes wrong. Just pass through the event. + m.appendStreamEvent(responseBuilder, event) return nil } qwenResponse := &qwenTextGenResponse{} - if err := json.Unmarshal(eventData, qwenResponse); err != nil { + if err := json.Unmarshal([]byte(event.Data), qwenResponse); err != nil { log.Errorf("unable to unmarshal Qwen response: %v", err) return fmt.Errorf("unable to unmarshal Qwen response: %v", err) } @@ -370,13 +368,9 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild log.Errorf("unable to marshal response: %v", err) return fmt.Errorf("unable to marshal response: %v", err) } - responseBuilder.WriteString(streamIdItemKey) - responseBuilder.WriteString(eventId) - responseBuilder.WriteString("\n") - responseBuilder.WriteString(streamEventHeader) - responseBuilder.WriteString(streamDataItemKey) - responseBuilder.Write(responseBody) - responseBuilder.WriteString("\n\n") + modifiedEvent := &*event + modifiedEvent.Data = string(responseBody) + m.appendStreamEvent(responseBuilder, modifiedEvent) } return nil @@ -404,6 +398,22 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content } } +func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) { + responseBuilder.WriteString(streamEventIdItemKey) + responseBuilder.WriteString(event.Id) + responseBuilder.WriteString("\n") + responseBuilder.WriteString(streamEventNameItemKey) + responseBuilder.WriteString(event.Event) + responseBuilder.WriteString("\n") + responseBuilder.WriteString(streamBuiltInItemKey) + responseBuilder.WriteString(streamHttpStatusValuePrefix) + responseBuilder.WriteString(event.HttpStatus) + responseBuilder.WriteString("\n") + responseBuilder.WriteString(streamDataItemKey) + responseBuilder.WriteString(event.Data) + responseBuilder.WriteString("\n\n") +} + type qwenTextGenRequest struct { Model string `json:"model"` Input qwenTextGenInput `json:"input"`