mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
more optimize of ai search plugin (#1896)
This commit is contained in:
@@ -46,16 +46,19 @@ type SearchRewrite struct {
|
||||
modelName string
|
||||
timeoutMillisecond uint32
|
||||
prompt string
|
||||
promptTemplate string // Original prompt template before replacing placeholders
|
||||
maxCount int
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
engine []engine.SearchEngine
|
||||
promptTemplate string
|
||||
referenceFormat string
|
||||
defaultLanguage string
|
||||
needReference bool
|
||||
searchRewrite *SearchRewrite
|
||||
defaultEnable bool
|
||||
engine []engine.SearchEngine
|
||||
promptTemplate string
|
||||
referenceFormat string
|
||||
defaultLanguage string
|
||||
needReference bool
|
||||
referenceLocation string // "head" or "tail"
|
||||
searchRewrite *SearchRewrite
|
||||
defaultEnable bool
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -74,6 +77,9 @@ var internetSearchPrompts string
|
||||
//go:embed prompts/private.md
|
||||
var privateSearchPrompts string
|
||||
|
||||
//go:embed prompts/chinese-internet.md
|
||||
var chineseInternetSearchPrompts string
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-search",
|
||||
@@ -99,6 +105,13 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
} else if !strings.Contains(config.referenceFormat, "%s") {
|
||||
return fmt.Errorf("invalid referenceFormat:%s", config.referenceFormat)
|
||||
}
|
||||
|
||||
config.referenceLocation = json.Get("referenceLocation").String()
|
||||
if config.referenceLocation == "" {
|
||||
config.referenceLocation = "head" // Default to head if not specified
|
||||
} else if config.referenceLocation != "head" && config.referenceLocation != "tail" {
|
||||
return fmt.Errorf("invalid referenceLocation:%s, must be 'head' or 'tail'", config.referenceLocation)
|
||||
}
|
||||
}
|
||||
config.defaultLanguage = json.Get("defaultLang").String()
|
||||
config.promptTemplate = json.Get("promptTemplate").String()
|
||||
@@ -144,6 +157,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
return fmt.Errorf("invalid promptTemplate, must contains {search_results} and {question}:%s", config.promptTemplate)
|
||||
}
|
||||
var internetExists, privateExists, arxivExists bool
|
||||
var onlyQuark bool = true
|
||||
for _, e := range json.Get("searchFrom").Array() {
|
||||
switch e.Get("type").String() {
|
||||
case "bing":
|
||||
@@ -153,6 +167,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
}
|
||||
config.engine = append(config.engine, searchEngine)
|
||||
internetExists = true
|
||||
onlyQuark = false
|
||||
case "google":
|
||||
searchEngine, err := google.NewGoogleSearch(&e)
|
||||
if err != nil {
|
||||
@@ -160,6 +175,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
}
|
||||
config.engine = append(config.engine, searchEngine)
|
||||
internetExists = true
|
||||
onlyQuark = false
|
||||
case "arxiv":
|
||||
searchEngine, err := arxiv.NewArxivSearch(&e)
|
||||
if err != nil {
|
||||
@@ -167,6 +183,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
}
|
||||
config.engine = append(config.engine, searchEngine)
|
||||
arxivExists = true
|
||||
onlyQuark = false
|
||||
case "elasticsearch":
|
||||
searchEngine, err := elasticsearch.NewElasticsearchSearch(&e)
|
||||
if err != nil {
|
||||
@@ -174,6 +191,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
}
|
||||
config.engine = append(config.engine, searchEngine)
|
||||
privateExists = true
|
||||
onlyQuark = false
|
||||
case "quark":
|
||||
searchEngine, err := quark.NewQuarkSearch(&e)
|
||||
if err != nil {
|
||||
@@ -217,6 +235,12 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
llmTimeout = 30000
|
||||
}
|
||||
searchRewrite.timeoutMillisecond = uint32(llmTimeout)
|
||||
|
||||
maxCount := searchRewriteJson.Get("maxCount").Int()
|
||||
if maxCount == 0 {
|
||||
maxCount = 3 // Default value
|
||||
}
|
||||
searchRewrite.maxCount = int(maxCount)
|
||||
// The consideration here is that internet searches are generally available, but arxiv and private sources may not be.
|
||||
if arxivExists {
|
||||
if privateExists {
|
||||
@@ -231,8 +255,18 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
searchRewrite.prompt = privateSearchPrompts
|
||||
} else if internetExists {
|
||||
// only internet
|
||||
searchRewrite.prompt = internetSearchPrompts
|
||||
if onlyQuark {
|
||||
// When only quark is used, use chinese-internet.md
|
||||
searchRewrite.prompt = chineseInternetSearchPrompts
|
||||
} else {
|
||||
searchRewrite.prompt = internetSearchPrompts
|
||||
}
|
||||
}
|
||||
|
||||
// Store the original prompt template before replacing placeholders
|
||||
searchRewrite.promptTemplate = searchRewrite.prompt
|
||||
// Replace {max_count} placeholder in the prompt with the configured value
|
||||
searchRewrite.prompt = strings.Replace(searchRewrite.prompt, "{max_count}", fmt.Sprintf("%d", searchRewrite.maxCount), -1)
|
||||
config.searchRewrite = searchRewrite
|
||||
}
|
||||
if len(config.engine) == 0 {
|
||||
@@ -260,9 +294,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Lo
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
|
||||
// Check if plugin should be enabled based on config and request
|
||||
webSearchOptions := gjson.GetBytes(body, "web_search_options")
|
||||
if !config.defaultEnable {
|
||||
// When defaultEnable is false, we need to check if web_search_options exists in the request
|
||||
webSearchOptions := gjson.GetBytes(body, "web_search_options")
|
||||
if !webSearchOptions.Exists() {
|
||||
log.Debugf("Plugin disabled by config and no web_search_options in request")
|
||||
return types.ActionContinue
|
||||
@@ -286,6 +320,36 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
|
||||
}
|
||||
searchRewrite := config.searchRewrite
|
||||
if searchRewrite != nil {
|
||||
// Check if web_search_options.search_context_size exists and adjust maxCount accordingly
|
||||
if webSearchOptions.Exists() {
|
||||
searchContextSize := webSearchOptions.Get("search_context_size").String()
|
||||
if searchContextSize != "" {
|
||||
originalMaxCount := searchRewrite.maxCount
|
||||
switch searchContextSize {
|
||||
case "low":
|
||||
searchRewrite.maxCount = 1
|
||||
log.Debugf("Setting maxCount to 1 based on search_context_size=low")
|
||||
case "medium":
|
||||
searchRewrite.maxCount = 3
|
||||
log.Debugf("Setting maxCount to 3 based on search_context_size=medium")
|
||||
case "high":
|
||||
searchRewrite.maxCount = 5
|
||||
log.Debugf("Setting maxCount to 5 based on search_context_size=high")
|
||||
default:
|
||||
log.Warnf("Unknown search_context_size value: %s, using configured maxCount: %d",
|
||||
searchContextSize, searchRewrite.maxCount)
|
||||
}
|
||||
|
||||
// If maxCount changed, regenerate the prompt from the template
|
||||
if originalMaxCount != searchRewrite.maxCount && searchRewrite.promptTemplate != "" {
|
||||
searchRewrite.prompt = strings.Replace(
|
||||
searchRewrite.promptTemplate,
|
||||
"{max_count}",
|
||||
fmt.Sprintf("%d", searchRewrite.maxCount),
|
||||
-1)
|
||||
}
|
||||
}
|
||||
}
|
||||
startTime := time.Now()
|
||||
rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1)
|
||||
rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf(
|
||||
@@ -509,16 +573,32 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
|
||||
}
|
||||
content := gjson.GetBytes(body, "choices.0.message.content").String()
|
||||
var modifiedContent string
|
||||
formattedReferences := fmt.Sprintf(config.referenceFormat, references)
|
||||
|
||||
if strings.HasPrefix(strings.TrimLeftFunc(content, unicode.IsSpace), "<think>") {
|
||||
thinkEnd := strings.Index(content, "</think>")
|
||||
if thinkEnd != -1 {
|
||||
modifiedContent = content[:thinkEnd+8] +
|
||||
fmt.Sprintf("\n%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content[thinkEnd+8:])
|
||||
if config.referenceLocation == "tail" {
|
||||
// Add references at the end
|
||||
modifiedContent = content + fmt.Sprintf("\n\n%s", formattedReferences)
|
||||
} else {
|
||||
// Default: add references after </think> tag
|
||||
modifiedContent = content[:thinkEnd+8] +
|
||||
fmt.Sprintf("\n%s\n\n%s", formattedReferences, content[thinkEnd+8:])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if modifiedContent == "" {
|
||||
modifiedContent = fmt.Sprintf("%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content)
|
||||
if config.referenceLocation == "tail" {
|
||||
// Add references at the end
|
||||
modifiedContent = fmt.Sprintf("%s\n\n%s", content, formattedReferences)
|
||||
} else {
|
||||
// Default: add references at the beginning
|
||||
modifiedContent = fmt.Sprintf("%s\n\n%s", formattedReferences, content)
|
||||
}
|
||||
}
|
||||
|
||||
body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent)
|
||||
if err != nil {
|
||||
log.Errorf("modify response message content failed, err:%v, body:%s", err, body)
|
||||
@@ -561,7 +641,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
|
||||
var newMessages []string
|
||||
for i, msg := range messages {
|
||||
if i < len(messages)-1 {
|
||||
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), log)
|
||||
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log)
|
||||
if newMsg != "" {
|
||||
newMessages = append(newMessages, newMsg)
|
||||
}
|
||||
@@ -579,7 +659,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
|
||||
}
|
||||
}
|
||||
|
||||
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, log wrapper.Log) string {
|
||||
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log wrapper.Log) string {
|
||||
log.Debugf("single sse message: %s", sseMessage)
|
||||
subMessages := strings.Split(sseMessage, "\n")
|
||||
var message string
|
||||
@@ -600,6 +680,26 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
|
||||
}
|
||||
bodyJson = strings.TrimPrefix(bodyJson, " ")
|
||||
bodyJson = strings.TrimSuffix(bodyJson, "\n")
|
||||
|
||||
// If tailReference is true, only check if this is the last message
|
||||
if tailReference {
|
||||
// Check if this is the last message in the stream (finish_reason is "stop")
|
||||
finishReason := gjson.Get(bodyJson, "choices.0.finish_reason").String()
|
||||
if finishReason == "stop" {
|
||||
// This is the last message, append references at the end
|
||||
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
|
||||
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", deltaContent+fmt.Sprintf("\n\n%s", references))
|
||||
if err != nil {
|
||||
log.Errorf("update message failed:%s", err)
|
||||
}
|
||||
ctx.SetContext("ReferenceAppended", true)
|
||||
return fmt.Sprintf("data: %s", modifiedMessage)
|
||||
}
|
||||
// Not the last message, return original message
|
||||
return sseMessage
|
||||
}
|
||||
|
||||
// Original head reference logic
|
||||
deltaContent := gjson.Get(bodyJson, "choices.0.delta.content").String()
|
||||
// Skip the preceding content that might be empty due to the presence of a separate reasoning_content field.
|
||||
if deltaContent == "" {
|
||||
@@ -615,7 +715,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
|
||||
if !strings.Contains(strings.TrimLeftFunc(bufferContent, unicode.IsSpace), "<think>") {
|
||||
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, bufferContent))
|
||||
if err != nil {
|
||||
log.Errorf("update messsage failed:%s", err)
|
||||
log.Errorf("update message failed:%s", err)
|
||||
}
|
||||
ctx.SetContext("ReferenceAppended", true)
|
||||
return fmt.Sprintf("data: %s", modifiedMessage)
|
||||
@@ -629,7 +729,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
|
||||
fmt.Sprintf("\n%s\n\n%s", references, bufferContent[thinkEnd+8:])
|
||||
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", modifiedContent)
|
||||
if err != nil {
|
||||
log.Errorf("update messsage failed:%s", err)
|
||||
log.Errorf("update message failed:%s", err)
|
||||
}
|
||||
ctx.SetContext("ReferenceAppended", true)
|
||||
return fmt.Sprintf("data: %s", modifiedMessage)
|
||||
@@ -644,7 +744,7 @@ func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references st
|
||||
// Return the content before the partial match
|
||||
modifiedMessage, err := sjson.Set(bodyJson, "choices.0.delta.content", bufferContent[:len(bufferContent)-i])
|
||||
if err != nil {
|
||||
log.Errorf("update messsage failed:%s", err)
|
||||
log.Errorf("update message failed:%s", err)
|
||||
}
|
||||
return fmt.Sprintf("data: %s", modifiedMessage)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user