mirror of
https://github.com/alibaba/higress.git
synced 2026-03-06 09:30:49 +08:00
127 lines
4.3 KiB
Go
127 lines
4.3 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
|
||
"ai-rag/dashscope"
|
||
"ai-rag/dashvector"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
func main() {
|
||
wrapper.SetCtx(
|
||
"ai-rag",
|
||
wrapper.ParseConfigBy(parseConfig),
|
||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||
)
|
||
}
|
||
|
||
type AIRagConfig struct {
|
||
DashScopeClient wrapper.HttpClient
|
||
DashScopeAPIKey string
|
||
DashVectorClient wrapper.HttpClient
|
||
DashVectorAPIKey string
|
||
DashVectorCollection string
|
||
}
|
||
|
||
type Request struct {
|
||
Model string `json:"model"`
|
||
Messages []Message `json:"messages"`
|
||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||
PresencePenalty float64 `json:"presence_penalty"`
|
||
Stream bool `json:"stream"`
|
||
Temperature float64 `json:"temperature"`
|
||
Topp int32 `json:"top_p"`
|
||
}
|
||
|
||
type Message struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
|
||
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
|
||
|
||
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||
ServiceName: json.Get("dashscope.serviceName").String(),
|
||
Port: json.Get("dashscope.servicePort").Int(),
|
||
Domain: json.Get("dashscope.domain").String(),
|
||
})
|
||
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
|
||
config.DashVectorCollection = json.Get("dashvector.collection").String()
|
||
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||
ServiceName: json.Get("dashvector.serviceName").String(),
|
||
Port: json.Get("dashvector.servicePort").Int(),
|
||
Domain: json.Get("dashvector.domain").String(),
|
||
})
|
||
return nil
|
||
}
|
||
|
||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||
p, _ := proxywasm.GetHttpRequestHeader(":path")
|
||
if p != "/api/openai/v1/chat/completions" {
|
||
ctx.DontReadRequestBody()
|
||
return types.ActionContinue
|
||
}
|
||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||
return types.ActionContinue
|
||
}
|
||
|
||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, log wrapper.Log) types.Action {
|
||
var rawRequest Request
|
||
_ = json.Unmarshal(body, &rawRequest)
|
||
messageLength := len(rawRequest.Messages)
|
||
rawContent := rawRequest.Messages[messageLength-1].Content
|
||
requestEmbedding := dashscope.Request{
|
||
Model: "text-embedding-v1",
|
||
Input: dashscope.Input{
|
||
Texts: []string{rawContent},
|
||
},
|
||
Parameter: dashscope.Parameter{
|
||
TextType: "query",
|
||
},
|
||
}
|
||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
|
||
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
|
||
// log.Info(string(reqEmbeddingSerialized))
|
||
config.DashScopeClient.Post(
|
||
"/api/v1/services/embeddings/text-embedding/text-embedding",
|
||
headers,
|
||
reqEmbeddingSerialized,
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
var responseEmbedding dashscope.Response
|
||
_ = json.Unmarshal(responseBody, &responseEmbedding)
|
||
requestQuery := dashvector.Request{
|
||
TopK: 1,
|
||
OutputFileds: []string{"raw"},
|
||
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
|
||
}
|
||
requestQuerySerialized, _ := json.Marshal(requestQuery)
|
||
config.DashVectorClient.Post(
|
||
fmt.Sprintf("/v1/collections/%s/query", config.DashVectorCollection),
|
||
[][2]string{{"Content-Type", "application/json"}, {"dashvector-auth-token", config.DashVectorAPIKey}},
|
||
requestQuerySerialized,
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
var response dashvector.Response
|
||
_ = json.Unmarshal(responseBody, &response)
|
||
doc := response.Output[0].Fields.Raw
|
||
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息,你可以自行选择是否使用这些参考信息,现在请回答以下问题:\n%s", doc, rawContent)
|
||
newBody, _ := json.Marshal(rawRequest)
|
||
// log.Info(string(newBody))
|
||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||
proxywasm.ResumeHttpRequest()
|
||
},
|
||
)
|
||
},
|
||
50000,
|
||
)
|
||
return types.ActionPause
|
||
}
|