more optimize of ai search plugin (#1896)

This commit is contained in:
澄潭
2025-03-14 23:24:22 +08:00
committed by GitHub
parent f09e029a6b
commit 34b3fc3114
9 changed files with 229 additions and 22 deletions

View File

@@ -46,16 +46,19 @@ type SearchRewrite struct {
modelName string
timeoutMillisecond uint32
prompt string
promptTemplate string // Original prompt template before replacing placeholders
maxCount int
}
type Config struct {
engine []engine.SearchEngine
promptTemplate string
referenceFormat string
defaultLanguage string
needReference bool
searchRewrite *SearchRewrite
defaultEnable bool
engine []engine.SearchEngine
promptTemplate string
referenceFormat string
defaultLanguage string
needReference bool
referenceLocation string // "head" or "tail"
searchRewrite *SearchRewrite
defaultEnable bool
}
const (
@@ -74,6 +77,9 @@ var internetSearchPrompts string
//go:embed prompts/private.md
var privateSearchPrompts string
//go:embed prompts/chinese-internet.md
var chineseInternetSearchPrompts string
func main() {
wrapper.SetCtx(
"ai-search",
@@ -99,6 +105,13 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
} else if !strings.Contains(config.referenceFormat, "%s") {
return fmt.Errorf("invalid referenceFormat:%s", config.referenceFormat)
}
config.referenceLocation = json.Get("referenceLocation").String()
if config.referenceLocation == "" {
config.referenceLocation = "head" // Default to head if not specified
} else if config.referenceLocation != "head" && config.referenceLocation != "tail" {
return fmt.Errorf("invalid referenceLocation:%s, must be 'head' or 'tail'", config.referenceLocation)
}
}
config.defaultLanguage = json.Get("defaultLang").String()
config.promptTemplate = json.Get("promptTemplate").String()
@@ -144,6 +157,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
return fmt.Errorf("invalid promptTemplate, must contains {search_results} and {question}:%s", config.promptTemplate)
}
var internetExists, privateExists, arxivExists bool
var onlyQuark bool = true
for _, e := range json.Get("searchFrom").Array() {
switch e.Get("type").String() {
case "bing":
@@ -153,6 +167,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
}
config.engine = append(config.engine, searchEngine)
internetExists = true
onlyQuark = false
case "google":
searchEngine, err := google.NewGoogleSearch(&e)
if err != nil {
@@ -160,6 +175,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
}
config.engine = append(config.engine, searchEngine)
internetExists = true
onlyQuark = false
case "arxiv":
searchEngine, err := arxiv.NewArxivSearch(&e)
if err != nil {
@@ -167,6 +183,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
}
config.engine = append(config.engine, searchEngine)
arxivExists = true
onlyQuark = false
case "elasticsearch":
searchEngine, err := elasticsearch.NewElasticsearchSearch(&e)
if err != nil {
@@ -174,6 +191,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
}
config.engine = append(config.engine, searchEngine)
privateExists = true
onlyQuark = false
case "quark":
searchEngine, err := quark.NewQuarkSearch(&e)
if err != nil {
@@ -217,6 +235,12 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
llmTimeout = 30000
}
searchRewrite.timeoutMillisecond = uint32(llmTimeout)
maxCount := searchRewriteJson.Get("maxCount").Int()
if maxCount == 0 {
maxCount = 3 // Default value
}
searchRewrite.maxCount = int(maxCount)
// The consideration here is that internet searches are generally available, but arxiv and private sources may not be.
if arxivExists {
if privateExists {
@@ -231,8 +255,18 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
searchRewrite.prompt = privateSearchPrompts
} else if internetExists {
// only internet
searchRewrite.prompt = internetSearchPrompts
if onlyQuark {
// When only quark is used, use chinese-internet.md
searchRewrite.prompt = chineseInternetSearchPrompts
} else {
searchRewrite.prompt = internetSearchPrompts
}
}
// Store the original prompt template before replacing placeholders
searchRewrite.promptTemplate = searchRewrite.prompt
// Replace {max_count} placeholder in the prompt with the configured value
searchRewrite.prompt = strings.Replace(searchRewrite.prompt, "{max_count}", fmt.Sprintf("%d", searchRewrite.maxCount), -1)
config.searchRewrite = searchRewrite
}
if len(config.engine) == 0 {
@@ -260,9 +294,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Lo
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
// Check if plugin should be enabled based on config and request
webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !config.defaultEnable {
// When defaultEnable is false, we need to check if web_search_options exists in the request
webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !webSearchOptions.Exists() {
log.Debugf("Plugin disabled by config and no web_search_options in request")
return types.ActionContinue
@@ -286,6 +320,36 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
}
searchRewrite := config.searchRewrite
if searchRewrite != nil {
// Check if web_search_options.search_context_size exists and adjust maxCount accordingly
if webSearchOptions.Exists() {
searchContextSize := webSearchOptions.Get("search_context_size").String()
if searchContextSize != "" {
originalMaxCount := searchRewrite.maxCount
switch searchContextSize {
case "low":
searchRewrite.maxCount = 1
log.Debugf("Setting maxCount to 1 based on search_context_size=low")
case "medium":
searchRewrite.maxCount = 3
log.Debugf("Setting maxCount to 3 based on search_context_size=medium")
case "high":
searchRewrite.maxCount = 5
log.Debugf("Setting maxCount to 5 based on search_context_size=high")
default:
log.Warnf("Unknown search_context_size value: %s, using configured maxCount: %d",
searchContextSize, searchRewrite.maxCount)
}
// If maxCount changed, regenerate the prompt from the template
if originalMaxCount != searchRewrite.maxCount && searchRewrite.promptTemplate != "" {
searchRewrite.prompt = strings.Replace(
searchRewrite.promptTemplate,
"{max_count}",
fmt.Sprintf("%d", searchRewrite.maxCount),
-1)
}
}
}
startTime := time.Now()
rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1)
rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf(
@@ -509,16 +573,32 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
}
content := gjson.GetBytes(body, "choices.0.message.content").String()
var modifiedContent string
formattedReferences := fmt.Sprintf(config.referenceFormat, references)
if strings.HasPrefix(strings.TrimLeftFunc(content, unicode.IsSpace), "<think>") {
thinkEnd := strings.Index(content, "</think>")
if thinkEnd != -1 {
modifiedContent = content[:thinkEnd+8] +
fmt.Sprintf("\n%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content[thinkEnd+8:])
if config.referenceLocation == "tail" {
// Add references at the end
modifiedContent = content + fmt.Sprintf("\n\n%s", formattedReferences)
} else {
// Default: add references after </think> tag
modifiedContent = content[:thinkEnd+8] +
fmt.Sprintf("\n%s\n\n%s", formattedReferences, content[thinkEnd+8:])
}
}
}
if modifiedContent == "" {
modifiedContent = fmt.Sprintf("%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content)
if config.referenceLocation == "tail" {
// Add references at the end
modifiedContent = fmt.Sprintf("%s\n\n%s", content, formattedReferences)
} else {
// Default: add references at the beginning
modifiedContent = fmt.Sprintf("%s\n\n%s", formattedReferences, content)
}
}
body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent)
if err != nil {
log.Errorf("modify response message content failed, err:%v, body:%s", err, body)
@@ -561,7 +641,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
var newMessages []string
for i, msg := range messages {
if i < len(messages)-1 {
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), log)
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log)
if newMsg != "" {
newMessages = append(newMessages, newMsg)
}
@@ -579,7 +659,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
}
}
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, log wrapper.Log) string {
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log wrapper.Log) string {
log.Debugf("single sse message: %s", sseMessage)
subMessages := strings.Split(sseMessage, "\n")
var message string
@@ -600,6 +680,26 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
}
bodyJson = strings.TrimPrefix(bodyJson, " ")
bodyJson = strings.TrimSuffix(bodyJson, "\n")
// If tailReference is true, only check if this is the last message
if tailReference {
// Check if this is the last message in the stream (finish_reason is "stop")
finishReason := gjson.Get(bodyJson, "choices.0.finish_reason").String()
if finishReason == "stop" {
// This is the last message, append references at the end
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", deltaContent+fmt.Sprintf("\n\n%s", references))
if err != nil {
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
}
// Not the last message, return original message
return sseMessage
}
// Original head reference logic
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
// Skip the preceding content that might be empty due to the presence of a separate reasoning_content field.
if deltaContent == "" {
@@ -615,7 +715,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
if !strings.Contains(strings.TrimLeftFunc(bufferContent, unicode.IsSpace), "<think>") {
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, bufferContent))
if err != nil {
log.Errorf("update messsage failed:%s", err)
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
@@ -629,7 +729,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
fmt.Sprintf("\n%s\n\n%s", references, bufferContent[thinkEnd+8:])
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", modifiedContent)
if err != nil {
log.Errorf("update messsage failed:%s", err)
log.Errorf("update message failed:%s", err)
}
ctx.SetContext("ReferenceAppended", true)
return fmt.Sprintf("data: %s", modifiedMessage)
@@ -644,7 +744,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
// Return the content before the partial match
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", bufferContent[:len(bufferContent)-i])
if err != nil {
log.Errorf("update messsage failed:%s", err)
log.Errorf("update message failed:%s", err)
}
return fmt.Sprintf("data: %s", modifiedMessage)
}