bugfix: cannot parse content if one streaming body has multi chunks (#1606)

This commit is contained in:
rinfx
2024-12-19 16:21:57 +08:00
committed by GitHub
parent be27726721
commit d74d327b68
3 changed files with 84 additions and 83 deletions

View File

@@ -101,7 +101,10 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun
} }
func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) { func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) {
subMessages := strings.Split(sseMessage, "\n") content := ""
for _, chunk := range strings.Split(sseMessage, "\n\n") {
log.Infof("chunk _ : %s", chunk)
subMessages := strings.Split(chunk, "\n")
var message string var message string
for _, msg := range subMessages { for _, msg := range subMessages {
if strings.HasPrefix(msg, "data:") { if strings.HasPrefix(msg, "data:") {
@@ -110,14 +113,14 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag
} }
} }
if len(message) < 6 { if len(message) < 6 {
return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message) return content, fmt.Errorf("[processSSEMessage] invalid message: %s", message)
} }
// skip the prefix "data:" // skip the prefix "data:"
bodyJson := message[5:] bodyJson := message[5:]
if strings.TrimSpace(bodyJson) == "[DONE]" { if strings.TrimSpace(bodyJson) == "[DONE]" {
return "", nil return content, nil
} }
// Extract values from JSON fields // Extract values from JSON fields
@@ -133,23 +136,23 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag
if !responseBody.Exists() { if !responseBody.Exists() {
if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil { if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil {
log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message) log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message)
return "", nil return content, nil
} }
return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message) return content, fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message)
} else { } else {
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
// If there is no content in the cache, initialize and set the content // If there is no content in the cache, initialize and set the content
if tempContentI == nil { if tempContentI == nil {
content := responseBody.String() content = responseBody.String()
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
return content, nil } else {
}
// Update the content in the cache // Update the content in the cache
appendMsg := responseBody.String() appendMsg := responseBody.String()
content := tempContentI.(string) + appendMsg content = tempContentI.(string) + appendMsg
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
}
}
}
return content, nil return content, nil
} }
}

View File

@@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=

View File

@@ -194,6 +194,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
ctx.SetContext(StreamContextKey, struct{}{}) ctx.SetContext(StreamContextKey, struct{}{})
} }
identityKey := ctx.GetStringContext(IdentityKey, "") identityKey := ctx.GetStringContext(IdentityKey, "")
question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
if question == "" {
log.Debug("parse question from request body failed")
return types.ActionContinue
}
ctx.SetContext(QuestionContextKey, question)
err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) { err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) {
if err := response.Error(); err != nil { if err := response.Error(); err != nil {
log.Errorf("redis get failed, err:%v", err) log.Errorf("redis get failed, err:%v", err)
@@ -230,13 +236,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
_ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1) _ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1)
return return
} }
question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
if question == "" {
log.Debug("parse question from request body failed")
_ = proxywasm.ResumeHttpRequest()
return
}
ctx.SetContext(QuestionContextKey, question)
fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2 fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2
currJson := bodyJson.Get("messages").String() currJson := bodyJson.Get("messages").String()
var currMessage []ChatHistory var currMessage []ChatHistory
@@ -317,7 +316,9 @@ func getIntQueryParameter(name string, path string, defaultValue int) int {
} }
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string {
subMessages := strings.Split(sseMessage, "\n") content := ""
for _, chunk := range strings.Split(sseMessage, "\n\n") {
subMessages := strings.Split(chunk, "\n")
var message string var message string
for _, msg := range subMessages { for _, msg := range subMessages {
if strings.HasPrefix(msg, "data:") { if strings.HasPrefix(msg, "data:") {
@@ -327,28 +328,27 @@ func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage
} }
if len(message) < 6 { if len(message) < 6 {
log.Errorf("invalid message:%s", message) log.Errorf("invalid message:%s", message)
return "" return content
} }
// skip the prefix "data:" // skip the prefix "data:"
bodyJson := message[5:] bodyJson := message[5:]
if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() { if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() {
tempContentI := ctx.GetContext(AnswerContentContextKey) tempContentI := ctx.GetContext(AnswerContentContextKey)
if tempContentI == nil { if tempContentI == nil {
content := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) content = TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
ctx.SetContext(AnswerContentContextKey, content) ctx.SetContext(AnswerContentContextKey, content)
return content } else {
}
append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
content := tempContentI.(string) + append content = tempContentI.(string) + append
ctx.SetContext(AnswerContentContextKey, content) ctx.SetContext(AnswerContentContextKey, content)
return content }
} else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
// TODO: compatible with other providers // TODO: compatible with other providers
ctx.SetContext(ToolCallsContextKey, struct{}{}) ctx.SetContext(ToolCallsContextKey, struct{}{})
return ""
} }
log.Debugf("unknown message:%s", bodyJson) log.Debugf("unknown message:%s", bodyJson)
return "" }
return content
} }
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {