Files
higress/plugins/wasm-go/extensions/ai-rag/main.go
2024-06-19 13:46:17 +08:00

127 lines
4.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"encoding/json"
"fmt"
"net/http"
"ai-rag/dashscope"
"ai-rag/dashvector"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
func main() {
wrapper.SetCtx(
"ai-rag",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
)
}
type AIRagConfig struct {
DashScopeClient wrapper.HttpClient
DashScopeAPIKey string
DashVectorClient wrapper.HttpClient
DashVectorAPIKey string
DashVectorCollection string
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature"`
Topp int32 `json:"top_p"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: json.Get("dashscope.serviceName").String(),
Port: json.Get("dashscope.servicePort").Int(),
Domain: json.Get("dashscope.domain").String(),
})
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
config.DashVectorCollection = json.Get("dashvector.collection").String()
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: json.Get("dashvector.serviceName").String(),
Port: json.Get("dashvector.servicePort").Int(),
Domain: json.Get("dashvector.domain").String(),
})
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
p, _ := proxywasm.GetHttpRequestHeader(":path")
if p != "/api/openai/v1/chat/completions" {
ctx.DontReadRequestBody()
return types.ActionContinue
}
proxywasm.RemoveHttpRequestHeader("content-length")
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, log wrapper.Log) types.Action {
var rawRequest Request
_ = json.Unmarshal(body, &rawRequest)
messageLength := len(rawRequest.Messages)
rawContent := rawRequest.Messages[messageLength-1].Content
requestEmbedding := dashscope.Request{
Model: "text-embedding-v1",
Input: dashscope.Input{
Texts: []string{rawContent},
},
Parameter: dashscope.Parameter{
TextType: "query",
},
}
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
// log.Info(string(reqEmbeddingSerialized))
config.DashScopeClient.Post(
"/api/v1/services/embeddings/text-embedding/text-embedding",
headers,
reqEmbeddingSerialized,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
var responseEmbedding dashscope.Response
_ = json.Unmarshal(responseBody, &responseEmbedding)
requestQuery := dashvector.Request{
TopK: 1,
OutputFileds: []string{"raw"},
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
}
requestQuerySerialized, _ := json.Marshal(requestQuery)
config.DashVectorClient.Post(
fmt.Sprintf("/v1/collections/%s/query", config.DashVectorCollection),
[][2]string{{"Content-Type", "application/json"}, {"dashvector-auth-token", config.DashVectorAPIKey}},
requestQuerySerialized,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
var response dashvector.Response
_ = json.Unmarshal(responseBody, &response)
doc := response.Output[0].Fields.Raw
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息你可以自行选择是否使用这些参考信息现在请回答以下问题\n%s", doc, rawContent)
newBody, _ := json.Marshal(rawRequest)
// log.Info(string(newBody))
proxywasm.ReplaceHttpRequestBody(newBody)
proxywasm.ResumeHttpRequest()
},
)
},
50000,
)
return types.ActionPause
}