mirror of
https://github.com/alibaba/higress.git
synced 2026-06-10 05:07:30 +08:00
Optimize ai-rag plugin (#1170)
This commit is contained in:
@@ -1,34 +1,42 @@
|
|||||||
# 简介
|
# 简介
|
||||||
通过对接阿里云向量检索服务实现LLM-RAG,流程如图所示:
|
通过对接阿里云向量检索服务实现LLM-RAG,流程如图所示:
|
||||||
|
|
||||||

|
<img src="https://img.alicdn.com/imgextra/i1/O1CN01LuRVs41KhoeuzakeF_!!6000000001196-0-tps-1926-1316.jpg" width=600>
|
||||||
|
|
||||||
# 配置说明
|
# 配置说明
|
||||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||||
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
|
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
|
||||||
| `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 |
|
| `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 |
|
||||||
| `dashscope.serviceName` | string | 必填 | - | 通义千问服务名 |
|
| `dashscope.serviceFQDN` | string | 必填 | - | 通义千问服务名 |
|
||||||
| `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 |
|
| `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 |
|
||||||
| `dashscope.domain` | string | 必填 | - | 访问通义千问服务时域名 |
|
| `dashscope.serviceHost` | string | 必填 | - | 访问通义千问服务时域名 |
|
||||||
| `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 |
|
| `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 |
|
||||||
| `dashvector.serviceName` | string | 必填 | - | 阿里云向量检索服务名 |
|
| `dashvector.serviceFQDN` | string | 必填 | - | 阿里云向量检索服务名 |
|
||||||
| `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 |
|
| `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 |
|
||||||
| `dashvector.domain` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
|
| `dashvector.serviceHost` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
|
||||||
|
| `dashvector.topk` | int | 必填 | - | 阿里云向量检索时获取向量数 |
|
||||||
|
| `dashvector.threshold` | float | 必填 | - | 向量距离阈值,高于该阈值的文档会被过滤掉 |
|
||||||
|
| `dashvector.field` | string | 必填 | - | 阿里云向量检索存储文档的字段名 |
|
||||||
|
|
||||||
|
插件开启后,在使用链路追踪功能时,会在span的attribute中添加rag检索到的文档id信息,供排查问题使用。
|
||||||
|
|
||||||
# 示例
|
# 示例
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
dashscope:
|
dashscope:
|
||||||
apiKey: xxxxxxxxxxxxxxx
|
apiKey: xxxxxxxxxxxxxxx
|
||||||
serviceName: dashscope
|
serviceFQDN: dashscope
|
||||||
servicePort: 443
|
servicePort: 443
|
||||||
domain: dashscope.aliyuncs.com
|
serviceHost: dashscope.aliyuncs.com
|
||||||
dashvector:
|
dashvector:
|
||||||
apiKey: xxxxxxxxxxxxxxxxxxxx
|
apiKey: xxxxxxxxxxxxxxxxxxxx
|
||||||
serviceName: dashvector
|
serviceFQDN: dashvector
|
||||||
servicePort: 443
|
servicePort: 443
|
||||||
domain: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
|
serviceHost: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
|
||||||
collection: xxxxxxxxxxxxxxx
|
collection: xxxxxxxxxxxxxxx
|
||||||
|
topk: 1
|
||||||
|
threshold: 0.4
|
||||||
|
field: raw
|
||||||
```
|
```
|
||||||
|
|
||||||
[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。
|
[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
|
||||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
|
||||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"ai-rag/dashscope"
|
"ai-rag/dashscope"
|
||||||
"ai-rag/dashvector"
|
"ai-rag/dashvector"
|
||||||
@@ -20,6 +21,7 @@ func main() {
|
|||||||
wrapper.ParseConfigBy(parseConfig),
|
wrapper.ParseConfigBy(parseConfig),
|
||||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||||
|
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,6 +31,9 @@ type AIRagConfig struct {
|
|||||||
DashVectorClient wrapper.HttpClient
|
DashVectorClient wrapper.HttpClient
|
||||||
DashVectorAPIKey string
|
DashVectorAPIKey string
|
||||||
DashVectorCollection string
|
DashVectorCollection string
|
||||||
|
DashVectorTopK int32
|
||||||
|
DashVectorThreshold float64
|
||||||
|
DashVectorField string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
@@ -47,29 +52,46 @@ type Message struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
|
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
|
||||||
|
checkList := []string{
|
||||||
|
"dashscope.apiKey",
|
||||||
|
"dashscope.serviceFQDN",
|
||||||
|
"dashscope.servicePort",
|
||||||
|
"dashscope.serviceHost",
|
||||||
|
"dashvector.apiKey",
|
||||||
|
"dashvector.collection",
|
||||||
|
"dashvector.serviceFQDN",
|
||||||
|
"dashvector.servicePort",
|
||||||
|
"dashvector.serviceHost",
|
||||||
|
"dashvector.topk",
|
||||||
|
"dashvector.threshold",
|
||||||
|
"dashvector.field",
|
||||||
|
}
|
||||||
|
for _, checkEntry := range checkList {
|
||||||
|
if !json.Get(checkEntry).Exists() {
|
||||||
|
return fmt.Errorf("%s not found in plugin config!", checkEntry)
|
||||||
|
}
|
||||||
|
}
|
||||||
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
|
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
|
||||||
|
|
||||||
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
config.DashScopeClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
ServiceName: json.Get("dashscope.serviceName").String(),
|
FQDN: json.Get("dashscope.serviceFQDN").String(),
|
||||||
Port: json.Get("dashscope.servicePort").Int(),
|
Port: json.Get("dashscope.servicePort").Int(),
|
||||||
Domain: json.Get("dashscope.domain").String(),
|
Host: json.Get("dashscope.serviceHost").String(),
|
||||||
})
|
})
|
||||||
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
|
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
|
||||||
config.DashVectorCollection = json.Get("dashvector.collection").String()
|
config.DashVectorCollection = json.Get("dashvector.collection").String()
|
||||||
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
config.DashVectorClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
ServiceName: json.Get("dashvector.serviceName").String(),
|
FQDN: json.Get("dashvector.serviceFQDN").String(),
|
||||||
Port: json.Get("dashvector.servicePort").Int(),
|
Port: json.Get("dashvector.servicePort").Int(),
|
||||||
Domain: json.Get("dashvector.domain").String(),
|
Host: json.Get("dashvector.serviceHost").String(),
|
||||||
})
|
})
|
||||||
|
config.DashVectorTopK = int32(json.Get("dashvector.topk").Int())
|
||||||
|
config.DashVectorThreshold = json.Get("dashvector.threshold").Float()
|
||||||
|
config.DashVectorField = json.Get("dashvector.field").String()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||||
p, _ := proxywasm.GetHttpRequestHeader(":path")
|
|
||||||
if p != "/api/openai/v1/chat/completions" {
|
|
||||||
ctx.DontReadRequestBody()
|
|
||||||
return types.ActionContinue
|
|
||||||
}
|
|
||||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
@@ -78,9 +100,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
|||||||
var rawRequest Request
|
var rawRequest Request
|
||||||
_ = json.Unmarshal(body, &rawRequest)
|
_ = json.Unmarshal(body, &rawRequest)
|
||||||
messageLength := len(rawRequest.Messages)
|
messageLength := len(rawRequest.Messages)
|
||||||
|
if messageLength == 0 {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
rawContent := rawRequest.Messages[messageLength-1].Content
|
rawContent := rawRequest.Messages[messageLength-1].Content
|
||||||
requestEmbedding := dashscope.Request{
|
requestEmbedding := dashscope.Request{
|
||||||
Model: "text-embedding-v1",
|
Model: "text-embedding-v2",
|
||||||
Input: dashscope.Input{
|
Input: dashscope.Input{
|
||||||
Texts: []string{rawContent},
|
Texts: []string{rawContent},
|
||||||
},
|
},
|
||||||
@@ -90,7 +115,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
|||||||
}
|
}
|
||||||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
|
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
|
||||||
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
|
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
|
||||||
// log.Info(string(reqEmbeddingSerialized))
|
|
||||||
config.DashScopeClient.Post(
|
config.DashScopeClient.Post(
|
||||||
"/api/v1/services/embeddings/text-embedding/text-embedding",
|
"/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||||
headers,
|
headers,
|
||||||
@@ -99,8 +123,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
|||||||
var responseEmbedding dashscope.Response
|
var responseEmbedding dashscope.Response
|
||||||
_ = json.Unmarshal(responseBody, &responseEmbedding)
|
_ = json.Unmarshal(responseBody, &responseEmbedding)
|
||||||
requestQuery := dashvector.Request{
|
requestQuery := dashvector.Request{
|
||||||
TopK: 1,
|
TopK: config.DashVectorTopK,
|
||||||
OutputFileds: []string{"raw"},
|
OutputFileds: []string{config.DashVectorField},
|
||||||
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
|
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
|
||||||
}
|
}
|
||||||
requestQuerySerialized, _ := json.Marshal(requestQuery)
|
requestQuerySerialized, _ := json.Marshal(requestQuery)
|
||||||
@@ -111,11 +135,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
|||||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
var response dashvector.Response
|
var response dashvector.Response
|
||||||
_ = json.Unmarshal(responseBody, &response)
|
_ = json.Unmarshal(responseBody, &response)
|
||||||
doc := response.Output[0].Fields.Raw
|
recallDocIds := []string{}
|
||||||
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息,你可以自行选择是否使用这些参考信息,现在请回答以下问题:\n%s", doc, rawContent)
|
recallDocs := []string{}
|
||||||
newBody, _ := json.Marshal(rawRequest)
|
for _, output := range response.Output {
|
||||||
// log.Info(string(newBody))
|
log.Debugf("Score: %f, Doc: %s", output.Score, output.Fields.Raw)
|
||||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
if output.Score <= float32(config.DashVectorThreshold) {
|
||||||
|
recallDocs = append(recallDocs, output.Fields.Raw)
|
||||||
|
recallDocIds = append(recallDocIds, output.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(recallDocs) > 0 {
|
||||||
|
rawRequest.Messages = rawRequest.Messages[:messageLength-1]
|
||||||
|
traceStr := strings.Join(recallDocIds, ", ")
|
||||||
|
proxywasm.SetProperty([]string{"trace_span_tag.rag_docs"}, []byte(traceStr))
|
||||||
|
for _, doc := range recallDocs {
|
||||||
|
rawRequest.Messages = append(rawRequest.Messages, Message{"user", doc})
|
||||||
|
}
|
||||||
|
rawRequest.Messages = append(rawRequest.Messages, Message{"user", fmt.Sprintf("现在,请回答以下问题:\n%s", rawContent)})
|
||||||
|
newBody, _ := json.Marshal(rawRequest)
|
||||||
|
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||||
|
ctx.SetContext("x-envoy-rag-recall", true)
|
||||||
|
}
|
||||||
proxywasm.ResumeHttpRequest()
|
proxywasm.ResumeHttpRequest()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -124,3 +164,13 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
|||||||
)
|
)
|
||||||
return types.ActionPause
|
return types.ActionPause
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||||
|
recall, ok := ctx.GetContext("x-envoy-rag-recall").(bool)
|
||||||
|
if ok && recall {
|
||||||
|
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "true")
|
||||||
|
} else {
|
||||||
|
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "false")
|
||||||
|
}
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user