mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
Optimize ai-rag plugin (#1170)
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ai-rag/dashscope"
|
||||
"ai-rag/dashvector"
|
||||
@@ -20,6 +21,7 @@ func main() {
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -29,6 +31,9 @@ type AIRagConfig struct {
|
||||
DashVectorClient wrapper.HttpClient
|
||||
DashVectorAPIKey string
|
||||
DashVectorCollection string
|
||||
DashVectorTopK int32
|
||||
DashVectorThreshold float64
|
||||
DashVectorField string
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
@@ -47,29 +52,46 @@ type Message struct {
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
|
||||
checkList := []string{
|
||||
"dashscope.apiKey",
|
||||
"dashscope.serviceFQDN",
|
||||
"dashscope.servicePort",
|
||||
"dashscope.serviceHost",
|
||||
"dashvector.apiKey",
|
||||
"dashvector.collection",
|
||||
"dashvector.serviceFQDN",
|
||||
"dashvector.servicePort",
|
||||
"dashvector.serviceHost",
|
||||
"dashvector.topk",
|
||||
"dashvector.threshold",
|
||||
"dashvector.field",
|
||||
}
|
||||
for _, checkEntry := range checkList {
|
||||
if !json.Get(checkEntry).Exists() {
|
||||
return fmt.Errorf("%s not found in plugin config!", checkEntry)
|
||||
}
|
||||
}
|
||||
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
|
||||
|
||||
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: json.Get("dashscope.serviceName").String(),
|
||||
Port: json.Get("dashscope.servicePort").Int(),
|
||||
Domain: json.Get("dashscope.domain").String(),
|
||||
config.DashScopeClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: json.Get("dashscope.serviceFQDN").String(),
|
||||
Port: json.Get("dashscope.servicePort").Int(),
|
||||
Host: json.Get("dashscope.serviceHost").String(),
|
||||
})
|
||||
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
|
||||
config.DashVectorCollection = json.Get("dashvector.collection").String()
|
||||
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: json.Get("dashvector.serviceName").String(),
|
||||
Port: json.Get("dashvector.servicePort").Int(),
|
||||
Domain: json.Get("dashvector.domain").String(),
|
||||
config.DashVectorClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: json.Get("dashvector.serviceFQDN").String(),
|
||||
Port: json.Get("dashvector.servicePort").Int(),
|
||||
Host: json.Get("dashvector.serviceHost").String(),
|
||||
})
|
||||
config.DashVectorTopK = int32(json.Get("dashvector.topk").Int())
|
||||
config.DashVectorThreshold = json.Get("dashvector.threshold").Float()
|
||||
config.DashVectorField = json.Get("dashvector.field").String()
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||
p, _ := proxywasm.GetHttpRequestHeader(":path")
|
||||
if p != "/api/openai/v1/chat/completions" {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -78,9 +100,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
var rawRequest Request
|
||||
_ = json.Unmarshal(body, &rawRequest)
|
||||
messageLength := len(rawRequest.Messages)
|
||||
if messageLength == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
rawContent := rawRequest.Messages[messageLength-1].Content
|
||||
requestEmbedding := dashscope.Request{
|
||||
Model: "text-embedding-v1",
|
||||
Model: "text-embedding-v2",
|
||||
Input: dashscope.Input{
|
||||
Texts: []string{rawContent},
|
||||
},
|
||||
@@ -90,7 +115,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
}
|
||||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
|
||||
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
|
||||
// log.Info(string(reqEmbeddingSerialized))
|
||||
config.DashScopeClient.Post(
|
||||
"/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
headers,
|
||||
@@ -99,8 +123,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
var responseEmbedding dashscope.Response
|
||||
_ = json.Unmarshal(responseBody, &responseEmbedding)
|
||||
requestQuery := dashvector.Request{
|
||||
TopK: 1,
|
||||
OutputFileds: []string{"raw"},
|
||||
TopK: config.DashVectorTopK,
|
||||
OutputFileds: []string{config.DashVectorField},
|
||||
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
|
||||
}
|
||||
requestQuerySerialized, _ := json.Marshal(requestQuery)
|
||||
@@ -111,11 +135,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
var response dashvector.Response
|
||||
_ = json.Unmarshal(responseBody, &response)
|
||||
doc := response.Output[0].Fields.Raw
|
||||
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息,你可以自行选择是否使用这些参考信息,现在请回答以下问题:\n%s", doc, rawContent)
|
||||
newBody, _ := json.Marshal(rawRequest)
|
||||
// log.Info(string(newBody))
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
recallDocIds := []string{}
|
||||
recallDocs := []string{}
|
||||
for _, output := range response.Output {
|
||||
log.Debugf("Score: %f, Doc: %s", output.Score, output.Fields.Raw)
|
||||
if output.Score <= float32(config.DashVectorThreshold) {
|
||||
recallDocs = append(recallDocs, output.Fields.Raw)
|
||||
recallDocIds = append(recallDocIds, output.ID)
|
||||
}
|
||||
}
|
||||
if len(recallDocs) > 0 {
|
||||
rawRequest.Messages = rawRequest.Messages[:messageLength-1]
|
||||
traceStr := strings.Join(recallDocIds, ", ")
|
||||
proxywasm.SetProperty([]string{"trace_span_tag.rag_docs"}, []byte(traceStr))
|
||||
for _, doc := range recallDocs {
|
||||
rawRequest.Messages = append(rawRequest.Messages, Message{"user", doc})
|
||||
}
|
||||
rawRequest.Messages = append(rawRequest.Messages, Message{"user", fmt.Sprintf("现在,请回答以下问题:\n%s", rawContent)})
|
||||
newBody, _ := json.Marshal(rawRequest)
|
||||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||||
ctx.SetContext("x-envoy-rag-recall", true)
|
||||
}
|
||||
proxywasm.ResumeHttpRequest()
|
||||
},
|
||||
)
|
||||
@@ -124,3 +164,13 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
)
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||
recall, ok := ctx.GetContext("x-envoy-rag-recall").(bool)
|
||||
if ok && recall {
|
||||
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "true")
|
||||
} else {
|
||||
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "false")
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user