// Copyright (c) 2022 Alibaba Group Holding Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( _ "embed" "errors" "fmt" "net/http" "strings" "time" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/elasticsearch" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/google" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/quark" ) type SearchRewrite struct { client wrapper.HttpClient url string apiKey string modelName string timeoutMillisecond uint32 prompt string } type Config struct { engine []engine.SearchEngine promptTemplate string referenceFormat string defaultLanguage string needReference bool searchRewrite *SearchRewrite } const ( DEFAULT_MAX_BODY_BYTES uint32 = 100 * 1024 * 1024 ) //go:embed prompts/full.md var fullSearchPrompts string //go:embed prompts/arxiv.md var arxivSearchPrompts string //go:embed prompts/internet.md var internetSearchPrompts string //go:embed prompts/private.md var privateSearchPrompts string func main() { wrapper.SetCtx( "ai-search", wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody), wrapper.ProcessResponseBodyBy(onHttpResponseBody), ) } func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error { config.needReference = json.Get("needReference").Bool() if config.needReference { config.referenceFormat = json.Get("referenceFormat").String() if config.referenceFormat == "" { config.referenceFormat = "**References:**\n%s" } else if !strings.Contains(config.referenceFormat, "%s") { return fmt.Errorf("invalid referenceFormat:%s", config.referenceFormat) } } config.defaultLanguage = json.Get("defaultLang").String() config.promptTemplate = json.Get("promptTemplate").String() if config.promptTemplate == "" { if config.needReference { config.promptTemplate = `# 以下内容是基于用户发送的消息的搜索结果: {search_results} 在我给你的搜索结果中,每个结果都是[webpage X begin]...[webpage X end]格式的,X代表每篇文章的数字索引。请在适当的情况下在句子末尾引用上下文。请按照引用编号[X]的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如[3][5],切记不要将引用集中在最后返回引用编号,而是在答案对应部分列出。 在回答时,请注意以下几点: - 今天是北京时间:{cur_date}。 - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。 - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内,并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项;如非必要,不要主动告诉用户搜索结果未提供的内容。 - 对于创作类的问题(如写论文),请务必在正文的段落中引用对应的参考编号,例如[3][5],不能只在文章末尾引用。你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。 - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。 - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。 - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。 - 你的回答应该综合多个相关网页来回答,不能重复引用一个网页。 - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。 # 用户消息为: {question}` } else { config.promptTemplate = `# 以下内容是基于用户发送的消息的搜索结果: {search_results} 在我给你的搜索结果中,每个结果都是[webpage begin]...[webpage end]格式的。 在回答时,请注意以下几点: - 今天是北京时间:{cur_date}。 - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。 - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内。如非必要,不要主动告诉用户搜索结果未提供的内容。 - 对于创作类的问题(如写论文),你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。 - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。 - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。 - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。 - 你的回答应该综合多个相关网页来回答,但回答中不要给出网页的引用来源。 - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。 # 用户消息为: {question}` } } if !strings.Contains(config.promptTemplate, "{search_results}") || !strings.Contains(config.promptTemplate, "{question}") { return fmt.Errorf("invalid promptTemplate, must contains {search_results} and {question}:%s", config.promptTemplate) } var internetExists, privateExists, arxivExists bool for _, e := range json.Get("searchFrom").Array() { switch e.Get("type").String() { case "bing": searchEngine, err := bing.NewBingSearch(&e) if err != nil { return fmt.Errorf("bing search engine init failed:%s", err) } config.engine = append(config.engine, searchEngine) internetExists = true case "google": searchEngine, err := google.NewGoogleSearch(&e) if err != nil { return fmt.Errorf("google search engine init failed:%s", err) } config.engine = append(config.engine, searchEngine) internetExists = true case "arxiv": searchEngine, err := arxiv.NewArxivSearch(&e) if err != nil { return fmt.Errorf("arxiv search engine init failed:%s", err) } config.engine = append(config.engine, searchEngine) arxivExists = true case "elasticsearch": searchEngine, err := elasticsearch.NewElasticsearchSearch(&e) if err != nil { return fmt.Errorf("elasticsearch search engine init failed:%s", err) } config.engine = append(config.engine, searchEngine) privateExists = true case "quark": searchEngine, err := quark.NewQuarkSearch(&e) if err != nil { return fmt.Errorf("elasticsearch search engine init failed:%s", err) } config.engine = append(config.engine, searchEngine) internetExists = true default: return fmt.Errorf("unkown search engine:%s", e.Get("type").String()) } } searchRewriteJson := json.Get("searchRewrite") if searchRewriteJson.Exists() { searchRewrite := &SearchRewrite{} llmServiceName := searchRewriteJson.Get("llmServiceName").String() if llmServiceName == "" { return errors.New("llm_service_name not found") } llmServicePort := searchRewriteJson.Get("llmServicePort").Int() if llmServicePort == 0 { return errors.New("llmServicePort not found") } searchRewrite.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ FQDN: llmServiceName, Port: llmServicePort, }) llmApiKey := searchRewriteJson.Get("llmApiKey").String() if llmApiKey == "" { return errors.New("llmApiKey not found") } searchRewrite.apiKey = llmApiKey llmUrl := searchRewriteJson.Get("llmUrl").String() if llmUrl == "" { return errors.New("llmUrl not found") } searchRewrite.url = llmUrl llmModelName := searchRewriteJson.Get("llmModelName").String() if llmModelName == "" { return errors.New("llmModelName not found") } searchRewrite.modelName = llmModelName llmTimeout := searchRewriteJson.Get("timeoutMillisecond").Uint() if llmTimeout == 0 { llmTimeout = 30000 } searchRewrite.timeoutMillisecond = uint32(llmTimeout) // The consideration here is that internet searches are generally available, but arxiv and private sources may not be. if arxivExists { if privateExists { // private + internet + arxiv searchRewrite.prompt = fullSearchPrompts } else { // internet + arxiv searchRewrite.prompt = arxivSearchPrompts } } else if privateExists { // private + internet searchRewrite.prompt = privateSearchPrompts } else if internetExists { // only internet searchRewrite.prompt = internetSearchPrompts } config.searchRewrite = searchRewrite } if len(config.engine) == 0 { return fmt.Errorf("no avaliable search engine found") } log.Debugf("ai search enabled, config: %#v", config) return nil } func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpRequestHeader("content-type") // The request does not have a body. if contentType == "" { return types.ActionContinue } if !strings.Contains(contentType, "application/json") { log.Warnf("content is not json, can't process: %s", contentType) ctx.DontReadRequestBody() return types.ActionContinue } ctx.SetRequestBodyBufferLimit(DEFAULT_MAX_BODY_BYTES) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") return types.ActionContinue } func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action { var queryIndex int var query string messages := gjson.GetBytes(body, "messages").Array() for i := len(messages) - 1; i >= 0; i-- { if messages[i].Get("role").String() == "user" { queryIndex = i query = messages[i].Get("content").String() break } } if query == "" { log.Errorf("not found user query in body:%s", body) return types.ActionContinue } searchRewrite := config.searchRewrite if searchRewrite != nil { startTime := time.Now() rewritePrompt := strings.Replace(searchRewrite.prompt, "{question}", query, 1) rewriteBody, _ := sjson.SetBytes([]byte(fmt.Sprintf( `{"stream":false,"max_tokens":100,"model":"%s","messages":[{"role":"user","content":""}]}`, searchRewrite.modelName)), "messages.0.content", rewritePrompt) err := searchRewrite.client.Post(searchRewrite.url, [][2]string{ {"Content-Type", "application/json"}, {"Authorization", fmt.Sprintf("Bearer %s", searchRewrite.apiKey)}, }, rewriteBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode != http.StatusOK { log.Errorf("search rewrite failed, status: %d", statusCode) // After a rewrite failure, no further search is performed, thus quickly identifying the failure. proxywasm.ResumeHttpRequest() return } content := gjson.GetBytes(responseBody, "choices.0.message.content").String() log.Infof("LLM rewritten query response: %s (took %v), original search query:%s", strings.ReplaceAll(content, "\n", `\n`), time.Since(startTime), query) if strings.Contains(content, "none") { log.Debugf("no search required") proxywasm.ResumeHttpRequest() return } // Parse search queries from LLM response var searchContexts []engine.SearchContext for _, line := range strings.Split(content, "\n") { line = strings.TrimSpace(line) if line == "" { continue } parts := strings.SplitN(line, ":", 2) if len(parts) != 2 { continue } engineType := strings.TrimSpace(parts[0]) queryStr := strings.TrimSpace(parts[1]) var ctx engine.SearchContext ctx.Language = config.defaultLanguage switch { case engineType == "internet": ctx.EngineType = engineType ctx.Querys = []string{queryStr} case engineType == "private": ctx.EngineType = engineType ctx.Querys = strings.Split(queryStr, ",") for i := range ctx.Querys { ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i]) } default: // Arxiv category ctx.EngineType = "arxiv" ctx.ArxivCategory = engineType ctx.Querys = strings.Split(queryStr, ",") for i := range ctx.Querys { ctx.Querys[i] = strings.TrimSpace(ctx.Querys[i]) } } if len(ctx.Querys) > 0 { searchContexts = append(searchContexts, ctx) if ctx.ArxivCategory != "" { // Conduct i/nquiries in all areas to increase recall. backupCtx := ctx backupCtx.ArxivCategory = "" searchContexts = append(searchContexts, backupCtx) } } } if len(searchContexts) == 0 { log.Errorf("no valid search contexts found") proxywasm.ResumeHttpRequest() return } if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) { proxywasm.ResumeHttpRequest() } }, searchRewrite.timeoutMillisecond) if err != nil { log.Errorf("search rewrite call llm service failed:%s", err) // After a rewrite failure, no further search is performed, thus quickly identifying the failure. return types.ActionContinue } return types.ActionPause } // Execute search without rewrite return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{ Querys: []string{query}, Language: config.defaultLanguage, }}, log) } func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log wrapper.Log) types.Action { searchResultGroups := make([][]engine.SearchResult, len(config.engine)) var finished int var searching int for i := 0; i < len(config.engine); i++ { configEngine := config.engine[i] // Check if engine needs to execute for any of the search contexts var needsExecute bool for _, searchCtx := range searchContexts { if configEngine.NeedExectue(searchCtx) { needsExecute = true break } } if !needsExecute { continue } // Process all search contexts for this engine for _, searchCtx := range searchContexts { if !configEngine.NeedExectue(searchCtx) { continue } args := configEngine.CallArgs(searchCtx) index := i err := configEngine.Client().Call(args.Method, args.Url, args.Headers, args.Body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { defer func() { finished++ if finished == searching { // Merge search results from all engines with deduplication var mergedResults []engine.SearchResult seenLinks := make(map[string]bool) for _, results := range searchResultGroups { for _, result := range results { if !seenLinks[result.Link] { seenLinks[result.Link] = true mergedResults = append(mergedResults, result) } } } // Format search results for prompt template var formattedResults []string var formattedReferences []string for j, result := range mergedResults { if config.needReference { formattedResults = append(formattedResults, fmt.Sprintf("[webpage %d begin]\n%s\n[webpage %d end]", j+1, result.Content, j+1)) formattedReferences = append(formattedReferences, fmt.Sprintf("[%d] [%s](%s)", j+1, result.Title, result.Link)) } else { formattedResults = append(formattedResults, fmt.Sprintf("[webpage begin]\n%s\n[webpage end]", result.Content)) } } // Prepare template variables curDate := time.Now().In(time.FixedZone("CST", 8*3600)).Format("2006年1月2日") searchResults := strings.Join(formattedResults, "\n") log.Debugf("searchResults: %s", searchResults) // Fill prompt template prompt := strings.Replace(config.promptTemplate, "{search_results}", searchResults, 1) prompt = strings.Replace(prompt, "{question}", searchContexts[0].Querys[0], 1) prompt = strings.Replace(prompt, "{cur_date}", curDate, 1) // Update request body with processed prompt modifiedBody, err := sjson.SetBytes(body, fmt.Sprintf("messages.%d.content", queryIndex), prompt) if err != nil { log.Errorf("modify request message content failed, err:%v, body:%s", err, body) } else { log.Debugf("modifeid body:%s", modifiedBody) proxywasm.ReplaceHttpRequestBody(modifiedBody) if config.needReference { ctx.SetContext("References", strings.Join(formattedReferences, "\n")) } } proxywasm.ResumeHttpRequest() } }() if statusCode != http.StatusOK { log.Errorf("search call failed, status: %d, engine: %#v", statusCode, configEngine) return } // Append results to existing slice for this engine searchResultGroups[index] = append(searchResultGroups[index], configEngine.ParseResult(searchCtx, responseBody)...) }, args.TimeoutMillisecond) if err != nil { log.Errorf("search call failed, engine: %#v", configEngine) continue } searching++ } } if searching > 0 { return types.ActionPause } return types.ActionContinue } func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Log) types.Action { if !config.needReference { ctx.DontReadResponseBody() return types.ActionContinue } proxywasm.RemoveHttpResponseHeader("content-length") contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { if err != nil { log.Errorf("unable to load content-type header from response: %v", err) } ctx.BufferResponseBody() ctx.SetResponseBodyBufferLimit(DEFAULT_MAX_BODY_BYTES) } return types.ActionContinue } func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action { references := ctx.GetStringContext("References", "") if references == "" { return types.ActionContinue } content := gjson.GetBytes(body, "choices.0.message.content") modifiedContent := fmt.Sprintf("%s\n\n%s", fmt.Sprintf(config.referenceFormat, references), content) body, err := sjson.SetBytes(body, "choices.0.message.content", modifiedContent) if err != nil { log.Errorf("modify response message content failed, err:%v, body:%s", err, body) return types.ActionContinue } proxywasm.ReplaceHttpResponseBody(body) return types.ActionContinue } func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { if ctx.GetBoolContext("ReferenceAppended", false) { return chunk } references := ctx.GetStringContext("References", "") if references == "" { return chunk } modifiedChunk, responseReady := setReferencesToFirstMessage(ctx, chunk, fmt.Sprintf(config.referenceFormat, references), log) if responseReady { ctx.SetContext("ReferenceAppended", true) return modifiedChunk } else { return []byte("") } } const PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" func setReferencesToFirstMessage(ctx wrapper.HttpContext, chunk []byte, references string, log wrapper.Log) ([]byte, bool) { if len(chunk) == 0 { log.Debugf("chunk is empty") return nil, false } var partialMessage []byte partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) if partialMessageI != nil { if pMsg, ok := partialMessageI.([]byte); ok { partialMessage = append(pMsg, chunk...) } else { log.Warnf("invalid partial message type: %T", partialMessageI) partialMessage = chunk } } else { partialMessage = chunk } if len(partialMessage) == 0 { log.Debugf("partial message is empty") return nil, false } messages := strings.Split(string(partialMessage), "\n\n") if len(messages) > 1 { firstMessage := messages[0] log.Debugf("first message: %s", firstMessage) firstMessage = strings.TrimPrefix(firstMessage, "data:") firstMessage = strings.TrimPrefix(firstMessage, " ") firstMessage = strings.TrimSuffix(firstMessage, "\n") deltaContent := gjson.Get(firstMessage, "choices.0.delta.content") modifiedMessage, err := sjson.Set(firstMessage, "choices.0.delta.content", fmt.Sprintf("%s\n\n%s", references, deltaContent)) if err != nil { log.Errorf("modify response delta content failed, err:%v", err) return partialMessage, true } modifiedMessage = fmt.Sprintf("data: %s", modifiedMessage) log.Debugf("modified message: %s", firstMessage) messages[0] = string(modifiedMessage) return []byte(strings.Join(messages, "\n\n")), true } ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, partialMessage) return nil, false }