mirror of
https://github.com/alibaba/higress.git
synced 2026-03-06 17:40:51 +08:00
add plugin: ai-rag (#1038)
This commit is contained in:
3
plugins/wasm-go/extensions/ai-rag/.gitignore
vendored
Normal file
3
plugins/wasm-go/extensions/ai-rag/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
config.yaml
|
||||
main.wasm
|
||||
tmp/
|
||||
49
plugins/wasm-go/extensions/ai-rag/README.md
Normal file
49
plugins/wasm-go/extensions/ai-rag/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# 简介
|
||||
通过对接阿里云向量检索服务实现LLM-RAG,流程如图所示:
|
||||
|
||||

|
||||
|
||||
# 配置说明
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
|
||||
| `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 |
|
||||
| `dashscope.serviceName` | string | 必填 | - | 通义千问服务名 |
|
||||
| `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 |
|
||||
| `dashscope.domain` | string | 必填 | - | 访问通义千问服务时域名 |
|
||||
| `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 |
|
||||
| `dashvector.serviceName` | string | 必填 | - | 阿里云向量检索服务名 |
|
||||
| `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 |
|
||||
| `dashvector.domain` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
|
||||
|
||||
# 示例
|
||||
|
||||
```yaml
|
||||
dashscope:
|
||||
apiKey: xxxxxxxxxxxxxxx
|
||||
serviceName: dashscope
|
||||
servicePort: 443
|
||||
domain: dashscope.aliyuncs.com
|
||||
dashvector:
|
||||
apiKey: xxxxxxxxxxxxxxxxxxxx
|
||||
serviceName: dashvector
|
||||
servicePort: 443
|
||||
domain: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
|
||||
collection: xxxxxxxxxxxxxxx
|
||||
```
|
||||
|
||||
[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。
|
||||
|
||||
以下为使用RAG进行增强的例子,原始请求为:
|
||||
```
|
||||
海南追尾事故,发生在哪里?原因是什么?人员伤亡情况如何?
|
||||
```
|
||||
|
||||
未经过RAG插件处理LLM返回的结果为:
|
||||
```
|
||||
抱歉,作为AI模型,我无法实时获取和更新新闻事件的具体信息,包括地点、原因、人员伤亡等细节。对于此类具体事件,建议您查阅最新的新闻报道或官方通报以获取准确信息。您可以访问主流媒体网站、使用新闻应用或者关注相关政府部门的公告来获取这类动态资讯。
|
||||
```
|
||||
|
||||
经过RAG插件处理后LLM返回的结果为:
|
||||
```
|
||||
海南追尾事故发生在海文高速公路文昌至海口方向37公里处。关于事故的具体原因,交警部门当时仍在进一步调查中,所以根据提供的信息无法确定事故的确切原因。人员伤亡情况是1人死亡(司机当场死亡),另有8人受伤(包括2名儿童和6名成人),所有受伤人员都被解救并送往医院进行治疗。
|
||||
```
|
||||
36
plugins/wasm-go/extensions/ai-rag/dashscope/types.go
Normal file
36
plugins/wasm-go/extensions/ai-rag/dashscope/types.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package dashscope
|
||||
|
||||
// DashScope embedding service: Request
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Input Input `json:"input"`
|
||||
Parameter Parameter `json:"parameters"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
TextType string `json:"text_type"`
|
||||
}
|
||||
|
||||
// DashScope embedding service: Response
|
||||
type Response struct {
|
||||
Output Output `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
TextIndex int32 `json:"text_index"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
TotalTokens int32 `json:"total_tokens"`
|
||||
}
|
||||
26
plugins/wasm-go/extensions/ai-rag/dashvector/types.go
Normal file
26
plugins/wasm-go/extensions/ai-rag/dashvector/types.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package dashvector
|
||||
|
||||
// DashVecotor document search: Request
|
||||
type Request struct {
|
||||
TopK int32 `json:"topk"`
|
||||
OutputFileds []string `json:"output_fileds"`
|
||||
Vector []float32 `json:"vector"`
|
||||
}
|
||||
|
||||
// DashVecotor document search: Response
|
||||
type Response struct {
|
||||
Code int32 `json:"code"`
|
||||
RequestID string `json:"request_id"`
|
||||
Message string `json:"message"`
|
||||
Output []OutputObject `json:"output"`
|
||||
}
|
||||
|
||||
type OutputObject struct {
|
||||
ID string `json:"id"`
|
||||
Fields FieldObject `json:"fields"`
|
||||
Score float32 `json:"score"`
|
||||
}
|
||||
|
||||
type FieldObject struct {
|
||||
Raw string `json:"raw"`
|
||||
}
|
||||
19
plugins/wasm-go/extensions/ai-rag/go.mod
Normal file
19
plugins/wasm-go/extensions/ai-rag/go.mod
Normal file
@@ -0,0 +1,19 @@
|
||||
module ai-rag
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a
|
||||
github.com/tidwall/gjson v1.14.3
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
25
plugins/wasm-go/extensions/ai-rag/go.sum
Normal file
25
plugins/wasm-go/extensions/ai-rag/go.sum
Normal file
@@ -0,0 +1,25 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
126
plugins/wasm-go/extensions/ai-rag/main.go
Normal file
126
plugins/wasm-go/extensions/ai-rag/main.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"myplugin/dashscope"
|
||||
"myplugin/dashvector"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-rag",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
)
|
||||
}
|
||||
|
||||
type AIRagConfig struct {
|
||||
DashScopeClient wrapper.HttpClient
|
||||
DashScopeAPIKey string
|
||||
DashVectorClient wrapper.HttpClient
|
||||
DashVectorAPIKey string
|
||||
DashVectorCollection string
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Topp int32 `json:"top_p"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
|
||||
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
|
||||
|
||||
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: json.Get("dashscope.serviceName").String(),
|
||||
Port: json.Get("dashscope.servicePort").Int(),
|
||||
Domain: json.Get("dashscope.domain").String(),
|
||||
})
|
||||
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
|
||||
config.DashVectorCollection = json.Get("dashvector.collection").String()
|
||||
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: json.Get("dashvector.serviceName").String(),
|
||||
Port: json.Get("dashvector.servicePort").Int(),
|
||||
Domain: json.Get("dashvector.domain").String(),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||
p, _ := proxywasm.GetHttpRequestHeader(":path")
|
||||
if p != "/api/openai/v1/chat/completions" {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, log wrapper.Log) types.Action {
|
||||
var rawRequest Request
|
||||
_ = json.Unmarshal(body, &rawRequest)
|
||||
messageLength := len(rawRequest.Messages)
|
||||
rawContent := rawRequest.Messages[messageLength-1].Content
|
||||
requestEmbedding := dashscope.Request{
|
||||
Model: "text-embedding-v1",
|
||||
Input: dashscope.Input{
|
||||
Texts: []string{rawContent},
|
||||
},
|
||||
Parameter: dashscope.Parameter{
|
||||
TextType: "query",
|
||||
},
|
||||
}
|
||||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
|
||||
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
|
||||
// log.Info(string(reqEmbeddingSerialized))
|
||||
config.DashScopeClient.Post(
|
||||
"/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
headers,
|
||||
reqEmbeddingSerialized,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
var responseEmbedding dashscope.Response
|
||||
_ = json.Unmarshal(responseBody, &responseEmbedding)
|
||||
requestQuery := dashvector.Request{
|
||||
TopK: 1,
|
||||
OutputFileds: []string{"raw"},
|
||||
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
|
||||
}
|
||||
requestQuerySerialized, _ := json.Marshal(requestQuery)
|
||||
config.DashVectorClient.Post(
|
||||
fmt.Sprintf("/v1/collections/%s/query", config.DashVectorCollection),
|
||||
[][2]string{{"Content-Type", "application/json"}, {"dashvector-auth-token", config.DashVectorAPIKey}},
|
||||
requestQuerySerialized,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
var response dashvector.Response
|
||||
_ = json.Unmarshal(responseBody, &response)
|
||||
doc := response.Output[0].Fields.Raw
|
||||
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息,你可以自行选择是否使用这些参考信息,现在请回答以下问题:\n%s", doc, rawContent)
|
||||
newBody, _ := json.Marshal(rawRequest)
|
||||
// log.Info(string(newBody))
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
},
|
||||
)
|
||||
},
|
||||
50000,
|
||||
)
|
||||
return types.ActionPause
|
||||
}
|
||||
Reference in New Issue
Block a user