mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
add plugin: ai-rag (#1038)
This commit is contained in:
3
plugins/wasm-go/extensions/ai-rag/.gitignore
vendored
Normal file
3
plugins/wasm-go/extensions/ai-rag/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
config.yaml
|
||||||
|
main.wasm
|
||||||
|
tmp/
|
||||||
49
plugins/wasm-go/extensions/ai-rag/README.md
Normal file
49
plugins/wasm-go/extensions/ai-rag/README.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# 简介
|
||||||
|
通过对接阿里云向量检索服务实现LLM-RAG,流程如图所示:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
# 配置说明
|
||||||
|
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||||
|
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
|
||||||
|
| `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 |
|
||||||
|
| `dashscope.serviceName` | string | 必填 | - | 通义千问服务名 |
|
||||||
|
| `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 |
|
||||||
|
| `dashscope.domain` | string | 必填 | - | 访问通义千问服务时域名 |
|
||||||
|
| `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 |
|
||||||
|
| `dashvector.serviceName` | string | 必填 | - | 阿里云向量检索服务名 |
|
||||||
|
| `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 |
|
||||||
|
| `dashvector.domain` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
|
||||||
|
|
||||||
|
# 示例
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
dashscope:
|
||||||
|
apiKey: xxxxxxxxxxxxxxx
|
||||||
|
serviceName: dashscope
|
||||||
|
servicePort: 443
|
||||||
|
domain: dashscope.aliyuncs.com
|
||||||
|
dashvector:
|
||||||
|
apiKey: xxxxxxxxxxxxxxxxxxxx
|
||||||
|
serviceName: dashvector
|
||||||
|
servicePort: 443
|
||||||
|
domain: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
|
||||||
|
collection: xxxxxxxxxxxxxxx
|
||||||
|
```
|
||||||
|
|
||||||
|
[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。
|
||||||
|
|
||||||
|
以下为使用RAG进行增强的例子,原始请求为:
|
||||||
|
```
|
||||||
|
海南追尾事故,发生在哪里?原因是什么?人员伤亡情况如何?
|
||||||
|
```
|
||||||
|
|
||||||
|
未经过RAG插件处理LLM返回的结果为:
|
||||||
|
```
|
||||||
|
抱歉,作为AI模型,我无法实时获取和更新新闻事件的具体信息,包括地点、原因、人员伤亡等细节。对于此类具体事件,建议您查阅最新的新闻报道或官方通报以获取准确信息。您可以访问主流媒体网站、使用新闻应用或者关注相关政府部门的公告来获取这类动态资讯。
|
||||||
|
```
|
||||||
|
|
||||||
|
经过RAG插件处理后LLM返回的结果为:
|
||||||
|
```
|
||||||
|
海南追尾事故发生在海文高速公路文昌至海口方向37公里处。关于事故的具体原因,交警部门当时仍在进一步调查中,所以根据提供的信息无法确定事故的确切原因。人员伤亡情况是1人死亡(司机当场死亡),另有8人受伤(包括2名儿童和6名成人),所有受伤人员都被解救并送往医院进行治疗。
|
||||||
|
```
|
||||||
36
plugins/wasm-go/extensions/ai-rag/dashscope/types.go
Normal file
36
plugins/wasm-go/extensions/ai-rag/dashscope/types.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package dashscope
|
||||||
|
|
||||||
|
// DashScope embedding service: Request
|
||||||
|
type Request struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input Input `json:"input"`
|
||||||
|
Parameter Parameter `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Input struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Parameter struct {
|
||||||
|
TextType string `json:"text_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashScope embedding service: Response
|
||||||
|
type Response struct {
|
||||||
|
Output Output `json:"output"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Output struct {
|
||||||
|
Embeddings []Embedding `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Embedding struct {
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
TextIndex int32 `json:"text_index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
TotalTokens int32 `json:"total_tokens"`
|
||||||
|
}
|
||||||
26
plugins/wasm-go/extensions/ai-rag/dashvector/types.go
Normal file
26
plugins/wasm-go/extensions/ai-rag/dashvector/types.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package dashvector
|
||||||
|
|
||||||
|
// DashVecotor document search: Request
|
||||||
|
type Request struct {
|
||||||
|
TopK int32 `json:"topk"`
|
||||||
|
OutputFileds []string `json:"output_fileds"`
|
||||||
|
Vector []float32 `json:"vector"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashVecotor document search: Response
|
||||||
|
type Response struct {
|
||||||
|
Code int32 `json:"code"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Output []OutputObject `json:"output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OutputObject struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Fields FieldObject `json:"fields"`
|
||||||
|
Score float32 `json:"score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FieldObject struct {
|
||||||
|
Raw string `json:"raw"`
|
||||||
|
}
|
||||||
19
plugins/wasm-go/extensions/ai-rag/go.mod
Normal file
19
plugins/wasm-go/extensions/ai-rag/go.mod
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
module ai-rag
|
||||||
|
|
||||||
|
go 1.18
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/alibaba/higress/plugins/wasm-go v1.3.5
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a
|
||||||
|
github.com/tidwall/gjson v1.14.3
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||||
|
github.com/magefile/mage v1.14.0 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
|
github.com/tidwall/resp v0.1.1 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5
|
||||||
|
)
|
||||||
25
plugins/wasm-go/extensions/ai-rag/go.sum
Normal file
25
plugins/wasm-go/extensions/ai-rag/go.sum
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
|
||||||
|
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||||
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||||
|
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||||
|
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||||
|
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||||
|
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
126
plugins/wasm-go/extensions/ai-rag/main.go
Normal file
126
plugins/wasm-go/extensions/ai-rag/main.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"myplugin/dashscope"
|
||||||
|
"myplugin/dashvector"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
wrapper.SetCtx(
|
||||||
|
"ai-rag",
|
||||||
|
wrapper.ParseConfigBy(parseConfig),
|
||||||
|
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||||
|
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIRagConfig struct {
|
||||||
|
DashScopeClient wrapper.HttpClient
|
||||||
|
DashScopeAPIKey string
|
||||||
|
DashVectorClient wrapper.HttpClient
|
||||||
|
DashVectorAPIKey string
|
||||||
|
DashVectorCollection string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
Topp int32 `json:"top_p"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
|
||||||
|
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
|
||||||
|
|
||||||
|
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||||
|
ServiceName: json.Get("dashscope.serviceName").String(),
|
||||||
|
Port: json.Get("dashscope.servicePort").Int(),
|
||||||
|
Domain: json.Get("dashscope.domain").String(),
|
||||||
|
})
|
||||||
|
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
|
||||||
|
config.DashVectorCollection = json.Get("dashvector.collection").String()
|
||||||
|
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||||
|
ServiceName: json.Get("dashvector.serviceName").String(),
|
||||||
|
Port: json.Get("dashvector.servicePort").Int(),
|
||||||
|
Domain: json.Get("dashvector.domain").String(),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||||
|
p, _ := proxywasm.GetHttpRequestHeader(":path")
|
||||||
|
if p != "/api/openai/v1/chat/completions" {
|
||||||
|
ctx.DontReadRequestBody()
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, log wrapper.Log) types.Action {
|
||||||
|
var rawRequest Request
|
||||||
|
_ = json.Unmarshal(body, &rawRequest)
|
||||||
|
messageLength := len(rawRequest.Messages)
|
||||||
|
rawContent := rawRequest.Messages[messageLength-1].Content
|
||||||
|
requestEmbedding := dashscope.Request{
|
||||||
|
Model: "text-embedding-v1",
|
||||||
|
Input: dashscope.Input{
|
||||||
|
Texts: []string{rawContent},
|
||||||
|
},
|
||||||
|
Parameter: dashscope.Parameter{
|
||||||
|
TextType: "query",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
|
||||||
|
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
|
||||||
|
// log.Info(string(reqEmbeddingSerialized))
|
||||||
|
config.DashScopeClient.Post(
|
||||||
|
"/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||||
|
headers,
|
||||||
|
reqEmbeddingSerialized,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
var responseEmbedding dashscope.Response
|
||||||
|
_ = json.Unmarshal(responseBody, &responseEmbedding)
|
||||||
|
requestQuery := dashvector.Request{
|
||||||
|
TopK: 1,
|
||||||
|
OutputFileds: []string{"raw"},
|
||||||
|
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
|
||||||
|
}
|
||||||
|
requestQuerySerialized, _ := json.Marshal(requestQuery)
|
||||||
|
config.DashVectorClient.Post(
|
||||||
|
fmt.Sprintf("/v1/collections/%s/query", config.DashVectorCollection),
|
||||||
|
[][2]string{{"Content-Type", "application/json"}, {"dashvector-auth-token", config.DashVectorAPIKey}},
|
||||||
|
requestQuerySerialized,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
var response dashvector.Response
|
||||||
|
_ = json.Unmarshal(responseBody, &response)
|
||||||
|
doc := response.Output[0].Fields.Raw
|
||||||
|
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息,你可以自行选择是否使用这些参考信息,现在请回答以下问题:\n%s", doc, rawContent)
|
||||||
|
newBody, _ := json.Marshal(rawRequest)
|
||||||
|
// log.Info(string(newBody))
|
||||||
|
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||||
|
proxywasm.ResumeHttpRequest()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
50000,
|
||||||
|
)
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user