Files
higress/plugins/wasm-go/extensions/ai-search/main.go

761 lines
29 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"bytes"
_ "embed"
"errors"
"fmt"
"net/http"
"strings"
"time"
"unicode"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/elasticsearch"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/google"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/quark"
)
type SearchRewrite struct {
client wrapper.HttpClient
url string
apiKey string
modelName string
timeoutMillisecond uint32
prompt string
promptTemplate string // Original prompt template before replacing placeholders
maxCount int
}
type Config struct {
engine []engine.SearchEngine
promptTemplate string
referenceFormat string
defaultLanguage string
needReference bool
referenceLocation string // "head" or "tail"
searchRewrite *SearchRewrite
defaultEnable bool
}
const (
DEFAULT_MAX_BODY_BYTES uint32 = 100 * 1024 * 1024
)
//go:embed prompts/full.md
var fullSearchPrompts string
//go:embed prompts/arxiv.md
var arxivSearchPrompts string
//go:embed prompts/internet.md
var internetSearchPrompts string
//go:embed prompts/private.md
var privateSearchPrompts string
//go:embed prompts/chinese-internet.md
var chineseInternetSearchPrompts string
func main() {
wrapper.SetCtx(
"ai-search",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
)
}
func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
config.defaultEnable = true // Default to true if not specified
if json.Get("defaultEnable").Exists() {
config.defaultEnable = json.Get("defaultEnable").Bool()
}
config.needReference = json.Get("needReference").Bool()
if config.needReference {
config.referenceFormat = json.Get("referenceFormat").String()
if config.referenceFormat == "" {
config.referenceFormat = "**References:**\n%s"
} else if !strings.Contains(config.referenceFormat, "%s") {
return fmt.Errorf("invalid referenceFormat:%s", config.referenceFormat)
}
config.referenceLocation = json.Get("referenceLocation").String()
if config.referenceLocation == "" {
config.referenceLocation = "head" // Default to head if not specified
} else if config.referenceLocation != "head" && config.referenceLocation != "tail" {
return fmt.Errorf("invalid referenceLocation:%s, must be 'head' or 'tail'", config.referenceLocation)
}
}
config.defaultLanguage = json.Get("defaultLang").String()
config.promptTemplate = json.Get("promptTemplate").String()
if config.promptTemplate == "" {
if config.needReference {
config.promptTemplate = `# 以下内容是基于用户发送的消息的搜索结果:
{search_results}
在我给你的搜索结果中,每个结果都是[webpage X begin]...[webpage X end]格式的X代表每篇文章的数字索引。请在适当的情况下在句子末尾引用上下文。请按照引用编号[X]的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如[3][5],切记不要将引用集中在最后返回引用编号,而是在答案对应部分列出。
在回答时,请注意以下几点:
- 今天是北京时间:{cur_date}。
- 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
- 对于列举类的问题如列举所有航班信息尽量将答案控制在10个要点以内并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项如非必要不要主动告诉用户搜索结果未提供的内容。
- 对于创作类的问题(如写论文),请务必在正文的段落中引用对应的参考编号,例如[3][5],不能只在文章末尾引用。你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
- 如果回答很长请尽量结构化、分段落总结。如果需要分点作答尽量控制在5个点以内并合并相关的内容。
- 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
- 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
- 你的回答应该综合多个相关网页来回答,不能重复引用一个网页。
- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
# 用户消息为:
{question}`
} else {
config.promptTemplate = `# 以下内容是基于用户发送的消息的搜索结果:
{search_results}
在我给你的搜索结果中,每个结果都是[webpage begin]...[webpage end]格式的。
在回答时,请注意以下几点:
- 今天是北京时间:{cur_date}。
- 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
- 对于列举类的问题如列举所有航班信息尽量将答案控制在10个要点以内。如非必要不要主动告诉用户搜索结果未提供的内容。
- 对于创作类的问题(如写论文),你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
- 如果回答很长请尽量结构化、分段落总结。如果需要分点作答尽量控制在5个点以内并合并相关的内容。
- 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
- 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
- 你的回答应该综合多个相关网页来回答,但回答中不要给出网页的引用来源。
- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
# 用户消息为:
{question}`
}
}
if !strings.Contains(config.promptTemplate, "{search_results}") ||
!strings.Contains(config.promptTemplate, "{question}") {
return fmt.Errorf("invalid promptTemplate, must contains {search_results} and {question}:%s", config.promptTemplate)
}
var internetExists, privateExists, arxivExists bool
var onlyQuark bool = true
for _, e := range json.Get("searchFrom").Array() {
switch e.Get("type").String() {
case "bing":
searchEngine, err := bing.NewBingSearch(&e)
if err != nil {
return fmt.Errorf("bing search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
internetExists = true
onlyQuark = false
case "google":
searchEngine, err := google.NewGoogleSearch(&e)
if err != nil {
return fmt.Errorf("google search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
internetExists = true
onlyQuark = false
case "arxiv":
searchEngine, err := arxiv.NewArxivSearch(&e)
if err != nil {
return fmt.Errorf("arxiv search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
arxivExists = true
onlyQuark = false
case "elasticsearch":
searchEngine, err := elasticsearch.NewElasticsearchSearch(&e, config.needReference)
if err != nil {
return fmt.Errorf("elasticsearch search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
privateExists = true
onlyQuark = false
case "quark":
searchEngine, err := quark.NewQuarkSearch(&e)
if err != nil {
return fmt.Errorf("quark search engine init failed:%s", err)
}
config.engine = append(config.engine, searchEngine)
internetExists = true
default:
return fmt.Errorf("unkown search engine:%s", e.Get("type").String())
}
}
searchRewriteJson := json.Get("searchRewrite")
if searchRewriteJson.Exists() {
searchRewrite := &SearchRewrite{}
llmServiceName := searchRewriteJson.Get("llmServiceName").String()
if llmServiceName == "" {
return errors.New("llm_service_name not found")
}
llmServicePort := searchRewriteJson.Get("llmServicePort").Int()
if llmServicePort == 0 {
return errors.New("llmServicePort not found")
}
searchRewrite.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: llmServiceName,
Port: llmServicePort,
})
llmApiKey := searchRewriteJson.Get("llmApiKey").String()
searchRewrite.apiKey = llmApiKey
llmUrl := searchRewriteJson.Get("llmUrl").String()
if llmUrl == "" {
return errors.New("llmUrl not found")
}
searchRewrite.url = llmUrl
llmModelName := searchRewriteJson.Get("llmModelName").String()
if llmModelName == "" {
return errors.New("llmModelName not found")
}
searchRewrite.modelName = llmModelName
llmTimeout := searchRewriteJson.Get("timeoutMillisecond").Uint()
if llmTimeout == 0 {
llmTimeout = 30000
}
searchRewrite.timeoutMillisecond = uint32(llmTimeout)
maxCount := searchRewriteJson.Get("maxCount").Int()
if maxCount == 0 {
maxCount = 3 // Default value
}
searchRewrite.maxCount = int(maxCount)
// The consideration here is that internet searches are generally available, but arxiv and private sources may not be.
if arxivExists {
if privateExists {
// private + internet + arxiv
searchRewrite.prompt = fullSearchPrompts
} else {
// internet + arxiv
searchRewrite.prompt = arxivSearchPrompts
}
} else if privateExists {
// private + internet
searchRewrite.prompt = privateSearchPrompts
} else if internetExists {
// only internet
if onlyQuark {
// When only quark is used, use chinese-internet.md
searchRewrite.prompt = chineseInternetSearchPrompts
} else {
searchRewrite.prompt = internetSearchPrompts
}
}
// Store the original prompt template before replacing placeholders
searchRewrite.promptTemplate = searchRewrite.prompt
// Replace {max_count} placeholder in the prompt with the configured value
searchRewrite.prompt = strings.Replace(searchRewrite.prompt, "{max_count}", fmt.Sprintf("%d", searchRewrite.maxCount), -1)
config.searchRewrite = searchRewrite
}
if len(config.engine) == 0 {
return fmt.Errorf("no avaliable search engine found")
}
log.Debugf("ai search enabled, config: %#v", config)
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Log) types.Action {
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
// The request does not have a body.
if contentType == "" {
return types.ActionContinue
}
if !strings.Contains(contentType, "application/json") {
log.Warnf("content is not json, can't process: %s", contentType)
ctx.DontReadRequestBody()
return types.ActionContinue
}
ctx.SetRequestBodyBufferLimit(DEFAULT_MAX_BODY_BYTES)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
// Check if plugin should be enabled based on config and request
webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !config.defaultEnable {
// When defaultEnable is false, we need to check if web_search_options exists in the request
if !webSearchOptions.Exists() {
log.Debugf("Plugin disabled by config and no web_search_options in request")
return types.ActionContinue
}
log.Debugf("Plugin enabled by web_search_options in request")
}
var queryIndex int
var query string
messages := gjson.GetBytes(body, "messages").Array()
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Get("role").String() == "user" {
queryIndex = i
query = messages[i].Get("content").String()
break
}
}
if query == "" {
log.Errorf("not found user query in body:%s", body)
return types.ActionContinue
}
searchRewrite := config.searchRewrite
if searchRewrite != nil {
// Check if web_search_options.search_context_size exists and adjust maxCount accordingly
if webSearchOptions.Exists() {
searchContextSize := webSearchOptions.Get("search_context_size").String()
if searchContextSize != "" {
originalMaxCount := searchRewrite.maxCount
switch searchContextSize {
case "low":
searchRewrite.maxCount = 1
log.Debugf("Setting maxCount to 1 based on search_context_size=low")
case "medium":
searchRewrite.maxCount = 3
log.Debugf("Setting maxCount to 3 based on search_context_size=medium")
case "high":
searchRewrite.maxCount = 5
log.Debugf("Setting maxCount to 5 based on search_context_size=high")
default:
log.Warnf("Unknown search_context_size value: %s, using configured maxCount: %d",
searchContextSize, searchRewrite.maxCount)
}
// If maxCount changed, regenerate the prompt from the template
if originalMaxCount != searchRewrite.maxCount && searchRewrite.promptTemplate != "" {
searchRewrite.prompt = strings.Replace(
searchRewrite.promptTemplate,
"{max_count}",
fmt.Sprintf("%d", searchRewrite.maxCount),
-1)
}
}
}
startTime := time.Now()
rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1)
rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf(
`{"stream":false,"max_tokens":4096,"model":"%s","messages":[{"role":"user","content":""}]}`,
searchRewrite.modelName)), "messages.0.content", rewritePrompt)
err := searchRewrite.client.Post(searchRewrite.url,
[][2]string{
{"Content-Type", "application/json"},
{"Authorization", fmt.Sprintf("Bearer %s", searchRewrite.apiKey)},
}, rewriteBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
log.Errorf("search rewrite failed, status: %d", statusCode)
// After a rewrite failure, no further search is performed, thus quickly identifying the failure.
proxywasm.ResumeHttpRequest()
return
}
content := gjson.GetBytes(responseBody, "choices.0.message.content").String()
log.Infof("LLM rewritten query response: %s (took %v), original search query:%s",
strings.ReplaceAll(content, "\n", `\n`), time.Since(startTime), query)
if strings.Contains(content, "none") {
log.Debugf("no search required")
proxywasm.ResumeHttpRequest()
return
}
// Parse search queries from LLM response
var searchContexts []engine.SearchContext
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
engineType := strings.TrimSpace(parts[0])
queryStr := strings.TrimSpace(parts[1])
var ctx engine.SearchContext
ctx.Language = config.defaultLanguage
switch {
case engineType == "internet":
ctx.EngineType = engineType
ctx.Querys = []string{queryStr}
case engineType == "private":
ctx.EngineType = engineType
ctx.Querys = strings.Split(queryStr, ",")
for i := range ctx.Querys {
ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i])
}
default:
// Arxiv category
ctx.EngineType = "arxiv"
ctx.ArxivCategory = engineType
ctx.Querys = strings.Split(queryStr, ",")
for i := range ctx.Querys {
ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i])
}
}
if len(ctx.Querys) > 0 {
searchContexts = append(searchContexts, ctx)
if ctx.ArxivCategory != "" {
// Conduct i/nquiries in all areas to increase recall.
backupCtx := ctx
backupCtx.ArxivCategory = ""
searchContexts = append(searchContexts, backupCtx)
}
}
}
if len(searchContexts) == 0 {
log.Errorf("no valid search contexts found")
proxywasm.ResumeHttpRequest()
return
}
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) {
proxywasm.ResumeHttpRequest()
}
}, searchRewrite.timeoutMillisecond)
if err != nil {
log.Errorf("search rewrite call llm service failed:%s", err)
// After a rewrite failure, no further search is performed, thus quickly identifying the failure.
return types.ActionContinue
}
return types.ActionPause
}
// Execute search without rewrite
return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{
Querys: []string{query},
Language: config.defaultLanguage,
}}, log)
}
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log wrapper.Log) types.Action {
searchResultGroups := make([][]engine.SearchResult, len(config.engine))
var finished int
var searching int
for i := 0; i < len(config.engine); i++ {
configEngine := config.engine[i]
// Check if engine needs to execute for any of the search contexts
var needsExecute bool
for _, searchCtx := range searchContexts {
if configEngine.NeedExectue(searchCtx) {
needsExecute = true
break
}
}
if !needsExecute {
continue
}
// Process all search contexts for this engine
for _, searchCtx := range searchContexts {
if !configEngine.NeedExectue(searchCtx) {
continue
}
args := configEngine.CallArgs(searchCtx)
index := i
err := configEngine.Client().Call(args.Method, args.Url, args.Headers, args.Body,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer func() {
finished++
if finished == searching {
// Merge search results from all engines with deduplication
var mergedResults []engine.SearchResult
seenLinks := make(map[string]bool)
for _, results := range searchResultGroups {
for _, result := range results {
if !seenLinks[result.Link] {
seenLinks[result.Link] = true
mergedResults = append(mergedResults, result)
}
}
}
if len(mergedResults) == 0 {
log.Warnf("no search result found, searchContexts:%#v", searchContexts)
proxywasm.ResumeHttpRequest()
return
}
// Format search results for prompt template
var formattedResults []string
var formattedReferences []string
for j, result := range mergedResults {
if config.needReference {
formattedResults = append(formattedResults,
fmt.Sprintf("[webpage %d begin]\n%s\n[webpage %d end]", j+1, result.Content, j+1))
formattedReferences = append(formattedReferences,
fmt.Sprintf("[%d] [%s](%s)", j+1, result.Title, result.Link))
} else {
formattedResults = append(formattedResults,
fmt.Sprintf("[webpage begin]\n%s\n[webpage end]", result.Content))
}
}
// Prepare template variables
curDate := time.Now().In(time.FixedZone("CST", 8*3600)).Format("2006年1月2日")
searchResults := strings.Join(formattedResults, "\n")
log.Debugf("searchResults: %s", searchResults)
// Fill prompt template
prompt := strings.Replace(config.promptTemplate, "{search_results}", searchResults, 1)
prompt = strings.Replace(prompt, "{question}", searchContexts[0].Querys[0], 1)
prompt = strings.Replace(prompt, "{cur_date}", curDate, 1)
// Update request body with processed prompt
modifiedBody, err := sjson.SetBytes(body, fmt.Sprintf("messages.%d.content", queryIndex), prompt)
if err != nil {
log.Errorf("modify request message content failed, err:%v, body:%s", err, body)
} else {
log.Debugf("modifeid body:%s", modifiedBody)
proxywasm.ReplaceHttpRequestBody(modifiedBody)
if config.needReference {
ctx.SetContext("References", strings.Join(formattedReferences, "\n\n"))
}
}
proxywasm.ResumeHttpRequest()
}
}()
if statusCode != http.StatusOK {
log.Errorf("search call failed, status: %d, engine: %#v", statusCode, configEngine)
return
}
// Append results to existing slice for this engine
searchResultGroups[index] = append(searchResultGroups[index], configEngine.ParseResult(searchCtx, responseBody)...)
}, args.TimeoutMillisecond)
if err != nil {
log.Errorf("search call failed, engine: %#v", configEngine)
continue
}
searching++
}
}
if searching > 0 {
return types.ActionPause
}
return types.ActionContinue
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Log) types.Action {
if !config.needReference {
ctx.DontReadResponseBody()
return types.ActionContinue
}
proxywasm.RemoveHttpResponseHeader("content-length")
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
if err != nil {
log.Errorf("unable to load content-type header from response: %v", err)
}
ctx.BufferResponseBody()
ctx.SetResponseBodyBufferLimit(DEFAULT_MAX_BODY_BYTES)
}
return types.ActionContinue
}
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
references := ctx.GetStringContext("References", "")
if references == "" {
return types.ActionContinue
}
content := gjson.GetBytes(body, "choices.0.message.content").String()
var modifiedContent string
formattedReferences := fmt.Sprintf(config.referenceFormat, references)
if strings.HasPrefix(strings.TrimLeftFunc(content, unicode.IsSpace), "<think>") {
thinkEnd := strings.Index(content, "</think>")
if thinkEnd != -1 {
if config.referenceLocation == "tail" {
// Add references at the end
modifiedContent = content + fmt.Sprintf("\n\n%s", formattedReferences)
} else {
// Default: add references after </think> tag
modifiedContent = content[:thinkEnd+8] +
fmt.Sprintf("\n%s\n\n%s", formattedReferences, content[thinkEnd+8:])
}
}
}
if modifiedContent == "" {
if config.referenceLocation == "tail" {
// Add references at the end
modifiedContent = fmt.Sprintf("%s\n\n%s", content, formattedReferences)
} else {
// Default: add references at the beginning
modifiedContent = fmt.Sprintf("%s\n\n%s", formattedReferences, content)
}
}
body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent)
if err != nil {
log.Errorf("modify response message content failed, err:%v, body:%s", err, body)
return types.ActionContinue
}
proxywasm.ReplaceHttpResponseBody(body)
return types.ActionContinue
}
func unifySSEChunk(data []byte) []byte {
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
return data
}
const (
PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
BUFFER_CONTENT_CONTEXT_KEY = "bufferContent"
BUFFER_SIZE = 30
)
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
if ctx.GetBoolContext("ReferenceAppended", false) {
return chunk
}
references := ctx.GetStringContext("References", "")
if references == "" {
return chunk
}
chunk = unifySSEChunk(chunk)
var partialMessage []byte
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY))
if partialMessageI != nil {
partialMessage = append(partialMessageI.([]byte), chunk...)
} else {
partialMessage = chunk
}
messages := strings.Split(string(partialMessage), "\n\n")
var newMessages []string
for i, msg := range messages {
if i < len(messages)-1 {
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log)
if newMsg != "" {
newMessages = append(newMessages, newMsg)
}
}
}
if !strings.HasSuffix(string(partialMessage), "\n\n") {
ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1]))
} else {
ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil)
}
if len(newMessages) > 0 {
return []byte(fmt.Sprintf("%s\n\n", strings.Join(newMessages, "\n\n")))
} else {
return []byte("")
}
}
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log wrapper.Log) string {
log.Debugf("single sse message: %s", sseMessage)
subMessages := strings.Split(sseMessage, "\n")
var message string
for _, msg := range subMessages {
if strings.HasPrefix(msg, "data:") {
message = msg
break
}
}
if len(message) < 6 {
log.Errorf("[processSSEMessage] invalid message: %s", message)
return sseMessage
}
// Skip the prefix "data:"
bodyJson := message[5:]
if strings.TrimSpace(bodyJson) == "[DONE]" {
return sseMessage
}
bodyJson = strings.TrimPrefix(bodyJson, " ")
bodyJson = strings.TrimSuffix(bodyJson, "\n")
// If tailReference is true, only check if this is the last message
if tailReference {
// Check if this is the last message in the stream (finish_reason is "stop")
finishReason := gjson.Get(bodyJson, "choices.0.finish_reason").String()
if finishReason == "stop" {
// This is the last message, append references at the end
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", deltaContent+fmt.Sprintf("\n\n%s", references))
if err != nil {
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
}
// Not the last message, return original message
return sseMessage
}
// Original head reference logic
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
// Skip the preceding content that might be empty due to the presence of a separate reasoning_content field.
if deltaContent == "" {
return sseMessage
}
bufferContent := ctx.GetStringContext(BUFFER_CONTENT_CONTEXT_KEY, "") + deltaContent
if len(bufferContent) < BUFFER_SIZE {
ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, bufferContent)
return ""
}
if !ctx.GetBoolContext("FirstMessageChecked", false) {
ctx.SetContext("FirstMessageChecked", true)
if !strings.Contains(strings.TrimLeftFunc(bufferContent, unicode.IsSpace), "<think>") {
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, bufferContent))
if err != nil {
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
}
}
// Content has <think> prefix
// Check for complete </think> tag
thinkEnd := strings.Index(bufferContent, "</think>")
if thinkEnd != -1 {
modifiedContent := bufferContent[:thinkEnd+8] +
fmt.Sprintf("\n%s\n\n%s", references, bufferContent[thinkEnd+8:])
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", modifiedContent)
if err != nil {
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
}
// Check for partial </think> tag at end of buffer
// Look for any partial match that could be completed in next message
for i := 1; i < len("</think>"); i++ {
if strings.HasSuffix(bufferContent, "</think>"[:i]) {
// Store only the partial match for the next message
ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, bufferContent[len(bufferContent)-i:])
// Return the content before the partial match
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", bufferContent[:len(bufferContent)-i])
if err != nil {
log.Errorf("update message failed:%s", err)
}
return fmt.Sprintf("data: %s", modifiedMessage)
}
}
ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, "")
return sseMessage
}