feat: Unify the SSE processing logic (#1800)

This commit is contained in:
Kent Dong
2025-02-25 11:00:18 +08:00
committed by GitHub
parent 53a015d8fe
commit 49617c7a98
5 changed files with 162 additions and 200 deletions

View File

@@ -278,14 +278,18 @@ func (m *functionCall) IsEmpty() bool {
return m.Name == "" && m.Arguments == ""
}
type streamEvent struct {
type StreamEvent struct {
Id string `json:"id"`
Event string `json:"event"`
Data string `json:"data"`
HttpStatus string `json:"http_status"`
}
func (e *streamEvent) setValue(key, value string) {
func (e *StreamEvent) IsEndData() bool {
return e.Data == streamEndDataValue
}
func (e *StreamEvent) SetValue(key, value string) {
switch key {
case streamEventIdItemKey:
e.Id = value
@@ -300,6 +304,10 @@ func (e *streamEvent) setValue(key, value string) {
}
}
func (e *StreamEvent) ToHttpString() string {
return fmt.Sprintf("%s %s\n\n", streamDataItemKey, e.Data)
}
// https://platform.openai.com/docs/guides/images
type imageGenerationRequest struct {
Model string `json:"model"`

View File

@@ -102,12 +102,12 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
_ = util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
return
}
err = m.performChatCompletion(ctx, content, request, log)
if err != nil {
util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err))
_ = util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err))
}
}, log)
if err == nil {
@@ -161,79 +161,9 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
}
}
func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
func (m *moonshotProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
if name != ApiNameChatCompletion {
return chunk, nil
}
receivedBody := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
receivedBody = append(bufferedStreamingBody, chunk...)
}
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
defer func() {
if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) {
// Just in case the received chunk is not a complete event.
ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:])
} else {
ctx.SetContext(ctxKeyStreamingBody, nil)
}
}()
var responseBuilder strings.Builder
currentKey := ""
currentEvent := &streamEvent{}
i, length := 0, len(receivedBody)
for i = 0; i < length; i++ {
ch := receivedBody[i]
if ch != '\n' {
if lineStartIndex == -1 {
if eventStartIndex == -1 {
eventStartIndex = i
}
lineStartIndex = i
valueStartIndex = -1
}
if valueStartIndex == -1 {
if ch == ':' {
valueStartIndex = i + 1
currentKey = string(receivedBody[lineStartIndex:valueStartIndex])
}
} else if valueStartIndex == i && ch == ' ' {
// Skip leading spaces in data.
valueStartIndex = i + 1
}
continue
}
if lineStartIndex != -1 {
value := string(receivedBody[valueStartIndex:i])
currentEvent.setValue(currentKey, value)
} else {
// Extra new line. The current event is complete.
log.Debugf("processing event: %v", currentEvent)
m.convertStreamEvent(&responseBuilder, currentEvent, log)
// Reset event parsing state.
eventStartIndex = -1
currentEvent = &streamEvent{}
}
// Reset line parsing state.
lineStartIndex = -1
valueStartIndex = -1
currentKey = ""
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
}
func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error {
if event.Data == streamEndDataValue {
m.appendStreamEvent(responseBuilder, event)
return nil
return nil, nil
}
if gjson.Get(event.Data, "choices.0.usage").Exists() {
@@ -241,20 +171,19 @@ func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder,
newData, err := sjson.Delete(event.Data, "choices.0.usage")
if err != nil {
log.Errorf("convert usage event error: %v", err)
return err
return nil, err
}
newData, err = sjson.SetRaw(newData, "usage", usageStr)
if err != nil {
log.Errorf("convert usage event error: %v", err)
return err
return nil, err
}
event.Data = newData
}
m.appendStreamEvent(responseBuilder, event)
return nil
return []StreamEvent{event}, nil
}
func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *StreamEvent) {
responseBuilder.WriteString(streamDataItemKey)
responseBuilder.WriteString(event.Data)
responseBuilder.WriteString("\n\n")

View File

@@ -149,6 +149,10 @@ type StreamingResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
}
type StreamingEventHandler interface {
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error)
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
@@ -575,6 +579,81 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []StreamEvent {
body := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
body = append(bufferedStreamingBody, chunk...)
}
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
defer func() {
if eventStartIndex >= 0 && eventStartIndex < len(body) {
// Just in case the received chunk is not a complete event.
ctx.SetContext(ctxKeyStreamingBody, body[eventStartIndex:])
} else {
ctx.SetContext(ctxKeyStreamingBody, nil)
}
}()
// Sample Qwen event response:
//
// event:result
// :HTTP_STATUS/200
// data:{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":116,"input_tokens":114,"output_tokens":2},"request_id":"71689cfc-1f42-9949-86e8-9563b7f832b1"}
//
// event:error
// :HTTP_STATUS/400
// data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"}
//
var events []StreamEvent
currentKey := ""
currentEvent := &StreamEvent{}
i, length := 0, len(body)
for i = 0; i < length; i++ {
ch := body[i]
if ch != '\n' {
if lineStartIndex == -1 {
if eventStartIndex == -1 {
eventStartIndex = i
}
lineStartIndex = i
valueStartIndex = -1
}
if valueStartIndex == -1 {
if ch == ':' {
valueStartIndex = i + 1
currentKey = string(body[lineStartIndex:valueStartIndex])
}
} else if valueStartIndex == i && ch == ' ' {
// Skip leading spaces in data.
valueStartIndex = i + 1
}
continue
}
if lineStartIndex != -1 {
value := string(body[valueStartIndex:i])
currentEvent.SetValue(currentKey, value)
} else {
// Extra new line. The current event is complete.
events = append(events, *currentEvent)
// Reset event parsing state.
eventStartIndex = -1
currentEvent = &StreamEvent{}
}
// Reset line parsing state.
lineStartIndex = -1
valueStartIndex = -1
currentKey = ""
}
return events
}
func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
_, exist := c.capabilities[string(apiName)]
return exist

View File

@@ -188,89 +188,32 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b
return json.Marshal(qwenRequest)
}
func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
return chunk, nil
}
receivedBody := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
receivedBody = append(bufferedStreamingBody, chunk...)
return nil, nil
}
incrementalStreaming := ctx.GetBoolContext(ctxKeyIncrementalStreaming, false)
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
defer func() {
if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) {
// Just in case the received chunk is not a complete event.
ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:])
} else {
ctx.SetContext(ctxKeyStreamingBody, nil)
}
}()
// Sample Qwen event response:
//
// event:result
// :HTTP_STATUS/200
// data:{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":116,"input_tokens":114,"output_tokens":2},"request_id":"71689cfc-1f42-9949-86e8-9563b7f832b1"}
//
// event:error
// :HTTP_STATUS/400
// data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"}
//
var responseBuilder strings.Builder
currentKey := ""
currentEvent := &streamEvent{}
i, length := 0, len(receivedBody)
for i = 0; i < length; i++ {
ch := receivedBody[i]
if ch != '\n' {
if lineStartIndex == -1 {
if eventStartIndex == -1 {
eventStartIndex = i
}
lineStartIndex = i
valueStartIndex = -1
}
if valueStartIndex == -1 {
if ch == ':' {
valueStartIndex = i + 1
currentKey = string(receivedBody[lineStartIndex:valueStartIndex])
}
} else if valueStartIndex == i && ch == ' ' {
// Skip leading spaces in data.
valueStartIndex = i + 1
}
continue
}
if lineStartIndex != -1 {
value := string(receivedBody[valueStartIndex:i])
currentEvent.setValue(currentKey, value)
} else {
// Extra new line. The current event is complete.
log.Debugf("processing event: %v", currentEvent)
if err := m.convertStreamEvent(ctx, &responseBuilder, currentEvent, incrementalStreaming, log); err != nil {
return nil, err
}
// Reset event parsing state.
eventStartIndex = -1
currentEvent = &streamEvent{}
}
// Reset line parsing state.
lineStartIndex = -1
valueStartIndex = -1
currentKey = ""
qwenResponse := &qwenTextGenResponse{}
if err := json.Unmarshal([]byte(event.Data), qwenResponse); err != nil {
log.Errorf("unable to unmarshal Qwen response: %v", err)
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
var outputEvents []StreamEvent
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
for _, response := range responses {
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, fmt.Errorf("unable to marshal response: %v", err)
}
modifiedEvent := event
modifiedEvent.Data = string(responseBody)
outputEvents = append(outputEvents, modifiedEvent)
}
return outputEvents, nil
}
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
@@ -481,39 +424,6 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
return responses
}
func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuilder *strings.Builder, event *streamEvent, incrementalStreaming bool, log wrapper.Log) error {
if event.Data == streamEndDataValue {
m.appendStreamEvent(responseBuilder, event)
return nil
}
if event.Event != eventResult || event.HttpStatus != httpStatus200 {
// Something goes wrong. Just pass through the event.
m.appendStreamEvent(responseBuilder, event)
return nil
}
qwenResponse := &qwenTextGenResponse{}
if err := json.Unmarshal([]byte(event.Data), qwenResponse); err != nil {
log.Errorf("unable to unmarshal Qwen response: %v", err)
return fmt.Errorf("unable to unmarshal Qwen response: %v", err)
}
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
for _, response := range responses {
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return fmt.Errorf("unable to marshal response: %v", err)
}
modifiedEvent := &*event
modifiedEvent.Data = string(responseBody)
m.appendStreamEvent(responseBuilder, modifiedEvent)
}
return nil
}
func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
request := &qwenTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
@@ -558,7 +468,7 @@ func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onl
return json.Marshal(request)
}
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *StreamEvent) {
responseBuilder.WriteString(streamDataItemKey)
responseBuilder.WriteString(event.Data)
responseBuilder.WriteString("\n\n")