Optimize ai-rag plugin (#1170)

This commit is contained in:
rinfx
2024-08-08 18:00:02 +08:00
committed by GitHub
parent 5e95f6f057
commit a17ac9e4c6
3 changed files with 90 additions and 35 deletions

View File

@@ -1,34 +1,42 @@
# 简介
通过对接阿里云向量检索服务实现LLM-RAG流程如图所示
![](https://img.alicdn.com/imgextra/i1/O1CN01LuRVs41KhoeuzakeF_!!6000000001196-0-tps-1926-1316.jpg)
<img src="https://img.alicdn.com/imgextra/i1/O1CN01LuRVs41KhoeuzakeF_!!6000000001196-0-tps-1926-1316.jpg" width=600>
# 配置说明
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
| `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 |
| `dashscope.serviceName` | string | 必填 | - | 通义千问服务名 |
| `dashscope.serviceFQDN` | string | 必填 | - | 通义千问服务名 |
| `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 |
| `dashscope.domain` | string | 必填 | - | 访问通义千问服务时域名 |
| `dashscope.serviceHost` | string | 必填 | - | 访问通义千问服务时域名 |
| `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 |
| `dashvector.serviceName` | string | 必填 | - | 阿里云向量检索服务名 |
| `dashvector.serviceFQDN` | string | 必填 | - | 阿里云向量检索服务名 |
| `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 |
| `dashvector.domain` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
| `dashvector.serviceHost` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
| `dashvector.topk` | int | 必填 | - | 阿里云向量检索时获取向量数 |
| `dashvector.threshold` | float | 必填 | - | 向量距离阈值,高于该阈值的文档会被过滤掉 |
| `dashvector.field` | string | 必填 | - | 阿里云向量检索存储文档的字段名 |
插件开启后在使用链路追踪功能时会在span的attribute中添加rag检索到的文档id信息供排查问题使用。
# 示例
```yaml
dashscope:
apiKey: xxxxxxxxxxxxxxx
serviceName: dashscope
serviceFQDN: dashscope
servicePort: 443
domain: dashscope.aliyuncs.com
serviceHost: dashscope.aliyuncs.com
dashvector:
apiKey: xxxxxxxxxxxxxxxxxxxx
serviceName: dashvector
serviceFQDN: dashvector
servicePort: 443
domain: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
serviceHost: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
collection: xxxxxxxxxxxxxxx
topk: 1
threshold: 0.4
field: raw
```
[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。

View File

@@ -1,12 +1,9 @@
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
"ai-rag/dashscope"
"ai-rag/dashvector"
@@ -20,6 +21,7 @@ func main() {
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
)
}
@@ -29,6 +31,9 @@ type AIRagConfig struct {
DashVectorClient wrapper.HttpClient
DashVectorAPIKey string
DashVectorCollection string
DashVectorTopK int32
DashVectorThreshold float64
DashVectorField string
}
type Request struct {
@@ -47,29 +52,46 @@ type Message struct {
}
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
checkList := []string{
"dashscope.apiKey",
"dashscope.serviceFQDN",
"dashscope.servicePort",
"dashscope.serviceHost",
"dashvector.apiKey",
"dashvector.collection",
"dashvector.serviceFQDN",
"dashvector.servicePort",
"dashvector.serviceHost",
"dashvector.topk",
"dashvector.threshold",
"dashvector.field",
}
for _, checkEntry := range checkList {
if !json.Get(checkEntry).Exists() {
return fmt.Errorf("%s not found in plugin config!", checkEntry)
}
}
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: json.Get("dashscope.serviceName").String(),
Port: json.Get("dashscope.servicePort").Int(),
Domain: json.Get("dashscope.domain").String(),
config.DashScopeClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: json.Get("dashscope.serviceFQDN").String(),
Port: json.Get("dashscope.servicePort").Int(),
Host: json.Get("dashscope.serviceHost").String(),
})
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
config.DashVectorCollection = json.Get("dashvector.collection").String()
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: json.Get("dashvector.serviceName").String(),
Port: json.Get("dashvector.servicePort").Int(),
Domain: json.Get("dashvector.domain").String(),
config.DashVectorClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: json.Get("dashvector.serviceFQDN").String(),
Port: json.Get("dashvector.servicePort").Int(),
Host: json.Get("dashvector.serviceHost").String(),
})
config.DashVectorTopK = int32(json.Get("dashvector.topk").Int())
config.DashVectorThreshold = json.Get("dashvector.threshold").Float()
config.DashVectorField = json.Get("dashvector.field").String()
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
p, _ := proxywasm.GetHttpRequestHeader(":path")
if p != "/api/openai/v1/chat/completions" {
ctx.DontReadRequestBody()
return types.ActionContinue
}
proxywasm.RemoveHttpRequestHeader("content-length")
return types.ActionContinue
}
@@ -78,9 +100,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
var rawRequest Request
_ = json.Unmarshal(body, &rawRequest)
messageLength := len(rawRequest.Messages)
if messageLength == 0 {
return types.ActionContinue
}
rawContent := rawRequest.Messages[messageLength-1].Content
requestEmbedding := dashscope.Request{
Model: "text-embedding-v1",
Model: "text-embedding-v2",
Input: dashscope.Input{
Texts: []string{rawContent},
},
@@ -90,7 +115,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
}
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
// log.Info(string(reqEmbeddingSerialized))
config.DashScopeClient.Post(
"/api/v1/services/embeddings/text-embedding/text-embedding",
headers,
@@ -99,8 +123,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
var responseEmbedding dashscope.Response
_ = json.Unmarshal(responseBody, &responseEmbedding)
requestQuery := dashvector.Request{
TopK: 1,
OutputFileds: []string{"raw"},
TopK: config.DashVectorTopK,
OutputFileds: []string{config.DashVectorField},
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
}
requestQuerySerialized, _ := json.Marshal(requestQuery)
@@ -111,11 +135,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
var response dashvector.Response
_ = json.Unmarshal(responseBody, &response)
doc := response.Output[0].Fields.Raw
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息你可以自行选择是否使用这些参考信息现在请回答以下问题\n%s", doc, rawContent)
newBody, _ := json.Marshal(rawRequest)
// log.Info(string(newBody))
proxywasm.ReplaceHttpRequestBody(newBody)
recallDocIds := []string{}
recallDocs := []string{}
for _, output := range response.Output {
log.Debugf("Score: %f, Doc: %s", output.Score, output.Fields.Raw)
if output.Score <= float32(config.DashVectorThreshold) {
recallDocs = append(recallDocs, output.Fields.Raw)
recallDocIds = append(recallDocIds, output.ID)
}
}
if len(recallDocs) > 0 {
rawRequest.Messages = rawRequest.Messages[:messageLength-1]
traceStr := strings.Join(recallDocIds, ", ")
proxywasm.SetProperty([]string{"trace_span_tag.rag_docs"}, []byte(traceStr))
for _, doc := range recallDocs {
rawRequest.Messages = append(rawRequest.Messages, Message{"user", doc})
}
rawRequest.Messages = append(rawRequest.Messages, Message{"user", fmt.Sprintf("现在,请回答以下问题:\n%s", rawContent)})
newBody, _ := json.Marshal(rawRequest)
proxywasm.ReplaceHttpRequestBody(newBody)
ctx.SetContext("x-envoy-rag-recall", true)
}
proxywasm.ResumeHttpRequest()
},
)
@@ -124,3 +164,13 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
)
return types.ActionPause
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
recall, ok := ctx.GetContext("x-envoy-rag-recall").(bool)
if ok && recall {
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "true")
} else {
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "false")
}
return types.ActionContinue
}