diff --git a/plugins/wasm-go/extensions/ai-search/README.md b/plugins/wasm-go/extensions/ai-search/README.md index f7829d490..90a2aa955 100644 --- a/plugins/wasm-go/extensions/ai-search/README.md +++ b/plugins/wasm-go/extensions/ai-search/README.md @@ -78,6 +78,12 @@ description: higress 支持通过集成搜索引擎(Google/Bing/Arxiv/Elastics | linkField | string | 必填 | - | 结果链接字段名称 | | titleField | string | 必填 | - | 结果标题字段名称 | +## Quark 特定配置 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|------|----------|----------|--------|------| +| contentMode | string | 选填 | "summary" | 内容模式:"summary"使用摘要(snippet),"full"使用正文(优先markdownText,为空则用mainText) | + ## 配置示例 ### 基础配置(单搜索引擎) @@ -115,6 +121,7 @@ searchFrom: serviceName: "quark-svc.dns" servicePort: 443 apiKey: "quark api key" + contentMode: "full" # 可选值:"summary"(默认)或"full" ``` ### 多搜索引擎配置 diff --git a/plugins/wasm-go/extensions/ai-search/README_EN.md b/plugins/wasm-go/extensions/ai-search/README_EN.md index 623c09a1d..c56604a8f 100644 --- a/plugins/wasm-go/extensions/ai-search/README_EN.md +++ b/plugins/wasm-go/extensions/ai-search/README_EN.md @@ -78,6 +78,12 @@ It is strongly recommended to enable this feature when using Arxiv or Elasticsea | linkField | string | Required | - | Result link field name | | titleField | string | Required | - | Result title field name | +## Quark Specific Configuration + +| Name | Data Type | Requirement | Default Value | Description | +|------|-----------|-------------|---------------|-------------| +| contentMode | string | Optional | "summary" | Content mode: "summary" uses snippet, "full" uses full text (markdownText first, then mainText if empty) | + ## Configuration Examples ### Basic Configuration (Single Search Engine) @@ -113,7 +119,8 @@ searchFrom: - type: quark serviceName: "quark-svc.dns" servicePort: 443 - apiKey: "aliyun accessKey" + apiKey: "quark api key" + contentMode: "full" # Optional values: "summary"(default) or "full" ``` ### Multiple Search Engines Configuration diff --git a/plugins/wasm-go/extensions/ai-search/engine/quark/quark.go b/plugins/wasm-go/extensions/ai-search/engine/quark/quark.go index 1df377fe4..850f66f02 100644 --- a/plugins/wasm-go/extensions/ai-search/engine/quark/quark.go +++ b/plugins/wasm-go/extensions/ai-search/engine/quark/quark.go @@ -24,6 +24,7 @@ type QuarkSearch struct { client wrapper.HttpClient count uint32 optionArgs map[string]string + contentMode string // "summary" or "full" } const ( @@ -112,6 +113,13 @@ func NewQuarkSearch(config *gjson.Result) (*QuarkSearch, error) { engine.optionArgs[key] = value.String() } } + engine.contentMode = config.Get("contentMode").String() + if engine.contentMode == "" { + engine.contentMode = "summary" + } + if engine.contentMode != "full" && engine.contentMode != "summary" { + return nil, fmt.Errorf("contentMode is not valid:%s", engine.contentMode) + } return engine, nil } @@ -148,10 +156,19 @@ func (g QuarkSearch) ParseResult(ctx engine.SearchContext, response []byte) []en jsonObj := gjson.ParseBytes(response) var results []engine.SearchResult for index, item := range jsonObj.Get("pageItems").Array() { + var content string + if g.contentMode == "full" { + content = item.Get("markdownText").String() + if content == "" { + content = item.Get("mainText").String() + } + } else if g.contentMode == "summary" { + content = item.Get("snippet").String() + } result := engine.SearchResult{ Title: item.Get("title").String(), Link: item.Get("link").String(), - Content: item.Get("mainText").String(), + Content: content, } if result.Valid() && index < int(g.count) { results = append(results, result) diff --git a/plugins/wasm-go/extensions/ai-search/main.go b/plugins/wasm-go/extensions/ai-search/main.go index 720e688cc..59446d0a3 100644 --- a/plugins/wasm-go/extensions/ai-search/main.go +++ b/plugins/wasm-go/extensions/ai-search/main.go @@ -15,12 +15,14 @@ package main import ( + "bytes" _ "embed" "errors" "fmt" "net/http" "strings" "time" + "unicode" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" @@ -492,8 +494,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log if references == "" { return types.ActionContinue } - content := gjson.GetBytes(body, "choices.0.message.content") - modifiedContent := fmt.Sprintf("%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content) + content := gjson.GetBytes(body, "choices.0.message.content").String() + var modifiedContent string + if strings.HasPrefix(strings.TrimLeftFunc(content, unicode.IsSpace), "") { + thinkEnd := strings.Index(content, "") + if thinkEnd != -1 { + modifiedContent = content[:thinkEnd+8] + + fmt.Sprintf("\n%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content[thinkEnd+8:]) + } + } + if modifiedContent == "" { + modifiedContent = fmt.Sprintf("%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content) + } body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent) if err != nil { log.Errorf("modify response message content failed, err:%v, body:%s", err, body) @@ -503,6 +515,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log return types.ActionContinue } +func unifySSEChunk(data []byte) []byte { + data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n")) + data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n")) + return data +} + +const ( + PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" + BUFFER_CONTENT_CONTEXT_KEY = "bufferContent" + BUFFER_SIZE = 30 +) + func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { if ctx.GetBoolContext("ReferenceAppended", false) { return chunk @@ -511,58 +535,110 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt if references == "" { return chunk } - modifiedChunk, responseReady := setReferencesToFirstMessage(ctx, chunk, fmt.Sprintf(config.referenceFormat, references), log) - if responseReady { - ctx.SetContext("ReferenceAppended", true) - return modifiedChunk + chunk = unifySSEChunk(chunk) + var partialMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY)) + if partialMessageI != nil { + partialMessage = append(partialMessageI.([]byte), chunk...) + } else { + partialMessage = chunk + } + messages := strings.Split(string(partialMessage), "\n\n") + var newMessages []string + for i, msg := range messages { + if i < len(messages)-1 { + newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), log) + if newMsg != "" { + newMessages = append(newMessages, newMsg) + } + } + } + if !strings.HasSuffix(string(partialMessage), "\n\n") { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) + } else { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) + } + if len(newMessages) == 1 { + return []byte(fmt.Sprintf("%s\n\n", newMessages[0])) + } else if len(newMessages) > 1 { + return []byte(strings.Join(newMessages, "\n\n")) } else { return []byte("") } } -const PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" - -func setReferencesToFirstMessage(ctx wrapper.HttpContext, chunk []byte, references string, log wrapper.Log) ([]byte, bool) { - if len(chunk) == 0 { - log.Debugf("chunk is empty") - return nil, false - } - - var partialMessage []byte - partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) - if partialMessageI != nil { - if pMsg, ok := partialMessageI.([]byte); ok { - partialMessage = append(pMsg, chunk...) - } else { - log.Warnf("invalid partial message type: %T", partialMessageI) - partialMessage = chunk +func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, log wrapper.Log) string { + log.Debugf("single sse message: %s", sseMessage) + subMessages := strings.Split(sseMessage, "\n") + var message string + for _, msg := range subMessages { + if strings.HasPrefix(msg, "data:") { + message = msg + break } - } else { - partialMessage = chunk } - - if len(partialMessage) == 0 { - log.Debugf("partial message is empty") - return nil, false + if len(message) < 6 { + log.Errorf("[processSSEMessage] invalid message: %s", message) + return sseMessage } - messages := strings.Split(string(partialMessage), "\n\n") - if len(messages) > 1 { - firstMessage := messages[0] - log.Debugf("first message: %s", firstMessage) - firstMessage = strings.TrimPrefix(firstMessage, "data:") - firstMessage = strings.TrimPrefix(firstMessage, " ") - firstMessage = strings.TrimSuffix(firstMessage, "\n") - deltaContent := gjson.Get(firstMessage, "choices.0.delta.content") - modifiedMessage, err := sjson.Set(firstMessage, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, deltaContent)) + // Skip the prefix "data:" + bodyJson := message[5:] + if strings.TrimSpace(bodyJson) == "[DONE]" { + return sseMessage + } + bodyJson = strings.TrimPrefix(bodyJson, " ") + bodyJson = strings.TrimSuffix(bodyJson, "\n") + deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String() + // Skip the preceding content that might be empty due to the presence of a separate reasoning_content field. + if deltaContent == "" { + return sseMessage + } + bufferContent := ctx.GetStringContext(BUFFER_CONTENT_CONTEXT_KEY, "") + deltaContent + if len(bufferContent) < BUFFER_SIZE { + ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, bufferContent) + return "" + } + if !ctx.GetBoolContext("FirstMessageChecked", false) { + ctx.SetContext("FirstMessageChecked", true) + if !strings.Contains(strings.TrimLeftFunc(bufferContent, unicode.IsSpace), "") { + modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, bufferContent)) + if err != nil { + log.Errorf("update messsage failed:%s", err) + } + ctx.SetContext("ReferenceAppended", true) + return fmt.Sprintf("data: %s", modifiedMessage) + } + } + // Content has prefix + // Check for complete tag + thinkEnd := strings.Index(bufferContent, "") + if thinkEnd != -1 { + modifiedContent := bufferContent[:thinkEnd+8] + + fmt.Sprintf("\n%s\n\n%s", references, bufferContent[thinkEnd+8:]) + modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", modifiedContent) if err != nil { - log.Errorf("modify response delta content failed, err:%v", err) - return partialMessage, true + log.Errorf("update messsage failed:%s", err) } - modifiedMessage = fmt.Sprintf("data: %s", modifiedMessage) - log.Debugf("modified message: %s", firstMessage) - messages[0] = string(modifiedMessage) - return []byte(strings.Join(messages, "\n\n")), true + ctx.SetContext("ReferenceAppended", true) + return fmt.Sprintf("data: %s", modifiedMessage) } - ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, partialMessage) - return nil, false + + // Check for partial tag at end of buffer + // Look for any partial match that could be completed in next message + for i := 1; i < len(""); i++ { + if strings.HasSuffix(bufferContent, ""[:i]) { + // Store only the partial match for the next message + ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, bufferContent[len(bufferContent)-i:]) + // Return the content before the partial match + modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", bufferContent[:len(bufferContent)-i]) + if err != nil { + log.Errorf("update messsage failed:%s", err) + } + return fmt.Sprintf("data: %s", modifiedMessage) + } + } + + ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, "") + return sseMessage }