From 22790aa14904e46a45d4266968c153bfdc7b8600 Mon Sep 17 00:00:00 2001 From: rinfx <893383980@qq.com> Date: Thu, 5 Dec 2024 11:35:25 +0800 Subject: [PATCH] fix moonshot usage compatible problem (#1568) --- .../extensions/ai-proxy/provider/moonshot.go | 101 +++++++++++++++++- 1 file changed, 100 insertions(+), 1 deletion(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index cb914d8c8..117f8dfb4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -3,12 +3,15 @@ package provider import ( "errors" "fmt" + "net/http" + "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" - "net/http" + "github.com/tidwall/sjson" ) // moonshotProvider is the provider for Moonshot AI service. @@ -149,3 +152,99 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba return errors.New("unsupported method: " + method) } } + +func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { + receivedBody := chunk + if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { + receivedBody = append(bufferedStreamingBody, chunk...) + } + + eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1 + + defer func() { + if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) { + // Just in case the received chunk is not a complete event. + ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:]) + } else { + ctx.SetContext(ctxKeyStreamingBody, nil) + } + }() + + var responseBuilder strings.Builder + currentKey := "" + currentEvent := &streamEvent{} + i, length := 0, len(receivedBody) + for i = 0; i < length; i++ { + ch := receivedBody[i] + if ch != '\n' { + if lineStartIndex == -1 { + if eventStartIndex == -1 { + eventStartIndex = i + } + lineStartIndex = i + valueStartIndex = -1 + } + if valueStartIndex == -1 { + if ch == ':' { + valueStartIndex = i + 1 + currentKey = string(receivedBody[lineStartIndex:valueStartIndex]) + } + } else if valueStartIndex == i && ch == ' ' { + // Skip leading spaces in data. + valueStartIndex = i + 1 + } + continue + } + + if lineStartIndex != -1 { + value := string(receivedBody[valueStartIndex:i]) + currentEvent.setValue(currentKey, value) + } else { + // Extra new line. The current event is complete. + log.Debugf("processing event: %v", currentEvent) + m.convertStreamEvent(&responseBuilder, currentEvent, log) + // Reset event parsing state. + eventStartIndex = -1 + currentEvent = &streamEvent{} + } + + // Reset line parsing state. + lineStartIndex = -1 + valueStartIndex = -1 + currentKey = "" + } + + modifiedResponseChunk := responseBuilder.String() + log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) + return []byte(modifiedResponseChunk), nil +} + +func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error { + if event.Data == streamEndDataValue { + m.appendStreamEvent(responseBuilder, event) + return nil + } + + if gjson.Get(event.Data, "choices.0.usage").Exists() { + usageStr := gjson.Get(event.Data, "choices.0.usage").Raw + newData, err := sjson.Delete(event.Data, "choices.0.usage") + if err != nil { + log.Errorf("convert usage event error: %v", err) + return err + } + newData, err = sjson.SetRaw(newData, "usage", usageStr) + if err != nil { + log.Errorf("convert usage event error: %v", err) + return err + } + event.Data = newData + } + m.appendStreamEvent(responseBuilder, event) + return nil +} + +func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) { + responseBuilder.WriteString(streamDataItemKey) + responseBuilder.WriteString(event.Data) + responseBuilder.WriteString("\n\n") +}