more optimize of ai search plugin (#1896)

This commit is contained in:
澄潭
2025-03-14 23:24:22 +08:00
committed by GitHub
parent f09e029a6b
commit 34b3fc3114
9 changed files with 229 additions and 22 deletions

View File

@@ -20,6 +20,7 @@ description: higress 支持通过集成搜索引擎Google/Bing/Arxiv/Elastics
| defaultEnable | bool | 选填 | true | 插件功能默认是否开启。设置为false时仅当请求中包含web_search_options字段时才启用插件功能 |
| needReference | bool | 选填 | false | 是否在回答中添加引用来源 |
| referenceFormat | string | 选填 | `"**References:**\n%s"` | 引用内容格式,必须包含%s占位符 |
| referenceLocation | string | 选填 | "head" | 引用位置:"head"在回答开头,"tail"在回答结尾 |
| defaultLang | string | 选填 | - | 默认搜索语言代码如zh-CN/en-US |
| promptTemplate | string | 选填 | 内置模板 | 提示模板,必须包含`{search_results}``{question}`占位符 |
| searchFrom | array of object | 必填 | - | 参考下面搜索引擎配置,至少配置一个引擎 |
@@ -45,6 +46,7 @@ description: higress 支持通过集成搜索引擎Google/Bing/Arxiv/Elastics
| llmUrl | string | 必填 | - | LLM服务API地址 |
| llmModelName | string | 必填 | - | LLM模型名称 |
| timeoutMillisecond | number | 选填 | 30000 | API调用超时时间毫秒 |
| maxCount | number | 选填 | 3 | 搜索重写生成的最大查询次数 |
## 搜索引擎通用配置
@@ -225,6 +227,18 @@ searchFrom:
servicePort: 8080
```
### 自定义引用位置
```yaml
needReference: true
referenceLocation: "tail" # 在回答结尾添加引用,而不是开头
searchFrom:
- type: bing
apiKey: "your-bing-key"
serviceName: "search-service.dns"
servicePort: 8080
```
### 搜索重写配置
```yaml
@@ -259,6 +273,25 @@ searchFrom:
这种配置适用于支持web_search选项的模型例如OpenAI的gpt-4o-search-preview模型。当请求中包含`web_search_options`字段时,即使是空对象(`"web_search_options": {}`),插件也会被激活。
### 搜索上下文大小配置
通过在请求中的`web_search_options`字段中添加`search_context_size`参数,可以动态调整搜索查询次数:
```json
{
"web_search_options": {
"search_context_size": "medium"
}
}
```
`search_context_size`支持三个级别:
- `low`: 生成1个搜索查询适合简单问题
- `medium`: 生成3个搜索查询默认值
- `high`: 生成5个搜索查询适合复杂问题
这个设置会覆盖配置中的`maxCount`值,允许客户端根据问题复杂度动态调整搜索深度。
## 注意事项
1. 提示词模版必须包含`{search_results}``{question}`占位符,可选使用`{cur_date}`插入当前日期格式2006年1月2日

View File

@@ -20,6 +20,7 @@ Plugin execution priority: `440`
| defaultEnable | bool | Optional | true | Whether the plugin functionality is enabled by default. When set to false, the plugin will only be activated when the request contains a web_search_options field |
| needReference | bool | Optional | false | Whether to add reference sources in the response |
| referenceFormat | string | Optional | `"**References:**\n%s"` | Reference content format, must include %s placeholder |
| referenceLocation | string | Optional | "head" | Reference position: "head" at the beginning of the response, "tail" at the end of the response |
| defaultLang | string | Optional | - | Default search language code (e.g. zh-CN/en-US) |
| promptTemplate | string | Optional | Built-in template | Prompt template, must include `{search_results}` and `{question}` placeholders |
| searchFrom | array of object | Required | - | Refer to search engine configuration below, at least one engine must be configured |
@@ -45,6 +46,7 @@ It is strongly recommended to enable this feature when using Arxiv or Elasticsea
| llmUrl | string | Required | - | LLM service API URL |
| llmModelName | string | Required | - | LLM model name |
| timeoutMillisecond | number | Optional | 30000 | API call timeout (milliseconds) |
| maxCount | number | Optional | 3 | Maximum number of search queries generated by the search rewrite |
## Search Engine Common Configuration
@@ -224,6 +226,18 @@ searchFrom:
servicePort: 8080
```
### Custom Reference Location
```yaml
needReference: true
referenceLocation: "tail" # Add references at the end of the response instead of the beginning
searchFrom:
- type: bing
apiKey: "your-bing-key"
serviceName: "search-service.dns"
servicePort: 8080
```
### Search Rewrite Configuration
```yaml
@@ -258,6 +272,25 @@ searchFrom:
This configuration is suitable for models that support web search options, such as OpenAI's gpt-4o-search-preview model. When the request contains a `web_search_options` field, even if it's an empty object (`"web_search_options": {}`), the plugin will be activated.
### Search Context Size Configuration
You can dynamically adjust the number of search queries by adding a `search_context_size` parameter in the `web_search_options` field of the request:
```json
{
"web_search_options": {
"search_context_size": "medium"
}
}
```
The `search_context_size` supports three levels:
- `low`: Generates 1 search query (suitable for simple questions)
- `medium`: Generates 3 search queries (default)
- `high`: Generates 5 search queries (suitable for complex questions)
This setting overrides the `maxCount` value in the configuration, allowing clients to dynamically adjust search depth based on question complexity.
## Notes
1. The prompt template must include `{search_results}` and `{question}` placeholders, optionally use `{cur_date}` to insert current date (format: January 2, 2006)

View File

@@ -46,16 +46,19 @@ type SearchRewrite struct {
modelName string
timeoutMillisecond uint32
prompt string
promptTemplate string // Original prompt template before replacing placeholders
maxCount int
}
type Config struct {
engine []engine.SearchEngine
promptTemplate string
referenceFormat string
defaultLanguage string
needReference bool
searchRewrite *SearchRewrite
defaultEnable bool
engine []engine.SearchEngine
promptTemplate string
referenceFormat string
defaultLanguage string
needReference bool
referenceLocation string // "head" or "tail"
searchRewrite *SearchRewrite
defaultEnable bool
}
const (
@@ -74,6 +77,9 @@ var internetSearchPrompts string
//go:embed prompts/private.md
var privateSearchPrompts string
//go:embed prompts/chinese-internet.md
var chineseInternetSearchPrompts string
func main() {
wrapper.SetCtx(
"ai-search",
@@ -99,6 +105,13 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
} else if !strings.Contains(config.referenceFormat, "%s") {
return fmt.Errorf("invalid referenceFormat:%s", config.referenceFormat)
}
config.referenceLocation = json.Get("referenceLocation").String()
if config.referenceLocation == "" {
config.referenceLocation = "head" // Default to head if not specified
} else if config.referenceLocation != "head" && config.referenceLocation != "tail" {
return fmt.Errorf("invalid referenceLocation:%s, must be 'head' or 'tail'", config.referenceLocation)
}
}
config.defaultLanguage = json.Get("defaultLang").String()
config.promptTemplate = json.Get("promptTemplate").String()
@@ -144,6 +157,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
return fmt.Errorf("invalid promptTemplate, must contains {search_results} and {question}:%s", config.promptTemplate)
}
var internetExists, privateExists, arxivExists bool
var onlyQuark bool = true
for _, e := range json.Get("searchFrom").Array() {
switch e.Get("type").String() {
case "bing":
@@ -153,6 +167,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
}
config.engine = append(config.engine, searchEngine)
internetExists = true
onlyQuark = false
case "google":
searchEngine, err := google.NewGoogleSearch(&e)
if err != nil {
@@ -160,6 +175,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
}
config.engine = append(config.engine, searchEngine)
internetExists = true
onlyQuark = false
case "arxiv":
searchEngine, err := arxiv.NewArxivSearch(&e)
if err != nil {
@@ -167,6 +183,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
}
config.engine = append(config.engine, searchEngine)
arxivExists = true
onlyQuark = false
case "elasticsearch":
searchEngine, err := elasticsearch.NewElasticsearchSearch(&e)
if err != nil {
@@ -174,6 +191,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
}
config.engine = append(config.engine, searchEngine)
privateExists = true
onlyQuark = false
case "quark":
searchEngine, err := quark.NewQuarkSearch(&e)
if err != nil {
@@ -217,6 +235,12 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
llmTimeout = 30000
}
searchRewrite.timeoutMillisecond = uint32(llmTimeout)
maxCount := searchRewriteJson.Get("maxCount").Int()
if maxCount == 0 {
maxCount = 3 // Default value
}
searchRewrite.maxCount = int(maxCount)
// The consideration here is that internet searches are generally available, but arxiv and private sources may not be.
if arxivExists {
if privateExists {
@@ -231,8 +255,18 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
searchRewrite.prompt = privateSearchPrompts
} else if internetExists {
// only internet
searchRewrite.prompt = internetSearchPrompts
if onlyQuark {
// When only quark is used, use chinese-internet.md
searchRewrite.prompt = chineseInternetSearchPrompts
} else {
searchRewrite.prompt = internetSearchPrompts
}
}
// Store the original prompt template before replacing placeholders
searchRewrite.promptTemplate = searchRewrite.prompt
// Replace {max_count} placeholder in the prompt with the configured value
searchRewrite.prompt = strings.Replace(searchRewrite.prompt, "{max_count}", fmt.Sprintf("%d", searchRewrite.maxCount), -1)
config.searchRewrite = searchRewrite
}
if len(config.engine) == 0 {
@@ -260,9 +294,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Lo
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
// Check if plugin should be enabled based on config and request
webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !config.defaultEnable {
// When defaultEnable is false, we need to check if web_search_options exists in the request
webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !webSearchOptions.Exists() {
log.Debugf("Plugin disabled by config and no web_search_options in request")
return types.ActionContinue
@@ -286,6 +320,36 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
}
searchRewrite := config.searchRewrite
if searchRewrite != nil {
// Check if web_search_options.search_context_size exists and adjust maxCount accordingly
if webSearchOptions.Exists() {
searchContextSize := webSearchOptions.Get("search_context_size").String()
if searchContextSize != "" {
originalMaxCount := searchRewrite.maxCount
switch searchContextSize {
case "low":
searchRewrite.maxCount = 1
log.Debugf("Setting maxCount to 1 based on search_context_size=low")
case "medium":
searchRewrite.maxCount = 3
log.Debugf("Setting maxCount to 3 based on search_context_size=medium")
case "high":
searchRewrite.maxCount = 5
log.Debugf("Setting maxCount to 5 based on search_context_size=high")
default:
log.Warnf("Unknown search_context_size value: %s, using configured maxCount: %d",
searchContextSize, searchRewrite.maxCount)
}
// If maxCount changed, regenerate the prompt from the template
if originalMaxCount != searchRewrite.maxCount && searchRewrite.promptTemplate != "" {
searchRewrite.prompt = strings.Replace(
searchRewrite.promptTemplate,
"{max_count}",
fmt.Sprintf("%d", searchRewrite.maxCount),
-1)
}
}
}
startTime := time.Now()
rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1)
rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf(
@@ -509,16 +573,32 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
}
content := gjson.GetBytes(body, "choices.0.message.content").String()
var modifiedContent string
formattedReferences := fmt.Sprintf(config.referenceFormat, references)
if strings.HasPrefix(strings.TrimLeftFunc(content, unicode.IsSpace), "<think>") {
thinkEnd := strings.Index(content, "</think>")
if thinkEnd != -1 {
modifiedContent = content[:thinkEnd+8] +
fmt.Sprintf("\n%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content[thinkEnd+8:])
if config.referenceLocation == "tail" {
// Add references at the end
modifiedContent = content + fmt.Sprintf("\n\n%s", formattedReferences)
} else {
// Default: add references after </think> tag
modifiedContent = content[:thinkEnd+8] +
fmt.Sprintf("\n%s\n\n%s", formattedReferences, content[thinkEnd+8:])
}
}
}
if modifiedContent == "" {
modifiedContent = fmt.Sprintf("%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content)
if config.referenceLocation == "tail" {
// Add references at the end
modifiedContent = fmt.Sprintf("%s\n\n%s", content, formattedReferences)
} else {
// Default: add references at the beginning
modifiedContent = fmt.Sprintf("%s\n\n%s", formattedReferences, content)
}
}
body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent)
if err != nil {
log.Errorf("modify response message content failed, err:%v, body:%s", err, body)
@@ -561,7 +641,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
var newMessages []string
for i, msg := range messages {
if i < len(messages)-1 {
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), log)
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log)
if newMsg != "" {
newMessages = append(newMessages, newMsg)
}
@@ -579,7 +659,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
}
}
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, log wrapper.Log) string {
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log wrapper.Log) string {
log.Debugf("single sse message: %s", sseMessage)
subMessages := strings.Split(sseMessage, "\n")
var message string
@@ -600,6 +680,26 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
}
bodyJson = strings.TrimPrefix(bodyJson, " ")
bodyJson = strings.TrimSuffix(bodyJson, "\n")
// If tailReference is true, only check if this is the last message
if tailReference {
// Check if this is the last message in the stream (finish_reason is "stop")
finishReason := gjson.Get(bodyJson, "choices.0.finish_reason").String()
if finishReason == "stop" {
// This is the last message, append references at the end
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", deltaContent+fmt.Sprintf("\n\n%s", references))
if err != nil {
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
}
// Not the last message, return original message
return sseMessage
}
// Original head reference logic
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
// Skip the preceding content that might be empty due to the presence of a separate reasoning_content field.
if deltaContent == "" {
@@ -615,7 +715,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
if !strings.Contains(strings.TrimLeftFunc(bufferContent, unicode.IsSpace), "<think>") {
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, bufferContent))
if err != nil {
log.Errorf("update messsage failed:%s", err)
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
@@ -629,7 +729,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
fmt.Sprintf("\n%s\n\n%s", references, bufferContent[thinkEnd+8:])
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", modifiedContent)
if err != nil {
log.Errorf("update messsage failed:%s", err)
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
@@ -644,7 +744,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
// Return the content before the partial match
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", bufferContent[:len(bufferContent)-i])
if err != nil {
log.Errorf("update messsage failed:%s", err)
log.Errorf("update message failed:%s", err)
}
return fmt.Sprintf("data: %s", modifiedMessage)
}

View File

@@ -185,7 +185,7 @@ none
4.2.2. 根据问题所属领域将问题拆分成多组关键词的组合同时组合中的关键词个数尽量不要超过3个
5. Final: 按照下面**回复内容示例**进行回复,注意:
- 不要输出思考过程
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在5次以内
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
- 查询搜索引擎时,需要以"internet:"开头
- 查询Arxiv论文时需要以Arxiv的Category值开头例如"cs.AI:"
- 查询Arxiv论文时优先用英文表述关键词进行搜索

View File

@@ -0,0 +1,39 @@
# 目标
你需要分析**用户发送的消息**,是否需要查询中文搜索引擎,并按照如下情况回复相应内容:
## 情况一:不需要查询搜索引擎
### 情况举例:
1. **用户发送的消息**不是在提问或寻求帮助
2. **用户发送的消息**是要求翻译文字
### 思考过程
根据上面的**情况举例**,如果符合,则按照下面**回复内容示例**进行回复,注意不要输出思考过程
### 回复内容示例:
none
## 情况二:需要查询搜索引擎
### 情况举例:
1. 答复**用户发送的消息**,需依赖互联网上最新的资料
2. 答复**用户发送的消息**,需依赖论文等专业资料
3. 通过查询资料,可以更好地答复**用户发送的消息**
### 思考过程
根据上面的**情况举例**,以及其他需要查询资料的情况,如果符合,按照以下步骤思考,并按照下面**回复内容示例**进行回复,注意不要输出思考过程:
1. What: 分析要答复**用户发送的消息**,需要了解什么知识和资料
2. How: 分析对于要查询的知识和资料,应该提出什么样的问题
3. Adjust: 明确查询什么问题后,用一句话概括问题,并且针对搜索引擎做问题优化
4. Final: 按照下面**回复内容示例**进行回复,注意:
- 不要输出思考过程
- 可以查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
- 需要以"internet:"开头
- 即使**用户发送的消息**使用了中文以外的其他语言,也用中文向搜索引擎查询问题,但注意不要翻译专有名词
### 回复内容示例:
#### 查询多次搜索引擎
internet: 黄金价格走势
internet: 历史黄金价格高点
# 用户发送的消息为:
{question}

View File

@@ -186,7 +186,7 @@ none
4.3.2. 根据问题所属领域将问题拆分成多组关键词的组合同时组合中的关键词个数尽量不要超过3个
5. Final: 按照下面**回复内容示例**进行回复,注意:
- 不要输出思考过程
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在5次以内
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
- 查询搜索引擎时,需要以"internet:"开头
- 查询私有知识库时,需要以"private:"开头
- 查询Arxiv论文时需要以Arxiv的Category值开头例如"cs.AI:"

View File

@@ -25,7 +25,7 @@ none
3. Adjust: 明确查询什么问题后,用一句话概括问题,并且针对搜索引擎做问题优化
4. Final: 按照下面**回复内容示例**进行回复,注意:
- 不要输出思考过程
- 可以查询多次,多个查询用换行分隔,总查询次数控制在5次以内
- 可以查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
- 需要以"internet:"开头
- 尽量满足**用户发送的消息**中的搜索要求,例如用户要求用英文搜索,则需用英文表述问题和关键词
- 用户如果没有要求搜索语言,则用和**用户发送的消息**一致的语言表述问题和关键词

View File

@@ -28,7 +28,7 @@ none
4.2. 向私有知识库提问:用一句话概括问题,私有知识库不需要对关键词进行拆分
5. Final: 按照下面**回复内容示例**进行回复,注意:
- 不要输出思考过程
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在5次以内
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
- 查询搜索引擎时,需要以"internet:"开头
- 查询私有知识库时,需要以"private:"开头
- 当用多个关键词查询时,关键词之间用","分隔

View File

@@ -7,7 +7,8 @@ def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='AI Search Test Script')
parser.add_argument('--question', required=True, help='The question to analyze')
parser.add_argument('--prompt', required=True, help='The prompt file to analyze')
parser.add_argument('--prompt', required=True, help='The prompt file to analyze')
parser.add_argument('--count', required=True, help='The max search count')
args = parser.parse_args()
# 读取并解析prompts.md模板
@@ -17,6 +18,7 @@ def main():
# 替换模板中的{question}变量
prompt = prompt_template.replace('{question}', args.question)
prompt = prompt_template.replace('{max_count}', args.count)
# 准备请求数据
headers = {