From 174350d3fb686b2279f6a63ef340bcf53f0f3296 Mon Sep 17 00:00:00 2001 From: rinfx <893383980@qq.com> Date: Mon, 17 Jun 2024 15:37:00 +0800 Subject: [PATCH] add plugin: ai-rag (#1038) --- plugins/wasm-go/extensions/ai-rag/.gitignore | 3 + plugins/wasm-go/extensions/ai-rag/README.md | 49 +++++++ .../extensions/ai-rag/dashscope/types.go | 36 +++++ .../extensions/ai-rag/dashvector/types.go | 26 ++++ plugins/wasm-go/extensions/ai-rag/go.mod | 19 +++ plugins/wasm-go/extensions/ai-rag/go.sum | 25 ++++ plugins/wasm-go/extensions/ai-rag/main.go | 126 ++++++++++++++++++ 7 files changed, 284 insertions(+) create mode 100644 plugins/wasm-go/extensions/ai-rag/.gitignore create mode 100644 plugins/wasm-go/extensions/ai-rag/README.md create mode 100644 plugins/wasm-go/extensions/ai-rag/dashscope/types.go create mode 100644 plugins/wasm-go/extensions/ai-rag/dashvector/types.go create mode 100644 plugins/wasm-go/extensions/ai-rag/go.mod create mode 100644 plugins/wasm-go/extensions/ai-rag/go.sum create mode 100644 plugins/wasm-go/extensions/ai-rag/main.go diff --git a/plugins/wasm-go/extensions/ai-rag/.gitignore b/plugins/wasm-go/extensions/ai-rag/.gitignore new file mode 100644 index 000000000..c9f9dc52b --- /dev/null +++ b/plugins/wasm-go/extensions/ai-rag/.gitignore @@ -0,0 +1,3 @@ +config.yaml +main.wasm +tmp/ \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-rag/README.md b/plugins/wasm-go/extensions/ai-rag/README.md new file mode 100644 index 000000000..00f406f4b --- /dev/null +++ b/plugins/wasm-go/extensions/ai-rag/README.md @@ -0,0 +1,49 @@ +# 简介 +通过对接阿里云向量检索服务实现LLM-RAG,流程如图所示: + +![](https://img.alicdn.com/imgextra/i1/O1CN01LuRVs41KhoeuzakeF_!!6000000001196-0-tps-1926-1316.jpg) + +# 配置说明 +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|----------------|-----------------|------|-----|----------------------------------------------------------------------------------| +| `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 | +| `dashscope.serviceName` | string | 必填 | - | 通义千问服务名 | +| `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 | +| `dashscope.domain` | string | 必填 | - | 访问通义千问服务时域名 | +| `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 | +| `dashvector.serviceName` | string | 必填 | - | 阿里云向量检索服务名 | +| `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 | +| `dashvector.domain` | string | 必填 | - | 访问阿里云向量检索服务时域名 | + +# 示例 + +```yaml +dashscope: + apiKey: xxxxxxxxxxxxxxx + serviceName: dashscope + servicePort: 443 + domain: dashscope.aliyuncs.com +dashvector: + apiKey: xxxxxxxxxxxxxxxxxxxx + serviceName: dashvector + servicePort: 443 + domain: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com + collection: xxxxxxxxxxxxxxx +``` + +[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。 + +以下为使用RAG进行增强的例子,原始请求为: +``` +海南追尾事故,发生在哪里?原因是什么?人员伤亡情况如何? +``` + +未经过RAG插件处理LLM返回的结果为: +``` +抱歉,作为AI模型,我无法实时获取和更新新闻事件的具体信息,包括地点、原因、人员伤亡等细节。对于此类具体事件,建议您查阅最新的新闻报道或官方通报以获取准确信息。您可以访问主流媒体网站、使用新闻应用或者关注相关政府部门的公告来获取这类动态资讯。 +``` + +经过RAG插件处理后LLM返回的结果为: +``` +海南追尾事故发生在海文高速公路文昌至海口方向37公里处。关于事故的具体原因,交警部门当时仍在进一步调查中,所以根据提供的信息无法确定事故的确切原因。人员伤亡情况是1人死亡(司机当场死亡),另有8人受伤(包括2名儿童和6名成人),所有受伤人员都被解救并送往医院进行治疗。 +``` \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-rag/dashscope/types.go b/plugins/wasm-go/extensions/ai-rag/dashscope/types.go new file mode 100644 index 000000000..79da649ca --- /dev/null +++ b/plugins/wasm-go/extensions/ai-rag/dashscope/types.go @@ -0,0 +1,36 @@ +package dashscope + +// DashScope embedding service: Request +type Request struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameter Parameter `json:"parameters"` +} + +type Input struct { + Texts []string `json:"texts"` +} + +type Parameter struct { + TextType string `json:"text_type"` +} + +// DashScope embedding service: Response +type Response struct { + Output Output `json:"output"` + Usage Usage `json:"usage"` + RequestID string `json:"request_id"` +} + +type Output struct { + Embeddings []Embedding `json:"embeddings"` +} + +type Embedding struct { + Embedding []float32 `json:"embedding"` + TextIndex int32 `json:"text_index"` +} + +type Usage struct { + TotalTokens int32 `json:"total_tokens"` +} diff --git a/plugins/wasm-go/extensions/ai-rag/dashvector/types.go b/plugins/wasm-go/extensions/ai-rag/dashvector/types.go new file mode 100644 index 000000000..52b10e56d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-rag/dashvector/types.go @@ -0,0 +1,26 @@ +package dashvector + +// DashVecotor document search: Request +type Request struct { + TopK int32 `json:"topk"` + OutputFileds []string `json:"output_fileds"` + Vector []float32 `json:"vector"` +} + +// DashVecotor document search: Response +type Response struct { + Code int32 `json:"code"` + RequestID string `json:"request_id"` + Message string `json:"message"` + Output []OutputObject `json:"output"` +} + +type OutputObject struct { + ID string `json:"id"` + Fields FieldObject `json:"fields"` + Score float32 `json:"score"` +} + +type FieldObject struct { + Raw string `json:"raw"` +} diff --git a/plugins/wasm-go/extensions/ai-rag/go.mod b/plugins/wasm-go/extensions/ai-rag/go.mod new file mode 100644 index 000000000..332d15466 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-rag/go.mod @@ -0,0 +1,19 @@ +module ai-rag + +go 1.18 + +require ( + github.com/alibaba/higress/plugins/wasm-go v1.3.5 + github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a + github.com/tidwall/gjson v1.14.3 +) + +require ( + github.com/google/uuid v1.3.0 // indirect + github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect + github.com/magefile/mage v1.14.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/resp v0.1.1 // indirect + github.com/tidwall/sjson v1.2.5 +) diff --git a/plugins/wasm-go/extensions/ai-rag/go.sum b/plugins/wasm-go/extensions/ai-rag/go.sum new file mode 100644 index 000000000..94fcae90c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-rag/go.sum @@ -0,0 +1,25 @@ +github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY= +github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= +github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= +github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= +github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= +github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-rag/main.go b/plugins/wasm-go/extensions/ai-rag/main.go new file mode 100644 index 000000000..8c9ee5a32 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-rag/main.go @@ -0,0 +1,126 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + + "myplugin/dashscope" + "myplugin/dashvector" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/tidwall/gjson" +) + +func main() { + wrapper.SetCtx( + "ai-rag", + wrapper.ParseConfigBy(parseConfig), + wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), + wrapper.ProcessRequestBodyBy(onHttpRequestBody), + ) +} + +type AIRagConfig struct { + DashScopeClient wrapper.HttpClient + DashScopeAPIKey string + DashVectorClient wrapper.HttpClient + DashVectorAPIKey string + DashVectorCollection string +} + +type Request struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + FrequencyPenalty float64 `json:"frequency_penalty"` + PresencePenalty float64 `json:"presence_penalty"` + Stream bool `json:"stream"` + Temperature float64 `json:"temperature"` + Topp int32 `json:"top_p"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error { + config.DashScopeAPIKey = json.Get("dashscope.apiKey").String() + + config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: json.Get("dashscope.serviceName").String(), + Port: json.Get("dashscope.servicePort").Int(), + Domain: json.Get("dashscope.domain").String(), + }) + config.DashVectorAPIKey = json.Get("dashvector.apiKey").String() + config.DashVectorCollection = json.Get("dashvector.collection").String() + config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: json.Get("dashvector.serviceName").String(), + Port: json.Get("dashvector.servicePort").Int(), + Domain: json.Get("dashvector.domain").String(), + }) + return nil +} + +func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action { + p, _ := proxywasm.GetHttpRequestHeader(":path") + if p != "/api/openai/v1/chat/completions" { + ctx.DontReadRequestBody() + return types.ActionContinue + } + proxywasm.RemoveHttpRequestHeader("content-length") + return types.ActionContinue +} + +func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, log wrapper.Log) types.Action { + var rawRequest Request + _ = json.Unmarshal(body, &rawRequest) + messageLength := len(rawRequest.Messages) + rawContent := rawRequest.Messages[messageLength-1].Content + requestEmbedding := dashscope.Request{ + Model: "text-embedding-v1", + Input: dashscope.Input{ + Texts: []string{rawContent}, + }, + Parameter: dashscope.Parameter{ + TextType: "query", + }, + } + headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}} + reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding) + // log.Info(string(reqEmbeddingSerialized)) + config.DashScopeClient.Post( + "/api/v1/services/embeddings/text-embedding/text-embedding", + headers, + reqEmbeddingSerialized, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + var responseEmbedding dashscope.Response + _ = json.Unmarshal(responseBody, &responseEmbedding) + requestQuery := dashvector.Request{ + TopK: 1, + OutputFileds: []string{"raw"}, + Vector: responseEmbedding.Output.Embeddings[0].Embedding, + } + requestQuerySerialized, _ := json.Marshal(requestQuery) + config.DashVectorClient.Post( + fmt.Sprintf("/v1/collections/%s/query", config.DashVectorCollection), + [][2]string{{"Content-Type", "application/json"}, {"dashvector-auth-token", config.DashVectorAPIKey}}, + requestQuerySerialized, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + var response dashvector.Response + _ = json.Unmarshal(responseBody, &response) + doc := response.Output[0].Fields.Raw + rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息,你可以自行选择是否使用这些参考信息,现在请回答以下问题:\n%s", doc, rawContent) + newBody, _ := json.Marshal(rawRequest) + // log.Info(string(newBody)) + proxywasm.ReplaceHttpRequestBody(newBody) + proxywasm.ResumeHttpRequest() + }, + ) + }, + 50000, + ) + return types.ActionPause +}