mirror of
https://github.com/alibaba/higress.git
synced 2026-03-09 19:20:51 +08:00
Optimize ai-rag plugin (#1170)
This commit is contained in:
@@ -1,34 +1,42 @@
|
||||
# 简介
|
||||
通过对接阿里云向量检索服务实现LLM-RAG,流程如图所示:
|
||||
|
||||

|
||||
<img src="https://img.alicdn.com/imgextra/i1/O1CN01LuRVs41KhoeuzakeF_!!6000000001196-0-tps-1926-1316.jpg" width=600>
|
||||
|
||||
# 配置说明
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
|
||||
| `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 |
|
||||
| `dashscope.serviceName` | string | 必填 | - | 通义千问服务名 |
|
||||
| `dashscope.serviceFQDN` | string | 必填 | - | 通义千问服务名 |
|
||||
| `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 |
|
||||
| `dashscope.domain` | string | 必填 | - | 访问通义千问服务时域名 |
|
||||
| `dashscope.serviceHost` | string | 必填 | - | 访问通义千问服务时域名 |
|
||||
| `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 |
|
||||
| `dashvector.serviceName` | string | 必填 | - | 阿里云向量检索服务名 |
|
||||
| `dashvector.serviceFQDN` | string | 必填 | - | 阿里云向量检索服务名 |
|
||||
| `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 |
|
||||
| `dashvector.domain` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
|
||||
| `dashvector.serviceHost` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
|
||||
| `dashvector.topk` | int | 必填 | - | 阿里云向量检索时获取向量数 |
|
||||
| `dashvector.threshold` | float | 必填 | - | 向量距离阈值,高于该阈值的文档会被过滤掉 |
|
||||
| `dashvector.field` | string | 必填 | - | 阿里云向量检索存储文档的字段名 |
|
||||
|
||||
插件开启后,在使用链路追踪功能时,会在span的attribute中添加rag检索到的文档id信息,供排查问题使用。
|
||||
|
||||
# 示例
|
||||
|
||||
```yaml
|
||||
dashscope:
|
||||
apiKey: xxxxxxxxxxxxxxx
|
||||
serviceName: dashscope
|
||||
serviceFQDN: dashscope
|
||||
servicePort: 443
|
||||
domain: dashscope.aliyuncs.com
|
||||
serviceHost: dashscope.aliyuncs.com
|
||||
dashvector:
|
||||
apiKey: xxxxxxxxxxxxxxxxxxxx
|
||||
serviceName: dashvector
|
||||
serviceFQDN: dashvector
|
||||
servicePort: 443
|
||||
domain: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
|
||||
serviceHost: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
|
||||
collection: xxxxxxxxxxxxxxx
|
||||
topk: 1
|
||||
threshold: 0.4
|
||||
field: raw
|
||||
```
|
||||
|
||||
[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ai-rag/dashscope"
|
||||
"ai-rag/dashvector"
|
||||
@@ -20,6 +21,7 @@ func main() {
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -29,6 +31,9 @@ type AIRagConfig struct {
|
||||
DashVectorClient wrapper.HttpClient
|
||||
DashVectorAPIKey string
|
||||
DashVectorCollection string
|
||||
DashVectorTopK int32
|
||||
DashVectorThreshold float64
|
||||
DashVectorField string
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
@@ -47,29 +52,46 @@ type Message struct {
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
|
||||
checkList := []string{
|
||||
"dashscope.apiKey",
|
||||
"dashscope.serviceFQDN",
|
||||
"dashscope.servicePort",
|
||||
"dashscope.serviceHost",
|
||||
"dashvector.apiKey",
|
||||
"dashvector.collection",
|
||||
"dashvector.serviceFQDN",
|
||||
"dashvector.servicePort",
|
||||
"dashvector.serviceHost",
|
||||
"dashvector.topk",
|
||||
"dashvector.threshold",
|
||||
"dashvector.field",
|
||||
}
|
||||
for _, checkEntry := range checkList {
|
||||
if !json.Get(checkEntry).Exists() {
|
||||
return fmt.Errorf("%s not found in plugin config!", checkEntry)
|
||||
}
|
||||
}
|
||||
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
|
||||
|
||||
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: json.Get("dashscope.serviceName").String(),
|
||||
Port: json.Get("dashscope.servicePort").Int(),
|
||||
Domain: json.Get("dashscope.domain").String(),
|
||||
config.DashScopeClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: json.Get("dashscope.serviceFQDN").String(),
|
||||
Port: json.Get("dashscope.servicePort").Int(),
|
||||
Host: json.Get("dashscope.serviceHost").String(),
|
||||
})
|
||||
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
|
||||
config.DashVectorCollection = json.Get("dashvector.collection").String()
|
||||
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: json.Get("dashvector.serviceName").String(),
|
||||
Port: json.Get("dashvector.servicePort").Int(),
|
||||
Domain: json.Get("dashvector.domain").String(),
|
||||
config.DashVectorClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: json.Get("dashvector.serviceFQDN").String(),
|
||||
Port: json.Get("dashvector.servicePort").Int(),
|
||||
Host: json.Get("dashvector.serviceHost").String(),
|
||||
})
|
||||
config.DashVectorTopK = int32(json.Get("dashvector.topk").Int())
|
||||
config.DashVectorThreshold = json.Get("dashvector.threshold").Float()
|
||||
config.DashVectorField = json.Get("dashvector.field").String()
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||
p, _ := proxywasm.GetHttpRequestHeader(":path")
|
||||
if p != "/api/openai/v1/chat/completions" {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -78,9 +100,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
var rawRequest Request
|
||||
_ = json.Unmarshal(body, &rawRequest)
|
||||
messageLength := len(rawRequest.Messages)
|
||||
if messageLength == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
rawContent := rawRequest.Messages[messageLength-1].Content
|
||||
requestEmbedding := dashscope.Request{
|
||||
Model: "text-embedding-v1",
|
||||
Model: "text-embedding-v2",
|
||||
Input: dashscope.Input{
|
||||
Texts: []string{rawContent},
|
||||
},
|
||||
@@ -90,7 +115,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
}
|
||||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
|
||||
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
|
||||
// log.Info(string(reqEmbeddingSerialized))
|
||||
config.DashScopeClient.Post(
|
||||
"/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
headers,
|
||||
@@ -99,8 +123,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
var responseEmbedding dashscope.Response
|
||||
_ = json.Unmarshal(responseBody, &responseEmbedding)
|
||||
requestQuery := dashvector.Request{
|
||||
TopK: 1,
|
||||
OutputFileds: []string{"raw"},
|
||||
TopK: config.DashVectorTopK,
|
||||
OutputFileds: []string{config.DashVectorField},
|
||||
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
|
||||
}
|
||||
requestQuerySerialized, _ := json.Marshal(requestQuery)
|
||||
@@ -111,11 +135,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
var response dashvector.Response
|
||||
_ = json.Unmarshal(responseBody, &response)
|
||||
doc := response.Output[0].Fields.Raw
|
||||
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息,你可以自行选择是否使用这些参考信息,现在请回答以下问题:\n%s", doc, rawContent)
|
||||
newBody, _ := json.Marshal(rawRequest)
|
||||
// log.Info(string(newBody))
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
recallDocIds := []string{}
|
||||
recallDocs := []string{}
|
||||
for _, output := range response.Output {
|
||||
log.Debugf("Score: %f, Doc: %s", output.Score, output.Fields.Raw)
|
||||
if output.Score <= float32(config.DashVectorThreshold) {
|
||||
recallDocs = append(recallDocs, output.Fields.Raw)
|
||||
recallDocIds = append(recallDocIds, output.ID)
|
||||
}
|
||||
}
|
||||
if len(recallDocs) > 0 {
|
||||
rawRequest.Messages = rawRequest.Messages[:messageLength-1]
|
||||
traceStr := strings.Join(recallDocIds, ", ")
|
||||
proxywasm.SetProperty([]string{"trace_span_tag.rag_docs"}, []byte(traceStr))
|
||||
for _, doc := range recallDocs {
|
||||
rawRequest.Messages = append(rawRequest.Messages, Message{"user", doc})
|
||||
}
|
||||
rawRequest.Messages = append(rawRequest.Messages, Message{"user", fmt.Sprintf("现在,请回答以下问题:\n%s", rawContent)})
|
||||
newBody, _ := json.Marshal(rawRequest)
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
ctx.SetContext("x-envoy-rag-recall", true)
|
||||
}
|
||||
proxywasm.ResumeHttpRequest()
|
||||
},
|
||||
)
|
||||
@@ -124,3 +164,13 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
)
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||
recall, ok := ctx.GetContext("x-envoy-rag-recall").(bool)
|
||||
if ok && recall {
|
||||
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "true")
|
||||
} else {
|
||||
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "false")
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user