mirror of
https://github.com/alibaba/higress.git
synced 2026-06-10 05:07:30 +08:00
more optimize of ai search plugin (#1896)
This commit is contained in:
@@ -20,6 +20,7 @@ description: higress 支持通过集成搜索引擎(Google/Bing/Arxiv/Elastics
|
|||||||
| defaultEnable | bool | 选填 | true | 插件功能默认是否开启。设置为false时,仅当请求中包含web_search_options字段时才启用插件功能 |
|
| defaultEnable | bool | 选填 | true | 插件功能默认是否开启。设置为false时,仅当请求中包含web_search_options字段时才启用插件功能 |
|
||||||
| needReference | bool | 选填 | false | 是否在回答中添加引用来源 |
|
| needReference | bool | 选填 | false | 是否在回答中添加引用来源 |
|
||||||
| referenceFormat | string | 选填 | `"**References:**\n%s"` | 引用内容格式,必须包含%s占位符 |
|
| referenceFormat | string | 选填 | `"**References:**\n%s"` | 引用内容格式,必须包含%s占位符 |
|
||||||
|
| referenceLocation | string | 选填 | "head" | 引用位置:"head"在回答开头,"tail"在回答结尾 |
|
||||||
| defaultLang | string | 选填 | - | 默认搜索语言代码(如zh-CN/en-US) |
|
| defaultLang | string | 选填 | - | 默认搜索语言代码(如zh-CN/en-US) |
|
||||||
| promptTemplate | string | 选填 | 内置模板 | 提示模板,必须包含`{search_results}`和`{question}`占位符 |
|
| promptTemplate | string | 选填 | 内置模板 | 提示模板,必须包含`{search_results}`和`{question}`占位符 |
|
||||||
| searchFrom | array of object | 必填 | - | 参考下面搜索引擎配置,至少配置一个引擎 |
|
| searchFrom | array of object | 必填 | - | 参考下面搜索引擎配置,至少配置一个引擎 |
|
||||||
@@ -45,6 +46,7 @@ description: higress 支持通过集成搜索引擎(Google/Bing/Arxiv/Elastics
|
|||||||
| llmUrl | string | 必填 | - | LLM服务API地址 |
|
| llmUrl | string | 必填 | - | LLM服务API地址 |
|
||||||
| llmModelName | string | 必填 | - | LLM模型名称 |
|
| llmModelName | string | 必填 | - | LLM模型名称 |
|
||||||
| timeoutMillisecond | number | 选填 | 30000 | API调用超时时间(毫秒) |
|
| timeoutMillisecond | number | 选填 | 30000 | API调用超时时间(毫秒) |
|
||||||
|
| maxCount | number | 选填 | 3 | 搜索重写生成的最大查询次数 |
|
||||||
|
|
||||||
## 搜索引擎通用配置
|
## 搜索引擎通用配置
|
||||||
|
|
||||||
@@ -225,6 +227,18 @@ searchFrom:
|
|||||||
servicePort: 8080
|
servicePort: 8080
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 自定义引用位置
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
needReference: true
|
||||||
|
referenceLocation: "tail" # 在回答结尾添加引用,而不是开头
|
||||||
|
searchFrom:
|
||||||
|
- type: bing
|
||||||
|
apiKey: "your-bing-key"
|
||||||
|
serviceName: "search-service.dns"
|
||||||
|
servicePort: 8080
|
||||||
|
```
|
||||||
|
|
||||||
### 搜索重写配置
|
### 搜索重写配置
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -259,6 +273,25 @@ searchFrom:
|
|||||||
|
|
||||||
这种配置适用于支持web_search选项的模型,例如OpenAI的gpt-4o-search-preview模型。当请求中包含`web_search_options`字段时,即使是空对象(`"web_search_options": {}`),插件也会被激活。
|
这种配置适用于支持web_search选项的模型,例如OpenAI的gpt-4o-search-preview模型。当请求中包含`web_search_options`字段时,即使是空对象(`"web_search_options": {}`),插件也会被激活。
|
||||||
|
|
||||||
|
### 搜索上下文大小配置
|
||||||
|
|
||||||
|
通过在请求中的`web_search_options`字段中添加`search_context_size`参数,可以动态调整搜索查询次数:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"web_search_options": {
|
||||||
|
"search_context_size": "medium"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`search_context_size`支持三个级别:
|
||||||
|
- `low`: 生成1个搜索查询(适合简单问题)
|
||||||
|
- `medium`: 生成3个搜索查询(默认值)
|
||||||
|
- `high`: 生成5个搜索查询(适合复杂问题)
|
||||||
|
|
||||||
|
这个设置会覆盖配置中的`maxCount`值,允许客户端根据问题复杂度动态调整搜索深度。
|
||||||
|
|
||||||
## 注意事项
|
## 注意事项
|
||||||
|
|
||||||
1. 提示词模版必须包含`{search_results}`和`{question}`占位符,可选使用`{cur_date}`插入当前日期(格式:2006年1月2日)
|
1. 提示词模版必须包含`{search_results}`和`{question}`占位符,可选使用`{cur_date}`插入当前日期(格式:2006年1月2日)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ Plugin execution priority: `440`
|
|||||||
| defaultEnable | bool | Optional | true | Whether the plugin functionality is enabled by default. When set to false, the plugin will only be activated when the request contains a web_search_options field |
|
| defaultEnable | bool | Optional | true | Whether the plugin functionality is enabled by default. When set to false, the plugin will only be activated when the request contains a web_search_options field |
|
||||||
| needReference | bool | Optional | false | Whether to add reference sources in the response |
|
| needReference | bool | Optional | false | Whether to add reference sources in the response |
|
||||||
| referenceFormat | string | Optional | `"**References:**\n%s"` | Reference content format, must include %s placeholder |
|
| referenceFormat | string | Optional | `"**References:**\n%s"` | Reference content format, must include %s placeholder |
|
||||||
|
| referenceLocation | string | Optional | "head" | Reference position: "head" at the beginning of the response, "tail" at the end of the response |
|
||||||
| defaultLang | string | Optional | - | Default search language code (e.g. zh-CN/en-US) |
|
| defaultLang | string | Optional | - | Default search language code (e.g. zh-CN/en-US) |
|
||||||
| promptTemplate | string | Optional | Built-in template | Prompt template, must include `{search_results}` and `{question}` placeholders |
|
| promptTemplate | string | Optional | Built-in template | Prompt template, must include `{search_results}` and `{question}` placeholders |
|
||||||
| searchFrom | array of object | Required | - | Refer to search engine configuration below, at least one engine must be configured |
|
| searchFrom | array of object | Required | - | Refer to search engine configuration below, at least one engine must be configured |
|
||||||
@@ -45,6 +46,7 @@ It is strongly recommended to enable this feature when using Arxiv or Elasticsea
|
|||||||
| llmUrl | string | Required | - | LLM service API URL |
|
| llmUrl | string | Required | - | LLM service API URL |
|
||||||
| llmModelName | string | Required | - | LLM model name |
|
| llmModelName | string | Required | - | LLM model name |
|
||||||
| timeoutMillisecond | number | Optional | 30000 | API call timeout (milliseconds) |
|
| timeoutMillisecond | number | Optional | 30000 | API call timeout (milliseconds) |
|
||||||
|
| maxCount | number | Optional | 3 | Maximum number of search queries generated by the search rewrite |
|
||||||
|
|
||||||
## Search Engine Common Configuration
|
## Search Engine Common Configuration
|
||||||
|
|
||||||
@@ -224,6 +226,18 @@ searchFrom:
|
|||||||
servicePort: 8080
|
servicePort: 8080
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Custom Reference Location
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
needReference: true
|
||||||
|
referenceLocation: "tail" # Add references at the end of the response instead of the beginning
|
||||||
|
searchFrom:
|
||||||
|
- type: bing
|
||||||
|
apiKey: "your-bing-key"
|
||||||
|
serviceName: "search-service.dns"
|
||||||
|
servicePort: 8080
|
||||||
|
```
|
||||||
|
|
||||||
### Search Rewrite Configuration
|
### Search Rewrite Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -258,6 +272,25 @@ searchFrom:
|
|||||||
|
|
||||||
This configuration is suitable for models that support web search options, such as OpenAI's gpt-4o-search-preview model. When the request contains a `web_search_options` field, even if it's an empty object (`"web_search_options": {}`), the plugin will be activated.
|
This configuration is suitable for models that support web search options, such as OpenAI's gpt-4o-search-preview model. When the request contains a `web_search_options` field, even if it's an empty object (`"web_search_options": {}`), the plugin will be activated.
|
||||||
|
|
||||||
|
### Search Context Size Configuration
|
||||||
|
|
||||||
|
You can dynamically adjust the number of search queries by adding a `search_context_size` parameter in the `web_search_options` field of the request:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"web_search_options": {
|
||||||
|
"search_context_size": "medium"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The `search_context_size` supports three levels:
|
||||||
|
- `low`: Generates 1 search query (suitable for simple questions)
|
||||||
|
- `medium`: Generates 3 search queries (default)
|
||||||
|
- `high`: Generates 5 search queries (suitable for complex questions)
|
||||||
|
|
||||||
|
This setting overrides the `maxCount` value in the configuration, allowing clients to dynamically adjust search depth based on question complexity.
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
1. The prompt template must include `{search_results}` and `{question}` placeholders, optionally use `{cur_date}` to insert current date (format: January 2, 2006)
|
1. The prompt template must include `{search_results}` and `{question}` placeholders, optionally use `{cur_date}` to insert current date (format: January 2, 2006)
|
||||||
|
|||||||
@@ -46,16 +46,19 @@ type SearchRewrite struct {
|
|||||||
modelName string
|
modelName string
|
||||||
timeoutMillisecond uint32
|
timeoutMillisecond uint32
|
||||||
prompt string
|
prompt string
|
||||||
|
promptTemplate string // Original prompt template before replacing placeholders
|
||||||
|
maxCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
engine []engine.SearchEngine
|
engine []engine.SearchEngine
|
||||||
promptTemplate string
|
promptTemplate string
|
||||||
referenceFormat string
|
referenceFormat string
|
||||||
defaultLanguage string
|
defaultLanguage string
|
||||||
needReference bool
|
needReference bool
|
||||||
searchRewrite *SearchRewrite
|
referenceLocation string // "head" or "tail"
|
||||||
defaultEnable bool
|
searchRewrite *SearchRewrite
|
||||||
|
defaultEnable bool
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -74,6 +77,9 @@ var internetSearchPrompts string
|
|||||||
//go:embed prompts/private.md
|
//go:embed prompts/private.md
|
||||||
var privateSearchPrompts string
|
var privateSearchPrompts string
|
||||||
|
|
||||||
|
//go:embed prompts/chinese-internet.md
|
||||||
|
var chineseInternetSearchPrompts string
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
wrapper.SetCtx(
|
wrapper.SetCtx(
|
||||||
"ai-search",
|
"ai-search",
|
||||||
@@ -99,6 +105,13 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
|||||||
} else if !strings.Contains(config.referenceFormat, "%s") {
|
} else if !strings.Contains(config.referenceFormat, "%s") {
|
||||||
return fmt.Errorf("invalid referenceFormat:%s", config.referenceFormat)
|
return fmt.Errorf("invalid referenceFormat:%s", config.referenceFormat)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config.referenceLocation = json.Get("referenceLocation").String()
|
||||||
|
if config.referenceLocation == "" {
|
||||||
|
config.referenceLocation = "head" // Default to head if not specified
|
||||||
|
} else if config.referenceLocation != "head" && config.referenceLocation != "tail" {
|
||||||
|
return fmt.Errorf("invalid referenceLocation:%s, must be 'head' or 'tail'", config.referenceLocation)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
config.defaultLanguage = json.Get("defaultLang").String()
|
config.defaultLanguage = json.Get("defaultLang").String()
|
||||||
config.promptTemplate = json.Get("promptTemplate").String()
|
config.promptTemplate = json.Get("promptTemplate").String()
|
||||||
@@ -144,6 +157,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
|||||||
return fmt.Errorf("invalid promptTemplate, must contains {search_results} and {question}:%s", config.promptTemplate)
|
return fmt.Errorf("invalid promptTemplate, must contains {search_results} and {question}:%s", config.promptTemplate)
|
||||||
}
|
}
|
||||||
var internetExists, privateExists, arxivExists bool
|
var internetExists, privateExists, arxivExists bool
|
||||||
|
var onlyQuark bool = true
|
||||||
for _, e := range json.Get("searchFrom").Array() {
|
for _, e := range json.Get("searchFrom").Array() {
|
||||||
switch e.Get("type").String() {
|
switch e.Get("type").String() {
|
||||||
case "bing":
|
case "bing":
|
||||||
@@ -153,6 +167,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
|||||||
}
|
}
|
||||||
config.engine = append(config.engine, searchEngine)
|
config.engine = append(config.engine, searchEngine)
|
||||||
internetExists = true
|
internetExists = true
|
||||||
|
onlyQuark = false
|
||||||
case "google":
|
case "google":
|
||||||
searchEngine, err := google.NewGoogleSearch(&e)
|
searchEngine, err := google.NewGoogleSearch(&e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -160,6 +175,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
|||||||
}
|
}
|
||||||
config.engine = append(config.engine, searchEngine)
|
config.engine = append(config.engine, searchEngine)
|
||||||
internetExists = true
|
internetExists = true
|
||||||
|
onlyQuark = false
|
||||||
case "arxiv":
|
case "arxiv":
|
||||||
searchEngine, err := arxiv.NewArxivSearch(&e)
|
searchEngine, err := arxiv.NewArxivSearch(&e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -167,6 +183,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
|||||||
}
|
}
|
||||||
config.engine = append(config.engine, searchEngine)
|
config.engine = append(config.engine, searchEngine)
|
||||||
arxivExists = true
|
arxivExists = true
|
||||||
|
onlyQuark = false
|
||||||
case "elasticsearch":
|
case "elasticsearch":
|
||||||
searchEngine, err := elasticsearch.NewElasticsearchSearch(&e)
|
searchEngine, err := elasticsearch.NewElasticsearchSearch(&e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -174,6 +191,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
|||||||
}
|
}
|
||||||
config.engine = append(config.engine, searchEngine)
|
config.engine = append(config.engine, searchEngine)
|
||||||
privateExists = true
|
privateExists = true
|
||||||
|
onlyQuark = false
|
||||||
case "quark":
|
case "quark":
|
||||||
searchEngine, err := quark.NewQuarkSearch(&e)
|
searchEngine, err := quark.NewQuarkSearch(&e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -217,6 +235,12 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
|||||||
llmTimeout = 30000
|
llmTimeout = 30000
|
||||||
}
|
}
|
||||||
searchRewrite.timeoutMillisecond = uint32(llmTimeout)
|
searchRewrite.timeoutMillisecond = uint32(llmTimeout)
|
||||||
|
|
||||||
|
maxCount := searchRewriteJson.Get("maxCount").Int()
|
||||||
|
if maxCount == 0 {
|
||||||
|
maxCount = 3 // Default value
|
||||||
|
}
|
||||||
|
searchRewrite.maxCount = int(maxCount)
|
||||||
// The consideration here is that internet searches are generally available, but arxiv and private sources may not be.
|
// The consideration here is that internet searches are generally available, but arxiv and private sources may not be.
|
||||||
if arxivExists {
|
if arxivExists {
|
||||||
if privateExists {
|
if privateExists {
|
||||||
@@ -231,8 +255,18 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
|||||||
searchRewrite.prompt = privateSearchPrompts
|
searchRewrite.prompt = privateSearchPrompts
|
||||||
} else if internetExists {
|
} else if internetExists {
|
||||||
// only internet
|
// only internet
|
||||||
searchRewrite.prompt = internetSearchPrompts
|
if onlyQuark {
|
||||||
|
// When only quark is used, use chinese-internet.md
|
||||||
|
searchRewrite.prompt = chineseInternetSearchPrompts
|
||||||
|
} else {
|
||||||
|
searchRewrite.prompt = internetSearchPrompts
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store the original prompt template before replacing placeholders
|
||||||
|
searchRewrite.promptTemplate = searchRewrite.prompt
|
||||||
|
// Replace {max_count} placeholder in the prompt with the configured value
|
||||||
|
searchRewrite.prompt = strings.Replace(searchRewrite.prompt, "{max_count}", fmt.Sprintf("%d", searchRewrite.maxCount), -1)
|
||||||
config.searchRewrite = searchRewrite
|
config.searchRewrite = searchRewrite
|
||||||
}
|
}
|
||||||
if len(config.engine) == 0 {
|
if len(config.engine) == 0 {
|
||||||
@@ -260,9 +294,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Lo
|
|||||||
|
|
||||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
|
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
|
||||||
// Check if plugin should be enabled based on config and request
|
// Check if plugin should be enabled based on config and request
|
||||||
|
webSearchOptions := gjson.GetBytes(body, "web_search_options")
|
||||||
if !config.defaultEnable {
|
if !config.defaultEnable {
|
||||||
// When defaultEnable is false, we need to check if web_search_options exists in the request
|
// When defaultEnable is false, we need to check if web_search_options exists in the request
|
||||||
webSearchOptions := gjson.GetBytes(body, "web_search_options")
|
|
||||||
if !webSearchOptions.Exists() {
|
if !webSearchOptions.Exists() {
|
||||||
log.Debugf("Plugin disabled by config and no web_search_options in request")
|
log.Debugf("Plugin disabled by config and no web_search_options in request")
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
@@ -286,6 +320,36 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
|
|||||||
}
|
}
|
||||||
searchRewrite := config.searchRewrite
|
searchRewrite := config.searchRewrite
|
||||||
if searchRewrite != nil {
|
if searchRewrite != nil {
|
||||||
|
// Check if web_search_options.search_context_size exists and adjust maxCount accordingly
|
||||||
|
if webSearchOptions.Exists() {
|
||||||
|
searchContextSize := webSearchOptions.Get("search_context_size").String()
|
||||||
|
if searchContextSize != "" {
|
||||||
|
originalMaxCount := searchRewrite.maxCount
|
||||||
|
switch searchContextSize {
|
||||||
|
case "low":
|
||||||
|
searchRewrite.maxCount = 1
|
||||||
|
log.Debugf("Setting maxCount to 1 based on search_context_size=low")
|
||||||
|
case "medium":
|
||||||
|
searchRewrite.maxCount = 3
|
||||||
|
log.Debugf("Setting maxCount to 3 based on search_context_size=medium")
|
||||||
|
case "high":
|
||||||
|
searchRewrite.maxCount = 5
|
||||||
|
log.Debugf("Setting maxCount to 5 based on search_context_size=high")
|
||||||
|
default:
|
||||||
|
log.Warnf("Unknown search_context_size value: %s, using configured maxCount: %d",
|
||||||
|
searchContextSize, searchRewrite.maxCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If maxCount changed, regenerate the prompt from the template
|
||||||
|
if originalMaxCount != searchRewrite.maxCount && searchRewrite.promptTemplate != "" {
|
||||||
|
searchRewrite.prompt = strings.Replace(
|
||||||
|
searchRewrite.promptTemplate,
|
||||||
|
"{max_count}",
|
||||||
|
fmt.Sprintf("%d", searchRewrite.maxCount),
|
||||||
|
-1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1)
|
rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1)
|
||||||
rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf(
|
rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf(
|
||||||
@@ -509,16 +573,32 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
|
|||||||
}
|
}
|
||||||
content := gjson.GetBytes(body, "choices.0.message.content").String()
|
content := gjson.GetBytes(body, "choices.0.message.content").String()
|
||||||
var modifiedContent string
|
var modifiedContent string
|
||||||
|
formattedReferences := fmt.Sprintf(config.referenceFormat, references)
|
||||||
|
|
||||||
if strings.HasPrefix(strings.TrimLeftFunc(content, unicode.IsSpace), "<think>") {
|
if strings.HasPrefix(strings.TrimLeftFunc(content, unicode.IsSpace), "<think>") {
|
||||||
thinkEnd := strings.Index(content, "</think>")
|
thinkEnd := strings.Index(content, "</think>")
|
||||||
if thinkEnd != -1 {
|
if thinkEnd != -1 {
|
||||||
modifiedContent = content[:thinkEnd+8] +
|
if config.referenceLocation == "tail" {
|
||||||
fmt.Sprintf("\n%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content[thinkEnd+8:])
|
// Add references at the end
|
||||||
|
modifiedContent = content + fmt.Sprintf("\n\n%s", formattedReferences)
|
||||||
|
} else {
|
||||||
|
// Default: add references after </think> tag
|
||||||
|
modifiedContent = content[:thinkEnd+8] +
|
||||||
|
fmt.Sprintf("\n%s\n\n%s", formattedReferences, content[thinkEnd+8:])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if modifiedContent == "" {
|
if modifiedContent == "" {
|
||||||
modifiedContent = fmt.Sprintf("%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content)
|
if config.referenceLocation == "tail" {
|
||||||
|
// Add references at the end
|
||||||
|
modifiedContent = fmt.Sprintf("%s\n\n%s", content, formattedReferences)
|
||||||
|
} else {
|
||||||
|
// Default: add references at the beginning
|
||||||
|
modifiedContent = fmt.Sprintf("%s\n\n%s", formattedReferences, content)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent)
|
body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("modify response message content failed, err:%v, body:%s", err, body)
|
log.Errorf("modify response message content failed, err:%v, body:%s", err, body)
|
||||||
@@ -561,7 +641,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
|
|||||||
var newMessages []string
|
var newMessages []string
|
||||||
for i, msg := range messages {
|
for i, msg := range messages {
|
||||||
if i < len(messages)-1 {
|
if i < len(messages)-1 {
|
||||||
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), log)
|
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log)
|
||||||
if newMsg != "" {
|
if newMsg != "" {
|
||||||
newMessages = append(newMessages, newMsg)
|
newMessages = append(newMessages, newMsg)
|
||||||
}
|
}
|
||||||
@@ -579,7 +659,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, log wrapper.Log) string {
|
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log wrapper.Log) string {
|
||||||
log.Debugf("single sse message: %s", sseMessage)
|
log.Debugf("single sse message: %s", sseMessage)
|
||||||
subMessages := strings.Split(sseMessage, "\n")
|
subMessages := strings.Split(sseMessage, "\n")
|
||||||
var message string
|
var message string
|
||||||
@@ -600,6 +680,26 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
|
|||||||
}
|
}
|
||||||
bodyJson = strings.TrimPrefix(bodyJson, " ")
|
bodyJson = strings.TrimPrefix(bodyJson, " ")
|
||||||
bodyJson = strings.TrimSuffix(bodyJson, "\n")
|
bodyJson = strings.TrimSuffix(bodyJson, "\n")
|
||||||
|
|
||||||
|
// If tailReference is true, only check if this is the last message
|
||||||
|
if tailReference {
|
||||||
|
// Check if this is the last message in the stream (finish_reason is "stop")
|
||||||
|
finishReason := gjson.Get(bodyJson, "choices.0.finish_reason").String()
|
||||||
|
if finishReason == "stop" {
|
||||||
|
// This is the last message, append references at the end
|
||||||
|
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
|
||||||
|
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", deltaContent+fmt.Sprintf("\n\n%s", references))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("update message failed:%s", err)
|
||||||
|
}
|
||||||
|
ctx.SetContext("ReferenceAppended", true)
|
||||||
|
return fmt.Sprintf("data: %s", modifiedMessage)
|
||||||
|
}
|
||||||
|
// Not the last message, return original message
|
||||||
|
return sseMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original head reference logic
|
||||||
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
|
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
|
||||||
// Skip the preceding content that might be empty due to the presence of a separate reasoning_content field.
|
// Skip the preceding content that might be empty due to the presence of a separate reasoning_content field.
|
||||||
if deltaContent == "" {
|
if deltaContent == "" {
|
||||||
@@ -615,7 +715,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
|
|||||||
if !strings.Contains(strings.TrimLeftFunc(bufferContent, unicode.IsSpace), "<think>") {
|
if !strings.Contains(strings.TrimLeftFunc(bufferContent, unicode.IsSpace), "<think>") {
|
||||||
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, bufferContent))
|
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, bufferContent))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("update messsage failed:%s", err)
|
log.Errorf("update message failed:%s", err)
|
||||||
}
|
}
|
||||||
ctx.SetContext("ReferenceAppended", true)
|
ctx.SetContext("ReferenceAppended", true)
|
||||||
return fmt.Sprintf("data: %s", modifiedMessage)
|
return fmt.Sprintf("data: %s", modifiedMessage)
|
||||||
@@ -629,7 +729,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
|
|||||||
fmt.Sprintf("\n%s\n\n%s", references, bufferContent[thinkEnd+8:])
|
fmt.Sprintf("\n%s\n\n%s", references, bufferContent[thinkEnd+8:])
|
||||||
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", modifiedContent)
|
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", modifiedContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("update messsage failed:%s", err)
|
log.Errorf("update message failed:%s", err)
|
||||||
}
|
}
|
||||||
ctx.SetContext("ReferenceAppended", true)
|
ctx.SetContext("ReferenceAppended", true)
|
||||||
return fmt.Sprintf("data: %s", modifiedMessage)
|
return fmt.Sprintf("data: %s", modifiedMessage)
|
||||||
@@ -644,7 +744,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
|
|||||||
// Return the content before the partial match
|
// Return the content before the partial match
|
||||||
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", bufferContent[:len(bufferContent)-i])
|
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", bufferContent[:len(bufferContent)-i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("update messsage failed:%s", err)
|
log.Errorf("update message failed:%s", err)
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("data: %s", modifiedMessage)
|
return fmt.Sprintf("data: %s", modifiedMessage)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ none
|
|||||||
4.2.2. 根据问题所属领域,将问题拆分成多组关键词的组合,同时组合中的关键词个数尽量不要超过3个
|
4.2.2. 根据问题所属领域,将问题拆分成多组关键词的组合,同时组合中的关键词个数尽量不要超过3个
|
||||||
5. Final: 按照下面**回复内容示例**进行回复,注意:
|
5. Final: 按照下面**回复内容示例**进行回复,注意:
|
||||||
- 不要输出思考过程
|
- 不要输出思考过程
|
||||||
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在5次以内
|
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
|
||||||
- 查询搜索引擎时,需要以"internet:"开头
|
- 查询搜索引擎时,需要以"internet:"开头
|
||||||
- 查询Arxiv论文时,需要以Arxiv的Category值开头,例如"cs.AI:"
|
- 查询Arxiv论文时,需要以Arxiv的Category值开头,例如"cs.AI:"
|
||||||
- 查询Arxiv论文时,优先用英文表述关键词进行搜索
|
- 查询Arxiv论文时,优先用英文表述关键词进行搜索
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
# 目标
|
||||||
|
你需要分析**用户发送的消息**,是否需要查询中文搜索引擎,并按照如下情况回复相应内容:
|
||||||
|
|
||||||
|
## 情况一:不需要查询搜索引擎
|
||||||
|
### 情况举例:
|
||||||
|
1. **用户发送的消息**不是在提问或寻求帮助
|
||||||
|
2. **用户发送的消息**是要求翻译文字
|
||||||
|
|
||||||
|
### 思考过程
|
||||||
|
根据上面的**情况举例**,如果符合,则按照下面**回复内容示例**进行回复,注意不要输出思考过程
|
||||||
|
|
||||||
|
### 回复内容示例:
|
||||||
|
none
|
||||||
|
|
||||||
|
## 情况二:需要查询搜索引擎
|
||||||
|
### 情况举例:
|
||||||
|
1. 答复**用户发送的消息**,需依赖互联网上最新的资料
|
||||||
|
2. 答复**用户发送的消息**,需依赖论文等专业资料
|
||||||
|
3. 通过查询资料,可以更好地答复**用户发送的消息**
|
||||||
|
|
||||||
|
### 思考过程
|
||||||
|
根据上面的**情况举例**,以及其他需要查询资料的情况,如果符合,按照以下步骤思考,并按照下面**回复内容示例**进行回复,注意不要输出思考过程:
|
||||||
|
1. What: 分析要答复**用户发送的消息**,需要了解什么知识和资料
|
||||||
|
2. How: 分析对于要查询的知识和资料,应该提出什么样的问题
|
||||||
|
3. Adjust: 明确查询什么问题后,用一句话概括问题,并且针对搜索引擎做问题优化
|
||||||
|
4. Final: 按照下面**回复内容示例**进行回复,注意:
|
||||||
|
- 不要输出思考过程
|
||||||
|
- 可以查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
|
||||||
|
- 需要以"internet:"开头
|
||||||
|
- 即使**用户发送的消息**使用了中文以外的其他语言,也用中文向搜索引擎查询问题,但注意不要翻译专有名词
|
||||||
|
|
||||||
|
### 回复内容示例:
|
||||||
|
|
||||||
|
#### 查询多次搜索引擎
|
||||||
|
internet: 黄金价格走势
|
||||||
|
internet: 历史黄金价格高点
|
||||||
|
|
||||||
|
# 用户发送的消息为:
|
||||||
|
{question}
|
||||||
@@ -186,7 +186,7 @@ none
|
|||||||
4.3.2. 根据问题所属领域,将问题拆分成多组关键词的组合,同时组合中的关键词个数尽量不要超过3个
|
4.3.2. 根据问题所属领域,将问题拆分成多组关键词的组合,同时组合中的关键词个数尽量不要超过3个
|
||||||
5. Final: 按照下面**回复内容示例**进行回复,注意:
|
5. Final: 按照下面**回复内容示例**进行回复,注意:
|
||||||
- 不要输出思考过程
|
- 不要输出思考过程
|
||||||
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在5次以内
|
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
|
||||||
- 查询搜索引擎时,需要以"internet:"开头
|
- 查询搜索引擎时,需要以"internet:"开头
|
||||||
- 查询私有知识库时,需要以"private:"开头
|
- 查询私有知识库时,需要以"private:"开头
|
||||||
- 查询Arxiv论文时,需要以Arxiv的Category值开头,例如"cs.AI:"
|
- 查询Arxiv论文时,需要以Arxiv的Category值开头,例如"cs.AI:"
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ none
|
|||||||
3. Adjust: 明确查询什么问题后,用一句话概括问题,并且针对搜索引擎做问题优化
|
3. Adjust: 明确查询什么问题后,用一句话概括问题,并且针对搜索引擎做问题优化
|
||||||
4. Final: 按照下面**回复内容示例**进行回复,注意:
|
4. Final: 按照下面**回复内容示例**进行回复,注意:
|
||||||
- 不要输出思考过程
|
- 不要输出思考过程
|
||||||
- 可以查询多次,多个查询用换行分隔,总查询次数控制在5次以内
|
- 可以查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
|
||||||
- 需要以"internet:"开头
|
- 需要以"internet:"开头
|
||||||
- 尽量满足**用户发送的消息**中的搜索要求,例如用户要求用英文搜索,则需用英文表述问题和关键词
|
- 尽量满足**用户发送的消息**中的搜索要求,例如用户要求用英文搜索,则需用英文表述问题和关键词
|
||||||
- 用户如果没有要求搜索语言,则用和**用户发送的消息**一致的语言表述问题和关键词
|
- 用户如果没有要求搜索语言,则用和**用户发送的消息**一致的语言表述问题和关键词
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ none
|
|||||||
4.2. 向私有知识库提问:用一句话概括问题,私有知识库不需要对关键词进行拆分
|
4.2. 向私有知识库提问:用一句话概括问题,私有知识库不需要对关键词进行拆分
|
||||||
5. Final: 按照下面**回复内容示例**进行回复,注意:
|
5. Final: 按照下面**回复内容示例**进行回复,注意:
|
||||||
- 不要输出思考过程
|
- 不要输出思考过程
|
||||||
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在5次以内
|
- 可以向多个查询目标分别查询多次,多个查询用换行分隔,总查询次数控制在{max_count}次以内
|
||||||
- 查询搜索引擎时,需要以"internet:"开头
|
- 查询搜索引擎时,需要以"internet:"开头
|
||||||
- 查询私有知识库时,需要以"private:"开头
|
- 查询私有知识库时,需要以"private:"开头
|
||||||
- 当用多个关键词查询时,关键词之间用","分隔
|
- 当用多个关键词查询时,关键词之间用","分隔
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ def main():
|
|||||||
# 解析命令行参数
|
# 解析命令行参数
|
||||||
parser = argparse.ArgumentParser(description='AI Search Test Script')
|
parser = argparse.ArgumentParser(description='AI Search Test Script')
|
||||||
parser.add_argument('--question', required=True, help='The question to analyze')
|
parser.add_argument('--question', required=True, help='The question to analyze')
|
||||||
parser.add_argument('--prompt', required=True, help='The prompt file to analyze')
|
parser.add_argument('--prompt', required=True, help='The prompt file to analyze')
|
||||||
|
parser.add_argument('--count', required=True, help='The max search count')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# 读取并解析prompts.md模板
|
# 读取并解析prompts.md模板
|
||||||
@@ -17,6 +18,7 @@ def main():
|
|||||||
|
|
||||||
# 替换模板中的{question}变量
|
# 替换模板中的{question}变量
|
||||||
prompt = prompt_template.replace('{question}', args.question)
|
prompt = prompt_template.replace('{question}', args.question)
|
||||||
|
prompt = prompt_template.replace('{max_count}', args.count)
|
||||||
|
|
||||||
# 准备请求数据
|
# 准备请求数据
|
||||||
headers = {
|
headers = {
|
||||||
|
|||||||
Reference in New Issue
Block a user