From a17ac9e4c626c378137b2be04636f463f2c222bc Mon Sep 17 00:00:00 2001 From: rinfx <893383980@qq.com> Date: Thu, 8 Aug 2024 18:00:02 +0800 Subject: [PATCH] Optimize ai-rag plugin (#1170) --- plugins/wasm-go/extensions/ai-rag/README.md | 26 ++++-- plugins/wasm-go/extensions/ai-rag/go.sum | 5 +- plugins/wasm-go/extensions/ai-rag/main.go | 94 ++++++++++++++++----- 3 files changed, 90 insertions(+), 35 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-rag/README.md b/plugins/wasm-go/extensions/ai-rag/README.md index 00f406f4b..204ad9ee0 100644 --- a/plugins/wasm-go/extensions/ai-rag/README.md +++ b/plugins/wasm-go/extensions/ai-rag/README.md @@ -1,34 +1,42 @@ # 简介 通过对接阿里云向量检索服务实现LLM-RAG,流程如图所示: -![](https://img.alicdn.com/imgextra/i1/O1CN01LuRVs41KhoeuzakeF_!!6000000001196-0-tps-1926-1316.jpg) + # 配置说明 | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | |----------------|-----------------|------|-----|----------------------------------------------------------------------------------| | `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 | -| `dashscope.serviceName` | string | 必填 | - | 通义千问服务名 | +| `dashscope.serviceFQDN` | string | 必填 | - | 通义千问服务名 | | `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 | -| `dashscope.domain` | string | 必填 | - | 访问通义千问服务时域名 | +| `dashscope.serviceHost` | string | 必填 | - | 访问通义千问服务时域名 | | `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 | -| `dashvector.serviceName` | string | 必填 | - | 阿里云向量检索服务名 | +| `dashvector.serviceFQDN` | string | 必填 | - | 阿里云向量检索服务名 | | `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 | -| `dashvector.domain` | string | 必填 | - | 访问阿里云向量检索服务时域名 | +| `dashvector.serviceHost` | string | 必填 | - | 访问阿里云向量检索服务时域名 | +| `dashvector.topk` | int | 必填 | - | 阿里云向量检索时获取向量数 | +| `dashvector.threshold` | float | 必填 | - | 向量距离阈值,高于该阈值的文档会被过滤掉 | +| `dashvector.field` | string | 必填 | - | 阿里云向量检索存储文档的字段名 | + +插件开启后,在使用链路追踪功能时,会在span的attribute中添加rag检索到的文档id信息,供排查问题使用。 # 示例 ```yaml dashscope: apiKey: xxxxxxxxxxxxxxx - serviceName: dashscope + serviceFQDN: dashscope servicePort: 443 - domain: dashscope.aliyuncs.com + serviceHost: dashscope.aliyuncs.com dashvector: apiKey: xxxxxxxxxxxxxxxxxxxx - serviceName: dashvector + serviceFQDN: dashvector servicePort: 443 - domain: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com + serviceHost: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com collection: xxxxxxxxxxxxxxx + topk: 1 + threshold: 0.4 + field: raw ``` [CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。 diff --git a/plugins/wasm-go/extensions/ai-rag/go.sum b/plugins/wasm-go/extensions/ai-rag/go.sum index 61ac3d38f..f473e12b2 100644 --- a/plugins/wasm-go/extensions/ai-rag/go.sum +++ b/plugins/wasm-go/extensions/ai-rag/go.sum @@ -1,12 +1,9 @@ -github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY= -github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= diff --git a/plugins/wasm-go/extensions/ai-rag/main.go b/plugins/wasm-go/extensions/ai-rag/main.go index 4eddf830d..b7be598c7 100644 --- a/plugins/wasm-go/extensions/ai-rag/main.go +++ b/plugins/wasm-go/extensions/ai-rag/main.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "ai-rag/dashscope" "ai-rag/dashvector" @@ -20,6 +21,7 @@ func main() { wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), + wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), ) } @@ -29,6 +31,9 @@ type AIRagConfig struct { DashVectorClient wrapper.HttpClient DashVectorAPIKey string DashVectorCollection string + DashVectorTopK int32 + DashVectorThreshold float64 + DashVectorField string } type Request struct { @@ -47,29 +52,46 @@ type Message struct { } func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error { + checkList := []string{ + "dashscope.apiKey", + "dashscope.serviceFQDN", + "dashscope.servicePort", + "dashscope.serviceHost", + "dashvector.apiKey", + "dashvector.collection", + "dashvector.serviceFQDN", + "dashvector.servicePort", + "dashvector.serviceHost", + "dashvector.topk", + "dashvector.threshold", + "dashvector.field", + } + for _, checkEntry := range checkList { + if !json.Get(checkEntry).Exists() { + return fmt.Errorf("%s not found in plugin config!", checkEntry) + } + } config.DashScopeAPIKey = json.Get("dashscope.apiKey").String() - config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{ - ServiceName: json.Get("dashscope.serviceName").String(), - Port: json.Get("dashscope.servicePort").Int(), - Domain: json.Get("dashscope.domain").String(), + config.DashScopeClient = wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: json.Get("dashscope.serviceFQDN").String(), + Port: json.Get("dashscope.servicePort").Int(), + Host: json.Get("dashscope.serviceHost").String(), }) config.DashVectorAPIKey = json.Get("dashvector.apiKey").String() config.DashVectorCollection = json.Get("dashvector.collection").String() - config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{ - ServiceName: json.Get("dashvector.serviceName").String(), - Port: json.Get("dashvector.servicePort").Int(), - Domain: json.Get("dashvector.domain").String(), + config.DashVectorClient = wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: json.Get("dashvector.serviceFQDN").String(), + Port: json.Get("dashvector.servicePort").Int(), + Host: json.Get("dashvector.serviceHost").String(), }) + config.DashVectorTopK = int32(json.Get("dashvector.topk").Int()) + config.DashVectorThreshold = json.Get("dashvector.threshold").Float() + config.DashVectorField = json.Get("dashvector.field").String() return nil } func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action { - p, _ := proxywasm.GetHttpRequestHeader(":path") - if p != "/api/openai/v1/chat/completions" { - ctx.DontReadRequestBody() - return types.ActionContinue - } proxywasm.RemoveHttpRequestHeader("content-length") return types.ActionContinue } @@ -78,9 +100,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, var rawRequest Request _ = json.Unmarshal(body, &rawRequest) messageLength := len(rawRequest.Messages) + if messageLength == 0 { + return types.ActionContinue + } rawContent := rawRequest.Messages[messageLength-1].Content requestEmbedding := dashscope.Request{ - Model: "text-embedding-v1", + Model: "text-embedding-v2", Input: dashscope.Input{ Texts: []string{rawContent}, }, @@ -90,7 +115,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, } headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}} reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding) - // log.Info(string(reqEmbeddingSerialized)) config.DashScopeClient.Post( "/api/v1/services/embeddings/text-embedding/text-embedding", headers, @@ -99,8 +123,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, var responseEmbedding dashscope.Response _ = json.Unmarshal(responseBody, &responseEmbedding) requestQuery := dashvector.Request{ - TopK: 1, - OutputFileds: []string{"raw"}, + TopK: config.DashVectorTopK, + OutputFileds: []string{config.DashVectorField}, Vector: responseEmbedding.Output.Embeddings[0].Embedding, } requestQuerySerialized, _ := json.Marshal(requestQuery) @@ -111,11 +135,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, func(statusCode int, responseHeaders http.Header, responseBody []byte) { var response dashvector.Response _ = json.Unmarshal(responseBody, &response) - doc := response.Output[0].Fields.Raw - rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息,你可以自行选择是否使用这些参考信息,现在请回答以下问题:\n%s", doc, rawContent) - newBody, _ := json.Marshal(rawRequest) - // log.Info(string(newBody)) - proxywasm.ReplaceHttpRequestBody(newBody) + recallDocIds := []string{} + recallDocs := []string{} + for _, output := range response.Output { + log.Debugf("Score: %f, Doc: %s", output.Score, output.Fields.Raw) + if output.Score <= float32(config.DashVectorThreshold) { + recallDocs = append(recallDocs, output.Fields.Raw) + recallDocIds = append(recallDocIds, output.ID) + } + } + if len(recallDocs) > 0 { + rawRequest.Messages = rawRequest.Messages[:messageLength-1] + traceStr := strings.Join(recallDocIds, ", ") + proxywasm.SetProperty([]string{"trace_span_tag.rag_docs"}, []byte(traceStr)) + for _, doc := range recallDocs { + rawRequest.Messages = append(rawRequest.Messages, Message{"user", doc}) + } + rawRequest.Messages = append(rawRequest.Messages, Message{"user", fmt.Sprintf("现在,请回答以下问题:\n%s", rawContent)}) + newBody, _ := json.Marshal(rawRequest) + proxywasm.ReplaceHttpRequestBody(newBody) + ctx.SetContext("x-envoy-rag-recall", true) + } proxywasm.ResumeHttpRequest() }, ) @@ -124,3 +164,13 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, ) return types.ActionPause } + +func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action { + recall, ok := ctx.GetContext("x-envoy-rag-recall").(bool) + if ok && recall { + proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "true") + } else { + proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "false") + } + return types.ActionContinue +}