mirror of
https://github.com/alibaba/higress.git
synced 2026-06-08 20:27:31 +08:00
feat: Unify the SSE processing logic (#1800)
This commit is contained in:
@@ -102,7 +102,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
|||||||
|
|
||||||
// Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses,
|
// Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses,
|
||||||
// allowing plugins to inspect or modify the response correctly
|
// allowing plugins to inspect or modify the response correctly
|
||||||
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||||
|
|
||||||
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
||||||
// Set the apiToken for the current request.
|
// Set the apiToken for the current request.
|
||||||
@@ -110,13 +110,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
|||||||
|
|
||||||
err := handler.OnRequestHeaders(ctx, apiName, log)
|
err := handler.OnRequestHeaders(ctx, apiName, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
|
_ = util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
hasRequestBody := wrapper.HasRequestBody()
|
hasRequestBody := wrapper.HasRequestBody()
|
||||||
if hasRequestBody {
|
if hasRequestBody {
|
||||||
proxywasm.RemoveHttpRequestHeader("Content-Length")
|
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||||
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
||||||
// Delay the header processing to allow changing in OnRequestBody
|
// Delay the header processing to allow changing in OnRequestBody
|
||||||
return types.HeaderStopIteration
|
return types.HeaderStopIteration
|
||||||
@@ -143,7 +143,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
|||||||
|
|
||||||
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
|
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
|
||||||
if settingErr != nil {
|
if settingErr != nil {
|
||||||
util.ErrorHandler(
|
_ = util.ErrorHandler(
|
||||||
"ai-proxy.proc_req_body_failed",
|
"ai-proxy.proc_req_body_failed",
|
||||||
fmt.Errorf("failed to replace request body by custom settings: %v", settingErr),
|
fmt.Errorf("failed to replace request body by custom settings: %v", settingErr),
|
||||||
)
|
)
|
||||||
@@ -156,7 +156,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return action
|
return action
|
||||||
}
|
}
|
||||||
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
||||||
}
|
}
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
@@ -205,7 +205,11 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
|||||||
|
|
||||||
checkStream(ctx, log)
|
checkStream(ctx, log)
|
||||||
_, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler)
|
_, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler)
|
||||||
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
var needHandleStreamingBody bool
|
||||||
|
_, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler)
|
||||||
|
if !needHandleStreamingBody {
|
||||||
|
_, needHandleStreamingBody = activeProvider.(provider.StreamingEventHandler)
|
||||||
|
}
|
||||||
if !needHandleBody && !needHandleStreamingBody {
|
if !needHandleBody && !needHandleStreamingBody {
|
||||||
ctx.DontReadResponseBody()
|
ctx.DontReadResponseBody()
|
||||||
} else if !needHandleStreamingBody {
|
} else if !needHandleStreamingBody {
|
||||||
@@ -224,7 +228,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType())
|
log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType())
|
||||||
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
|
log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
|
||||||
|
|
||||||
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
|
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
|
||||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||||
@@ -234,6 +238,38 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
|||||||
}
|
}
|
||||||
return chunk
|
return chunk
|
||||||
}
|
}
|
||||||
|
if handler, ok := activeProvider.(provider.StreamingEventHandler); ok {
|
||||||
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||||
|
events := provider.ExtractStreamingEvents(ctx, chunk, log)
|
||||||
|
log.Debugf("[onStreamingResponseBody] %d events received", len(events))
|
||||||
|
if len(events) == 0 {
|
||||||
|
// No events are extracted, return the original chunk
|
||||||
|
return chunk
|
||||||
|
}
|
||||||
|
var responseBuilder strings.Builder
|
||||||
|
for _, event := range events {
|
||||||
|
log.Debugf("processing event: %v", event)
|
||||||
|
|
||||||
|
if event.IsEndData() {
|
||||||
|
responseBuilder.WriteString(event.ToHttpString())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event, log)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk)
|
||||||
|
return chunk
|
||||||
|
}
|
||||||
|
if outputEvents == nil || len(outputEvents) == 0 {
|
||||||
|
responseBuilder.WriteString(event.ToHttpString())
|
||||||
|
} else {
|
||||||
|
for _, outputEvent := range outputEvents {
|
||||||
|
responseBuilder.WriteString(outputEvent.ToHttpString())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return []byte(responseBuilder.String())
|
||||||
|
}
|
||||||
return chunk
|
return chunk
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,11 +287,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
|||||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||||
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
|
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
_ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
if err = provider.ReplaceResponseBody(body, log); err != nil {
|
if err = provider.ReplaceResponseBody(body, log); err != nil {
|
||||||
util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
|
|||||||
@@ -278,14 +278,18 @@ func (m *functionCall) IsEmpty() bool {
|
|||||||
return m.Name == "" && m.Arguments == ""
|
return m.Name == "" && m.Arguments == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamEvent struct {
|
type StreamEvent struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Event string `json:"event"`
|
Event string `json:"event"`
|
||||||
Data string `json:"data"`
|
Data string `json:"data"`
|
||||||
HttpStatus string `json:"http_status"`
|
HttpStatus string `json:"http_status"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *streamEvent) setValue(key, value string) {
|
func (e *StreamEvent) IsEndData() bool {
|
||||||
|
return e.Data == streamEndDataValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StreamEvent) SetValue(key, value string) {
|
||||||
switch key {
|
switch key {
|
||||||
case streamEventIdItemKey:
|
case streamEventIdItemKey:
|
||||||
e.Id = value
|
e.Id = value
|
||||||
@@ -300,6 +304,10 @@ func (e *streamEvent) setValue(key, value string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *StreamEvent) ToHttpString() string {
|
||||||
|
return fmt.Sprintf("%s %s\n\n", streamDataItemKey, e.Data)
|
||||||
|
}
|
||||||
|
|
||||||
// https://platform.openai.com/docs/guides/images
|
// https://platform.openai.com/docs/guides/images
|
||||||
type imageGenerationRequest struct {
|
type imageGenerationRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|||||||
@@ -102,12 +102,12 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
|
|||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to load context file: %v", err)
|
log.Errorf("failed to load context file: %v", err)
|
||||||
util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
|
_ = util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = m.performChatCompletion(ctx, content, request, log)
|
err = m.performChatCompletion(ctx, content, request, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err))
|
_ = util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err))
|
||||||
}
|
}
|
||||||
}, log)
|
}, log)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -161,79 +161,9 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
func (m *moonshotProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
|
||||||
if name != ApiNameChatCompletion {
|
if name != ApiNameChatCompletion {
|
||||||
return chunk, nil
|
return nil, nil
|
||||||
}
|
|
||||||
receivedBody := chunk
|
|
||||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
|
||||||
receivedBody = append(bufferedStreamingBody, chunk...)
|
|
||||||
}
|
|
||||||
|
|
||||||
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) {
|
|
||||||
// Just in case the received chunk is not a complete event.
|
|
||||||
ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:])
|
|
||||||
} else {
|
|
||||||
ctx.SetContext(ctxKeyStreamingBody, nil)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var responseBuilder strings.Builder
|
|
||||||
currentKey := ""
|
|
||||||
currentEvent := &streamEvent{}
|
|
||||||
i, length := 0, len(receivedBody)
|
|
||||||
for i = 0; i < length; i++ {
|
|
||||||
ch := receivedBody[i]
|
|
||||||
if ch != '\n' {
|
|
||||||
if lineStartIndex == -1 {
|
|
||||||
if eventStartIndex == -1 {
|
|
||||||
eventStartIndex = i
|
|
||||||
}
|
|
||||||
lineStartIndex = i
|
|
||||||
valueStartIndex = -1
|
|
||||||
}
|
|
||||||
if valueStartIndex == -1 {
|
|
||||||
if ch == ':' {
|
|
||||||
valueStartIndex = i + 1
|
|
||||||
currentKey = string(receivedBody[lineStartIndex:valueStartIndex])
|
|
||||||
}
|
|
||||||
} else if valueStartIndex == i && ch == ' ' {
|
|
||||||
// Skip leading spaces in data.
|
|
||||||
valueStartIndex = i + 1
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if lineStartIndex != -1 {
|
|
||||||
value := string(receivedBody[valueStartIndex:i])
|
|
||||||
currentEvent.setValue(currentKey, value)
|
|
||||||
} else {
|
|
||||||
// Extra new line. The current event is complete.
|
|
||||||
log.Debugf("processing event: %v", currentEvent)
|
|
||||||
m.convertStreamEvent(&responseBuilder, currentEvent, log)
|
|
||||||
// Reset event parsing state.
|
|
||||||
eventStartIndex = -1
|
|
||||||
currentEvent = &streamEvent{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset line parsing state.
|
|
||||||
lineStartIndex = -1
|
|
||||||
valueStartIndex = -1
|
|
||||||
currentKey = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
modifiedResponseChunk := responseBuilder.String()
|
|
||||||
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
|
|
||||||
return []byte(modifiedResponseChunk), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error {
|
|
||||||
if event.Data == streamEndDataValue {
|
|
||||||
m.appendStreamEvent(responseBuilder, event)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if gjson.Get(event.Data, "choices.0.usage").Exists() {
|
if gjson.Get(event.Data, "choices.0.usage").Exists() {
|
||||||
@@ -241,20 +171,19 @@ func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder,
|
|||||||
newData, err := sjson.Delete(event.Data, "choices.0.usage")
|
newData, err := sjson.Delete(event.Data, "choices.0.usage")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("convert usage event error: %v", err)
|
log.Errorf("convert usage event error: %v", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
newData, err = sjson.SetRaw(newData, "usage", usageStr)
|
newData, err = sjson.SetRaw(newData, "usage", usageStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("convert usage event error: %v", err)
|
log.Errorf("convert usage event error: %v", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
event.Data = newData
|
event.Data = newData
|
||||||
}
|
}
|
||||||
m.appendStreamEvent(responseBuilder, event)
|
return []StreamEvent{event}, nil
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
|
func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *StreamEvent) {
|
||||||
responseBuilder.WriteString(streamDataItemKey)
|
responseBuilder.WriteString(streamDataItemKey)
|
||||||
responseBuilder.WriteString(event.Data)
|
responseBuilder.WriteString(event.Data)
|
||||||
responseBuilder.WriteString("\n\n")
|
responseBuilder.WriteString("\n\n")
|
||||||
|
|||||||
@@ -149,6 +149,10 @@ type StreamingResponseBodyHandler interface {
|
|||||||
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
|
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamingEventHandler interface {
|
||||||
|
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error)
|
||||||
|
}
|
||||||
|
|
||||||
type ApiNameHandler interface {
|
type ApiNameHandler interface {
|
||||||
GetApiName(path string) ApiName
|
GetApiName(path string) ApiName
|
||||||
}
|
}
|
||||||
@@ -575,6 +579,81 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []StreamEvent {
|
||||||
|
body := chunk
|
||||||
|
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||||
|
body = append(bufferedStreamingBody, chunk...)
|
||||||
|
}
|
||||||
|
|
||||||
|
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if eventStartIndex >= 0 && eventStartIndex < len(body) {
|
||||||
|
// Just in case the received chunk is not a complete event.
|
||||||
|
ctx.SetContext(ctxKeyStreamingBody, body[eventStartIndex:])
|
||||||
|
} else {
|
||||||
|
ctx.SetContext(ctxKeyStreamingBody, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Sample Qwen event response:
|
||||||
|
//
|
||||||
|
// event:result
|
||||||
|
// :HTTP_STATUS/200
|
||||||
|
// data:{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":116,"input_tokens":114,"output_tokens":2},"request_id":"71689cfc-1f42-9949-86e8-9563b7f832b1"}
|
||||||
|
//
|
||||||
|
// event:error
|
||||||
|
// :HTTP_STATUS/400
|
||||||
|
// data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"}
|
||||||
|
//
|
||||||
|
|
||||||
|
var events []StreamEvent
|
||||||
|
|
||||||
|
currentKey := ""
|
||||||
|
currentEvent := &StreamEvent{}
|
||||||
|
i, length := 0, len(body)
|
||||||
|
for i = 0; i < length; i++ {
|
||||||
|
ch := body[i]
|
||||||
|
if ch != '\n' {
|
||||||
|
if lineStartIndex == -1 {
|
||||||
|
if eventStartIndex == -1 {
|
||||||
|
eventStartIndex = i
|
||||||
|
}
|
||||||
|
lineStartIndex = i
|
||||||
|
valueStartIndex = -1
|
||||||
|
}
|
||||||
|
if valueStartIndex == -1 {
|
||||||
|
if ch == ':' {
|
||||||
|
valueStartIndex = i + 1
|
||||||
|
currentKey = string(body[lineStartIndex:valueStartIndex])
|
||||||
|
}
|
||||||
|
} else if valueStartIndex == i && ch == ' ' {
|
||||||
|
// Skip leading spaces in data.
|
||||||
|
valueStartIndex = i + 1
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if lineStartIndex != -1 {
|
||||||
|
value := string(body[valueStartIndex:i])
|
||||||
|
currentEvent.SetValue(currentKey, value)
|
||||||
|
} else {
|
||||||
|
// Extra new line. The current event is complete.
|
||||||
|
events = append(events, *currentEvent)
|
||||||
|
// Reset event parsing state.
|
||||||
|
eventStartIndex = -1
|
||||||
|
currentEvent = &StreamEvent{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset line parsing state.
|
||||||
|
lineStartIndex = -1
|
||||||
|
valueStartIndex = -1
|
||||||
|
currentKey = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
|
func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
|
||||||
_, exist := c.capabilities[string(apiName)]
|
_, exist := c.capabilities[string(apiName)]
|
||||||
return exist
|
return exist
|
||||||
|
|||||||
@@ -188,89 +188,32 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b
|
|||||||
return json.Marshal(qwenRequest)
|
return json.Marshal(qwenRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
|
||||||
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
|
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
|
||||||
return chunk, nil
|
return nil, nil
|
||||||
}
|
|
||||||
|
|
||||||
receivedBody := chunk
|
|
||||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
|
||||||
receivedBody = append(bufferedStreamingBody, chunk...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
incrementalStreaming := ctx.GetBoolContext(ctxKeyIncrementalStreaming, false)
|
incrementalStreaming := ctx.GetBoolContext(ctxKeyIncrementalStreaming, false)
|
||||||
|
|
||||||
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
|
qwenResponse := &qwenTextGenResponse{}
|
||||||
|
if err := json.Unmarshal([]byte(event.Data), qwenResponse); err != nil {
|
||||||
defer func() {
|
log.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||||
if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) {
|
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||||
// Just in case the received chunk is not a complete event.
|
|
||||||
ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:])
|
|
||||||
} else {
|
|
||||||
ctx.SetContext(ctxKeyStreamingBody, nil)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Sample Qwen event response:
|
|
||||||
//
|
|
||||||
// event:result
|
|
||||||
// :HTTP_STATUS/200
|
|
||||||
// data:{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":116,"input_tokens":114,"output_tokens":2},"request_id":"71689cfc-1f42-9949-86e8-9563b7f832b1"}
|
|
||||||
//
|
|
||||||
// event:error
|
|
||||||
// :HTTP_STATUS/400
|
|
||||||
// data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"}
|
|
||||||
//
|
|
||||||
|
|
||||||
var responseBuilder strings.Builder
|
|
||||||
currentKey := ""
|
|
||||||
currentEvent := &streamEvent{}
|
|
||||||
i, length := 0, len(receivedBody)
|
|
||||||
for i = 0; i < length; i++ {
|
|
||||||
ch := receivedBody[i]
|
|
||||||
if ch != '\n' {
|
|
||||||
if lineStartIndex == -1 {
|
|
||||||
if eventStartIndex == -1 {
|
|
||||||
eventStartIndex = i
|
|
||||||
}
|
|
||||||
lineStartIndex = i
|
|
||||||
valueStartIndex = -1
|
|
||||||
}
|
|
||||||
if valueStartIndex == -1 {
|
|
||||||
if ch == ':' {
|
|
||||||
valueStartIndex = i + 1
|
|
||||||
currentKey = string(receivedBody[lineStartIndex:valueStartIndex])
|
|
||||||
}
|
|
||||||
} else if valueStartIndex == i && ch == ' ' {
|
|
||||||
// Skip leading spaces in data.
|
|
||||||
valueStartIndex = i + 1
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if lineStartIndex != -1 {
|
|
||||||
value := string(receivedBody[valueStartIndex:i])
|
|
||||||
currentEvent.setValue(currentKey, value)
|
|
||||||
} else {
|
|
||||||
// Extra new line. The current event is complete.
|
|
||||||
log.Debugf("processing event: %v", currentEvent)
|
|
||||||
if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, incrementalStreaming, log); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Reset event parsing state.
|
|
||||||
eventStartIndex = -1
|
|
||||||
currentEvent = &streamEvent{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset line parsing state.
|
|
||||||
lineStartIndex = -1
|
|
||||||
valueStartIndex = -1
|
|
||||||
currentKey = ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modifiedResponseChunk := responseBuilder.String()
|
var outputEvents []StreamEvent
|
||||||
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
|
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
|
||||||
return []byte(modifiedResponseChunk), nil
|
for _, response := range responses {
|
||||||
|
responseBody, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to marshal response: %v", err)
|
||||||
|
return nil, fmt.Errorf("unable to marshal response: %v", err)
|
||||||
|
}
|
||||||
|
modifiedEvent := event
|
||||||
|
modifiedEvent.Data = string(responseBody)
|
||||||
|
outputEvents = append(outputEvents, modifiedEvent)
|
||||||
|
}
|
||||||
|
return outputEvents, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||||
@@ -481,39 +424,6 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
|
|||||||
return responses
|
return responses
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, incrementalStreaming bool, log wrapper.Log) error {
|
|
||||||
if event.Data == streamEndDataValue {
|
|
||||||
m.appendStreamEvent(responseBuilder, event)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if event.Event != eventResult || event.HttpStatus != httpStatus200 {
|
|
||||||
// Something goes wrong. Just pass through the event.
|
|
||||||
m.appendStreamEvent(responseBuilder, event)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
qwenResponse := &qwenTextGenResponse{}
|
|
||||||
if err := json.Unmarshal([]byte(event.Data), qwenResponse); err != nil {
|
|
||||||
log.Errorf("unable to unmarshal Qwen response: %v", err)
|
|
||||||
return fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
|
|
||||||
for _, response := range responses {
|
|
||||||
responseBody, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to marshal response: %v", err)
|
|
||||||
return fmt.Errorf("unable to marshal response: %v", err)
|
|
||||||
}
|
|
||||||
modifiedEvent := &*event
|
|
||||||
modifiedEvent.Data = string(responseBody)
|
|
||||||
m.appendStreamEvent(responseBuilder, modifiedEvent)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
|
func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
|
||||||
request := &qwenTextGenRequest{}
|
request := &qwenTextGenRequest{}
|
||||||
if err := json.Unmarshal(body, request); err != nil {
|
if err := json.Unmarshal(body, request); err != nil {
|
||||||
@@ -558,7 +468,7 @@ func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onl
|
|||||||
return json.Marshal(request)
|
return json.Marshal(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
|
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *StreamEvent) {
|
||||||
responseBuilder.WriteString(streamDataItemKey)
|
responseBuilder.WriteString(streamDataItemKey)
|
||||||
responseBuilder.WriteString(event.Data)
|
responseBuilder.WriteString(event.Data)
|
||||||
responseBuilder.WriteString("\n\n")
|
responseBuilder.WriteString("\n\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user