diff --git a/plugins/wasm-go/extensions/ai-search/README.md b/plugins/wasm-go/extensions/ai-search/README.md
index f7829d490..90a2aa955 100644
--- a/plugins/wasm-go/extensions/ai-search/README.md
+++ b/plugins/wasm-go/extensions/ai-search/README.md
@@ -78,6 +78,12 @@ description: higress 支持通过集成搜索引擎(Google/Bing/Arxiv/Elastics
| linkField | string | 必填 | - | 结果链接字段名称 |
| titleField | string | 必填 | - | 结果标题字段名称 |
+## Quark 特定配置
+
+| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
+|------|----------|----------|--------|------|
+| contentMode | string | 选填 | "summary" | 内容模式:"summary"使用摘要(snippet),"full"使用正文(优先markdownText,为空则用mainText) |
+
## 配置示例
### 基础配置(单搜索引擎)
@@ -115,6 +121,7 @@ searchFrom:
serviceName: "quark-svc.dns"
servicePort: 443
apiKey: "quark api key"
+ contentMode: "full" # 可选值:"summary"(默认)或"full"
```
### 多搜索引擎配置
diff --git a/plugins/wasm-go/extensions/ai-search/README_EN.md b/plugins/wasm-go/extensions/ai-search/README_EN.md
index 623c09a1d..c56604a8f 100644
--- a/plugins/wasm-go/extensions/ai-search/README_EN.md
+++ b/plugins/wasm-go/extensions/ai-search/README_EN.md
@@ -78,6 +78,12 @@ It is strongly recommended to enable this feature when using Arxiv or Elasticsea
| linkField | string | Required | - | Result link field name |
| titleField | string | Required | - | Result title field name |
+## Quark Specific Configuration
+
+| Name | Data Type | Requirement | Default Value | Description |
+|------|-----------|-------------|---------------|-------------|
+| contentMode | string | Optional | "summary" | Content mode: "summary" uses snippet, "full" uses full text (markdownText first, then mainText if empty) |
+
## Configuration Examples
### Basic Configuration (Single Search Engine)
@@ -113,7 +119,8 @@ searchFrom:
- type: quark
serviceName: "quark-svc.dns"
servicePort: 443
- apiKey: "aliyun accessKey"
+ apiKey: "quark api key"
+ contentMode: "full" # Optional values: "summary"(default) or "full"
```
### Multiple Search Engines Configuration
diff --git a/plugins/wasm-go/extensions/ai-search/engine/quark/quark.go b/plugins/wasm-go/extensions/ai-search/engine/quark/quark.go
index 1df377fe4..850f66f02 100644
--- a/plugins/wasm-go/extensions/ai-search/engine/quark/quark.go
+++ b/plugins/wasm-go/extensions/ai-search/engine/quark/quark.go
@@ -24,6 +24,7 @@ type QuarkSearch struct {
client wrapper.HttpClient
count uint32
optionArgs map[string]string
+ contentMode string // "summary" or "full"
}
const (
@@ -112,6 +113,13 @@ func NewQuarkSearch(config *gjson.Result) (*QuarkSearch, error) {
engine.optionArgs[key] = value.String()
}
}
+ engine.contentMode = config.Get("contentMode").String()
+ if engine.contentMode == "" {
+ engine.contentMode = "summary"
+ }
+ if engine.contentMode != "full" && engine.contentMode != "summary" {
+ return nil, fmt.Errorf("contentMode is not valid:%s", engine.contentMode)
+ }
return engine, nil
}
@@ -148,10 +156,19 @@ func (g QuarkSearch) ParseResult(ctx engine.SearchContext, response []byte) []en
jsonObj := gjson.ParseBytes(response)
var results []engine.SearchResult
for index, item := range jsonObj.Get("pageItems").Array() {
+ var content string
+ if g.contentMode == "full" {
+ content = item.Get("markdownText").String()
+ if content == "" {
+ content = item.Get("mainText").String()
+ }
+ } else if g.contentMode == "summary" {
+ content = item.Get("snippet").String()
+ }
result := engine.SearchResult{
Title: item.Get("title").String(),
Link: item.Get("link").String(),
- Content: item.Get("mainText").String(),
+ Content: content,
}
if result.Valid() && index < int(g.count) {
results = append(results, result)
diff --git a/plugins/wasm-go/extensions/ai-search/main.go b/plugins/wasm-go/extensions/ai-search/main.go
index 720e688cc..59446d0a3 100644
--- a/plugins/wasm-go/extensions/ai-search/main.go
+++ b/plugins/wasm-go/extensions/ai-search/main.go
@@ -15,12 +15,14 @@
package main
import (
+ "bytes"
_ "embed"
"errors"
"fmt"
"net/http"
"strings"
"time"
+ "unicode"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -492,8 +494,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
if references == "" {
return types.ActionContinue
}
- content := gjson.GetBytes(body, "choices.0.message.content")
- modifiedContent := fmt.Sprintf("%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content)
+ content := gjson.GetBytes(body, "choices.0.message.content").String()
+ var modifiedContent string
+ if strings.HasPrefix(strings.TrimLeftFunc(content, unicode.IsSpace), "") {
+ thinkEnd := strings.Index(content, "")
+ if thinkEnd != -1 {
+ modifiedContent = content[:thinkEnd+8] +
+ fmt.Sprintf("\n%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content[thinkEnd+8:])
+ }
+ }
+ if modifiedContent == "" {
+ modifiedContent = fmt.Sprintf("%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content)
+ }
body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent)
if err != nil {
log.Errorf("modify response message content failed, err:%v, body:%s", err, body)
@@ -503,6 +515,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
return types.ActionContinue
}
+func unifySSEChunk(data []byte) []byte {
+ data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
+ data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
+ return data
+}
+
+const (
+ PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
+ BUFFER_CONTENT_CONTEXT_KEY = "bufferContent"
+ BUFFER_SIZE = 30
+)
+
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
if ctx.GetBoolContext("ReferenceAppended", false) {
return chunk
@@ -511,58 +535,110 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
if references == "" {
return chunk
}
- modifiedChunk, responseReady := setReferencesToFirstMessage(ctx, chunk, fmt.Sprintf(config.referenceFormat, references), log)
- if responseReady {
- ctx.SetContext("ReferenceAppended", true)
- return modifiedChunk
+ chunk = unifySSEChunk(chunk)
+ var partialMessage []byte
+ partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
+ log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY))
+ if partialMessageI != nil {
+ partialMessage = append(partialMessageI.([]byte), chunk...)
+ } else {
+ partialMessage = chunk
+ }
+ messages := strings.Split(string(partialMessage), "\n\n")
+ var newMessages []string
+ for i, msg := range messages {
+ if i < len(messages)-1 {
+ newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), log)
+ if newMsg != "" {
+ newMessages = append(newMessages, newMsg)
+ }
+ }
+ }
+ if !strings.HasSuffix(string(partialMessage), "\n\n") {
+ ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1]))
+ } else {
+ ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil)
+ }
+ if len(newMessages) == 1 {
+ return []byte(fmt.Sprintf("%s\n\n", newMessages[0]))
+ } else if len(newMessages) > 1 {
+ return []byte(strings.Join(newMessages, "\n\n"))
} else {
return []byte("")
}
}
-const PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
-
-func setReferencesToFirstMessage(ctx wrapper.HttpContext, chunk []byte, references string, log wrapper.Log) ([]byte, bool) {
- if len(chunk) == 0 {
- log.Debugf("chunk is empty")
- return nil, false
- }
-
- var partialMessage []byte
- partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
- if partialMessageI != nil {
- if pMsg, ok := partialMessageI.([]byte); ok {
- partialMessage = append(pMsg, chunk...)
- } else {
- log.Warnf("invalid partial message type: %T", partialMessageI)
- partialMessage = chunk
+func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, log wrapper.Log) string {
+ log.Debugf("single sse message: %s", sseMessage)
+ subMessages := strings.Split(sseMessage, "\n")
+ var message string
+ for _, msg := range subMessages {
+ if strings.HasPrefix(msg, "data:") {
+ message = msg
+ break
}
- } else {
- partialMessage = chunk
}
-
- if len(partialMessage) == 0 {
- log.Debugf("partial message is empty")
- return nil, false
+ if len(message) < 6 {
+ log.Errorf("[processSSEMessage] invalid message: %s", message)
+ return sseMessage
}
- messages := strings.Split(string(partialMessage), "\n\n")
- if len(messages) > 1 {
- firstMessage := messages[0]
- log.Debugf("first message: %s", firstMessage)
- firstMessage = strings.TrimPrefix(firstMessage, "data:")
- firstMessage = strings.TrimPrefix(firstMessage, " ")
- firstMessage = strings.TrimSuffix(firstMessage, "\n")
- deltaContent := gjson.Get(firstMessage, "choices.0.delta.content")
- modifiedMessage, err := sjson.Set(firstMessage, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, deltaContent))
+ // Skip the prefix "data:"
+ bodyJson := message[5:]
+ if strings.TrimSpace(bodyJson) == "[DONE]" {
+ return sseMessage
+ }
+ bodyJson = strings.TrimPrefix(bodyJson, " ")
+ bodyJson = strings.TrimSuffix(bodyJson, "\n")
+ deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
+ // Skip the preceding content that might be empty due to the presence of a separate reasoning_content field.
+ if deltaContent == "" {
+ return sseMessage
+ }
+ bufferContent := ctx.GetStringContext(BUFFER_CONTENT_CONTEXT_KEY, "") + deltaContent
+ if len(bufferContent) < BUFFER_SIZE {
+ ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, bufferContent)
+ return ""
+ }
+ if !ctx.GetBoolContext("FirstMessageChecked", false) {
+ ctx.SetContext("FirstMessageChecked", true)
+ if !strings.Contains(strings.TrimLeftFunc(bufferContent, unicode.IsSpace), "") {
+ modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, bufferContent))
+ if err != nil {
+ log.Errorf("update messsage failed:%s", err)
+ }
+ ctx.SetContext("ReferenceAppended", true)
+ return fmt.Sprintf("data: %s", modifiedMessage)
+ }
+ }
+ // Content has prefix
+ // Check for complete tag
+ thinkEnd := strings.Index(bufferContent, "")
+ if thinkEnd != -1 {
+ modifiedContent := bufferContent[:thinkEnd+8] +
+ fmt.Sprintf("\n%s\n\n%s", references, bufferContent[thinkEnd+8:])
+ modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", modifiedContent)
if err != nil {
- log.Errorf("modify response delta content failed, err:%v", err)
- return partialMessage, true
+ log.Errorf("update messsage failed:%s", err)
}
- modifiedMessage = fmt.Sprintf("data: %s", modifiedMessage)
- log.Debugf("modified message: %s", firstMessage)
- messages[0] = string(modifiedMessage)
- return []byte(strings.Join(messages, "\n\n")), true
+ ctx.SetContext("ReferenceAppended", true)
+ return fmt.Sprintf("data: %s", modifiedMessage)
}
- ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, partialMessage)
- return nil, false
+
+ // Check for partial tag at end of buffer
+ // Look for any partial match that could be completed in next message
+ for i := 1; i < len(""); i++ {
+ if strings.HasSuffix(bufferContent, ""[:i]) {
+ // Store only the partial match for the next message
+ ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, bufferContent[len(bufferContent)-i:])
+ // Return the content before the partial match
+ modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", bufferContent[:len(bufferContent)-i])
+ if err != nil {
+ log.Errorf("update messsage failed:%s", err)
+ }
+ return fmt.Sprintf("data: %s", modifiedMessage)
+ }
+ }
+
+ ctx.SetContext(BUFFER_CONTENT_CONTEXT_KEY, "")
+ return sseMessage
}