diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 73ff7eafa..4a2d1fb98 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -102,7 +102,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf // Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses, // allowing plugins to inspect or modify the response correctly - proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { // Set the apiToken for the current request. @@ -110,13 +110,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf err := handler.OnRequestHeaders(ctx, apiName, log) if err != nil { - util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) + _ = util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) return types.ActionContinue } hasRequestBody := wrapper.HasRequestBody() if hasRequestBody { - proxywasm.RemoveHttpRequestHeader("Content-Length") + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) // Delay the header processing to allow changing in OnRequestBody return types.HeaderStopIteration @@ -143,7 +143,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body) if settingErr != nil { - util.ErrorHandler( + _ = util.ErrorHandler( "ai-proxy.proc_req_body_failed", fmt.Errorf("failed to replace request body by custom settings: %v", settingErr), ) @@ -156,7 +156,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig if err == nil { return action } - util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) + _ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) } return types.ActionContinue } @@ -205,7 +205,11 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo checkStream(ctx, log) _, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler) - _, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler) + var needHandleStreamingBody bool + _, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler) + if !needHandleStreamingBody { + _, needHandleStreamingBody = activeProvider.(provider.StreamingEventHandler) + } if !needHandleBody && !needHandleStreamingBody { ctx.DontReadResponseBody() } else if !needHandleStreamingBody { @@ -224,7 +228,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin } log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType()) - log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk)) + log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk)) if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) @@ -234,6 +238,38 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin } return chunk } + if handler, ok := activeProvider.(provider.StreamingEventHandler); ok { + apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) + events := provider.ExtractStreamingEvents(ctx, chunk, log) + log.Debugf("[onStreamingResponseBody] %d events received", len(events)) + if len(events) == 0 { + // No events are extracted, return the original chunk + return chunk + } + var responseBuilder strings.Builder + for _, event := range events { + log.Debugf("processing event: %v", event) + + if event.IsEndData() { + responseBuilder.WriteString(event.ToHttpString()) + continue + } + + outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event, log) + if err != nil { + log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk) + return chunk + } + if outputEvents == nil || len(outputEvents) == 0 { + responseBuilder.WriteString(event.ToHttpString()) + } else { + for _, outputEvent := range outputEvents { + responseBuilder.WriteString(outputEvent.ToHttpString()) + } + } + } + return []byte(responseBuilder.String()) + } return chunk } @@ -251,11 +287,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) body, err := handler.TransformResponseBody(ctx, apiName, body, log) if err != nil { - util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err)) + _ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err)) return types.ActionContinue } if err = provider.ReplaceResponseBody(body, log); err != nil { - util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err)) + _ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err)) } } return types.ActionContinue diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index b38b4fde8..7de9cfe2f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -278,14 +278,18 @@ func (m *functionCall) IsEmpty() bool { return m.Name == "" && m.Arguments == "" } -type streamEvent struct { +type StreamEvent struct { Id string `json:"id"` Event string `json:"event"` Data string `json:"data"` HttpStatus string `json:"http_status"` } -func (e *streamEvent) setValue(key, value string) { +func (e *StreamEvent) IsEndData() bool { + return e.Data == streamEndDataValue +} + +func (e *StreamEvent) SetValue(key, value string) { switch key { case streamEventIdItemKey: e.Id = value @@ -300,6 +304,10 @@ func (e *streamEvent) setValue(key, value string) { } } +func (e *StreamEvent) ToHttpString() string { + return fmt.Sprintf("%s %s\n\n", streamDataItemKey, e.Data) +} + // https://platform.openai.com/docs/guides/images type imageGenerationRequest struct { Model string `json:"model"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index f0f63cf79..46fa68c73 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -102,12 +102,12 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam }() if err != nil { log.Errorf("failed to load context file: %v", err) - util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err)) + _ = util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err)) return } err = m.performChatCompletion(ctx, content, request, log) if err != nil { - util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err)) + _ = util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err)) } }, log) if err == nil { @@ -161,79 +161,9 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba } } -func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { +func (m *moonshotProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) { if name != ApiNameChatCompletion { - return chunk, nil - } - receivedBody := chunk - if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { - receivedBody = append(bufferedStreamingBody, chunk...) - } - - eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1 - - defer func() { - if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) { - // Just in case the received chunk is not a complete event. - ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:]) - } else { - ctx.SetContext(ctxKeyStreamingBody, nil) - } - }() - - var responseBuilder strings.Builder - currentKey := "" - currentEvent := &streamEvent{} - i, length := 0, len(receivedBody) - for i = 0; i < length; i++ { - ch := receivedBody[i] - if ch != '\n' { - if lineStartIndex == -1 { - if eventStartIndex == -1 { - eventStartIndex = i - } - lineStartIndex = i - valueStartIndex = -1 - } - if valueStartIndex == -1 { - if ch == ':' { - valueStartIndex = i + 1 - currentKey = string(receivedBody[lineStartIndex:valueStartIndex]) - } - } else if valueStartIndex == i && ch == ' ' { - // Skip leading spaces in data. - valueStartIndex = i + 1 - } - continue - } - - if lineStartIndex != -1 { - value := string(receivedBody[valueStartIndex:i]) - currentEvent.setValue(currentKey, value) - } else { - // Extra new line. The current event is complete. - log.Debugf("processing event: %v", currentEvent) - m.convertStreamEvent(&responseBuilder, currentEvent, log) - // Reset event parsing state. - eventStartIndex = -1 - currentEvent = &streamEvent{} - } - - // Reset line parsing state. - lineStartIndex = -1 - valueStartIndex = -1 - currentKey = "" - } - - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil -} - -func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error { - if event.Data == streamEndDataValue { - m.appendStreamEvent(responseBuilder, event) - return nil + return nil, nil } if gjson.Get(event.Data, "choices.0.usage").Exists() { @@ -241,20 +171,19 @@ func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder, newData, err := sjson.Delete(event.Data, "choices.0.usage") if err != nil { log.Errorf("convert usage event error: %v", err) - return err + return nil, err } newData, err = sjson.SetRaw(newData, "usage", usageStr) if err != nil { log.Errorf("convert usage event error: %v", err) - return err + return nil, err } event.Data = newData } - m.appendStreamEvent(responseBuilder, event) - return nil + return []StreamEvent{event}, nil } -func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) { +func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *StreamEvent) { responseBuilder.WriteString(streamDataItemKey) responseBuilder.WriteString(event.Data) responseBuilder.WriteString("\n\n") diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 46e6b4ed7..e3c54bfce 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -149,6 +149,10 @@ type StreamingResponseBodyHandler interface { OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) } +type StreamingEventHandler interface { + OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) +} + type ApiNameHandler interface { GetApiName(path string) ApiName } @@ -575,6 +579,81 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper. return "" } +func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []StreamEvent { + body := chunk + if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { + body = append(bufferedStreamingBody, chunk...) + } + + eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1 + + defer func() { + if eventStartIndex >= 0 && eventStartIndex < len(body) { + // Just in case the received chunk is not a complete event. + ctx.SetContext(ctxKeyStreamingBody, body[eventStartIndex:]) + } else { + ctx.SetContext(ctxKeyStreamingBody, nil) + } + }() + + // Sample Qwen event response: + // + // event:result + // :HTTP_STATUS/200 + // data:{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":116,"input_tokens":114,"output_tokens":2},"request_id":"71689cfc-1f42-9949-86e8-9563b7f832b1"} + // + // event:error + // :HTTP_STATUS/400 + // data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"} + // + + var events []StreamEvent + + currentKey := "" + currentEvent := &StreamEvent{} + i, length := 0, len(body) + for i = 0; i < length; i++ { + ch := body[i] + if ch != '\n' { + if lineStartIndex == -1 { + if eventStartIndex == -1 { + eventStartIndex = i + } + lineStartIndex = i + valueStartIndex = -1 + } + if valueStartIndex == -1 { + if ch == ':' { + valueStartIndex = i + 1 + currentKey = string(body[lineStartIndex:valueStartIndex]) + } + } else if valueStartIndex == i && ch == ' ' { + // Skip leading spaces in data. + valueStartIndex = i + 1 + } + continue + } + + if lineStartIndex != -1 { + value := string(body[valueStartIndex:i]) + currentEvent.SetValue(currentKey, value) + } else { + // Extra new line. The current event is complete. + events = append(events, *currentEvent) + // Reset event parsing state. + eventStartIndex = -1 + currentEvent = &StreamEvent{} + } + + // Reset line parsing state. + lineStartIndex = -1 + valueStartIndex = -1 + currentKey = "" + } + + return events +} + func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool { _, exist := c.capabilities[string(apiName)] return exist diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index fd55eee22..4bb39c121 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -188,89 +188,32 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b return json.Marshal(qwenRequest) } -func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { +func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) { if m.config.qwenEnableCompatible || name != ApiNameChatCompletion { - return chunk, nil - } - - receivedBody := chunk - if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has { - receivedBody = append(bufferedStreamingBody, chunk...) + return nil, nil } incrementalStreaming := ctx.GetBoolContext(ctxKeyIncrementalStreaming, false) - eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1 - - defer func() { - if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) { - // Just in case the received chunk is not a complete event. - ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:]) - } else { - ctx.SetContext(ctxKeyStreamingBody, nil) - } - }() - - // Sample Qwen event response: - // - // event:result - // :HTTP_STATUS/200 - // data:{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":116,"input_tokens":114,"output_tokens":2},"request_id":"71689cfc-1f42-9949-86e8-9563b7f832b1"} - // - // event:error - // :HTTP_STATUS/400 - // data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"} - // - - var responseBuilder strings.Builder - currentKey := "" - currentEvent := &streamEvent{} - i, length := 0, len(receivedBody) - for i = 0; i < length; i++ { - ch := receivedBody[i] - if ch != '\n' { - if lineStartIndex == -1 { - if eventStartIndex == -1 { - eventStartIndex = i - } - lineStartIndex = i - valueStartIndex = -1 - } - if valueStartIndex == -1 { - if ch == ':' { - valueStartIndex = i + 1 - currentKey = string(receivedBody[lineStartIndex:valueStartIndex]) - } - } else if valueStartIndex == i && ch == ' ' { - // Skip leading spaces in data. - valueStartIndex = i + 1 - } - continue - } - - if lineStartIndex != -1 { - value := string(receivedBody[valueStartIndex:i]) - currentEvent.setValue(currentKey, value) - } else { - // Extra new line. The current event is complete. - log.Debugf("processing event: %v", currentEvent) - if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, incrementalStreaming, log); err != nil { - return nil, err - } - // Reset event parsing state. - eventStartIndex = -1 - currentEvent = &streamEvent{} - } - - // Reset line parsing state. - lineStartIndex = -1 - valueStartIndex = -1 - currentKey = "" + qwenResponse := &qwenTextGenResponse{} + if err := json.Unmarshal([]byte(event.Data), qwenResponse); err != nil { + log.Errorf("unable to unmarshal Qwen response: %v", err) + return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err) } - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + var outputEvents []StreamEvent + responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log) + for _, response := range responses { + responseBody, err := json.Marshal(response) + if err != nil { + log.Errorf("unable to marshal response: %v", err) + return nil, fmt.Errorf("unable to marshal response: %v", err) + } + modifiedEvent := event + modifiedEvent.Data = string(responseBody) + outputEvents = append(outputEvents, modifiedEvent) + } + return outputEvents, nil } func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { @@ -481,39 +424,6 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont return responses } -func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, incrementalStreaming bool, log wrapper.Log) error { - if event.Data == streamEndDataValue { - m.appendStreamEvent(responseBuilder, event) - return nil - } - - if event.Event != eventResult || event.HttpStatus != httpStatus200 { - // Something goes wrong. Just pass through the event. - m.appendStreamEvent(responseBuilder, event) - return nil - } - - qwenResponse := &qwenTextGenResponse{} - if err := json.Unmarshal([]byte(event.Data), qwenResponse); err != nil { - log.Errorf("unable to unmarshal Qwen response: %v", err) - return fmt.Errorf("unable to unmarshal Qwen response: %v", err) - } - - responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log) - for _, response := range responses { - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return fmt.Errorf("unable to marshal response: %v", err) - } - modifiedEvent := &*event - modifiedEvent.Data = string(responseBody) - m.appendStreamEvent(responseBuilder, modifiedEvent) - } - - return nil -} - func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) { request := &qwenTextGenRequest{} if err := json.Unmarshal(body, request); err != nil { @@ -558,7 +468,7 @@ func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onl return json.Marshal(request) } -func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) { +func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *StreamEvent) { responseBuilder.WriteString(streamDataItemKey) responseBuilder.WriteString(event.Data) responseBuilder.WriteString("\n\n")