mirror of
https://github.com/alibaba/higress.git
synced 2026-06-03 17:47:25 +08:00
[ai-cache] Implement a WASM plugin for LLM result retrieval based on vector similarity (#1290)
This commit is contained in:
256
plugins/wasm-go/extensions/ai-cache/vector/dashvector.go
Normal file
256
plugins/wasm-go/extensions/ai-cache/vector/dashvector.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
type dashVectorProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
if len(config.apiKey) == 0 {
|
||||
return errors.New("[DashVector] apiKey is required")
|
||||
}
|
||||
if len(config.collectionID) == 0 {
|
||||
return errors.New("[DashVector] collectionID is required")
|
||||
}
|
||||
if len(config.serviceName) == 0 {
|
||||
return errors.New("[DashVector] serviceName is required")
|
||||
}
|
||||
if len(config.serviceHost) == 0 {
|
||||
return errors.New("[DashVector] serviceHost is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &DvProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: config.serviceName,
|
||||
Host: config.serviceHost,
|
||||
Port: int64(config.servicePort),
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type DvProvider struct {
|
||||
config ProviderConfig
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (d *DvProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_DASH_VECTOR
|
||||
}
|
||||
|
||||
// type embeddingRequest struct {
|
||||
// Model string `json:"model"`
|
||||
// Input input `json:"input"`
|
||||
// Parameters params `json:"parameters"`
|
||||
// }
|
||||
|
||||
// type params struct {
|
||||
// TextType string `json:"text_type"`
|
||||
// }
|
||||
|
||||
// type input struct {
|
||||
// Texts []string `json:"texts"`
|
||||
// }
|
||||
|
||||
// queryResponse 定义查询响应的结构
|
||||
type queryResponse struct {
|
||||
Code int `json:"code"`
|
||||
RequestID string `json:"request_id"`
|
||||
Message string `json:"message"`
|
||||
Output []result `json:"output"`
|
||||
}
|
||||
|
||||
// queryRequest 定义查询请求的结构
|
||||
type queryRequest struct {
|
||||
Vector []float64 `json:"vector"`
|
||||
TopK int `json:"topk"`
|
||||
IncludeVector bool `json:"include_vector"`
|
||||
}
|
||||
|
||||
// result 定义查询结果的结构
|
||||
type result struct {
|
||||
ID string `json:"id"`
|
||||
Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化
|
||||
Fields map[string]interface{} `json:"fields"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
func (d *DvProvider) constructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) {
|
||||
url := fmt.Sprintf("/v1/collections/%s/query", d.config.collectionID)
|
||||
|
||||
requestData := queryRequest{
|
||||
Vector: vector,
|
||||
TopK: d.config.topK,
|
||||
IncludeVector: false,
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
header := [][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
{"dashvector-auth-token", d.config.apiKey},
|
||||
}
|
||||
|
||||
return url, requestBody, header, nil
|
||||
}
|
||||
|
||||
func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, error) {
|
||||
var queryResp queryResponse
|
||||
err := json.Unmarshal(responseBody, &queryResp)
|
||||
if err != nil {
|
||||
return queryResponse{}, err
|
||||
}
|
||||
return queryResp, nil
|
||||
}
|
||||
|
||||
func (d *DvProvider) QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
url, body, headers, err := d.constructEmbeddingQueryParameters(emb)
|
||||
log.Debugf("url:%s, body:%s, headers:%v", url, string(body), headers)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to construct embedding query parameters: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = d.client.Post(url, headers, body,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
err = nil
|
||||
if statusCode != http.StatusOK {
|
||||
err = fmt.Errorf("failed to query embedding: %d", statusCode)
|
||||
callback(nil, ctx, log, err)
|
||||
return
|
||||
}
|
||||
log.Debugf("query embedding response: %d, %s", statusCode, responseBody)
|
||||
results, err := d.ParseQueryResponse(responseBody, ctx, log)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse query response: %v", err)
|
||||
}
|
||||
callback(results, ctx, log, err)
|
||||
},
|
||||
d.config.timeout)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to query embedding: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func getStringValue(fields map[string]interface{}, key string) string {
|
||||
if val, ok := fields[key]; ok {
|
||||
return val.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) {
|
||||
resp, err := d.parseQueryResponse(responseBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(resp.Output) == 0 {
|
||||
return nil, errors.New("no query results found in response")
|
||||
}
|
||||
|
||||
results := make([]QueryResult, 0, len(resp.Output))
|
||||
|
||||
for _, output := range resp.Output {
|
||||
result := QueryResult{
|
||||
Text: getStringValue(output.Fields, "query"),
|
||||
Embedding: output.Vector,
|
||||
Score: output.Score,
|
||||
Answer: getStringValue(output.Fields, "answer"),
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
type document struct {
|
||||
Vector []float64 `json:"vector"`
|
||||
Fields map[string]string `json:"fields"`
|
||||
}
|
||||
|
||||
type insertRequest struct {
|
||||
Docs []document `json:"docs"`
|
||||
}
|
||||
|
||||
func (d *DvProvider) constructUploadParameters(emb []float64, queryString string, answer string) (string, []byte, [][2]string, error) {
|
||||
url := "/v1/collections/" + d.config.collectionID + "/docs"
|
||||
|
||||
doc := document{
|
||||
Vector: emb,
|
||||
Fields: map[string]string{
|
||||
"query": queryString,
|
||||
"answer": answer,
|
||||
},
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(insertRequest{Docs: []document{doc}})
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
header := [][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
{"dashvector-auth-token", d.config.apiKey},
|
||||
}
|
||||
|
||||
return url, requestBody, header, err
|
||||
}
|
||||
|
||||
func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.client.Post(
|
||||
url,
|
||||
headers,
|
||||
body,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
||||
if statusCode != http.StatusOK {
|
||||
err = fmt.Errorf("failed to upload embedding: %d", statusCode)
|
||||
}
|
||||
callback(ctx, log, err)
|
||||
},
|
||||
d.config.timeout)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.client.Post(
|
||||
url,
|
||||
headers,
|
||||
body,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
||||
if statusCode != http.StatusOK {
|
||||
err = fmt.Errorf("failed to upload embedding: %d", statusCode)
|
||||
}
|
||||
callback(ctx, log, err)
|
||||
},
|
||||
d.config.timeout)
|
||||
return err
|
||||
}
|
||||
167
plugins/wasm-go/extensions/ai-cache/vector/provider.go
Normal file
167
plugins/wasm-go/extensions/ai-cache/vector/provider.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
PROVIDER_TYPE_DASH_VECTOR = "dashvector"
|
||||
PROVIDER_TYPE_CHROMA = "chroma"
|
||||
)
|
||||
|
||||
type providerInitializer interface {
|
||||
ValidateConfig(ProviderConfig) error
|
||||
CreateProvider(ProviderConfig) (Provider, error)
|
||||
}
|
||||
|
||||
var (
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{},
|
||||
// PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
// QueryResult 定义通用的查询结果的结构体
|
||||
type QueryResult struct {
|
||||
Text string // 相似的文本
|
||||
Embedding []float64 // 相似文本的向量
|
||||
Score float64 // 文本的向量相似度或距离等度量
|
||||
Answer string // 相似文本对应的LLM生成的回答
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
GetProviderType() string
|
||||
}
|
||||
|
||||
type EmbeddingQuerier interface {
|
||||
QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
||||
}
|
||||
|
||||
type EmbeddingUploader interface {
|
||||
UploadEmbedding(
|
||||
queryString string,
|
||||
queryEmb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
||||
}
|
||||
|
||||
type AnswerAndEmbeddingUploader interface {
|
||||
UploadAnswerAndEmbedding(
|
||||
queryString string,
|
||||
queryEmb []float64,
|
||||
answer string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
||||
}
|
||||
|
||||
type StringQuerier interface {
|
||||
QueryString(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
||||
}
|
||||
|
||||
type SimilarityThresholdProvider interface {
|
||||
GetSimilarityThreshold() float64
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
// @Title zh-CN 向量存储服务提供者类型
|
||||
// @Description zh-CN 向量存储服务提供者类型,例如 dashvector、chroma
|
||||
typ string
|
||||
// @Title zh-CN 向量存储服务名称
|
||||
// @Description zh-CN 向量存储服务名称
|
||||
serviceName string
|
||||
// @Title zh-CN 向量存储服务域名
|
||||
// @Description zh-CN 向量存储服务域名
|
||||
serviceHost string
|
||||
// @Title zh-CN 向量存储服务端口
|
||||
// @Description zh-CN 向量存储服务端口
|
||||
servicePort int64
|
||||
// @Title zh-CN 向量存储服务 API Key
|
||||
// @Description zh-CN 向量存储服务 API Key
|
||||
apiKey string
|
||||
// @Title zh-CN 返回TopK结果
|
||||
// @Description zh-CN 返回TopK结果,默认为 1
|
||||
topK int
|
||||
// @Title zh-CN 请求超时
|
||||
// @Description zh-CN 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒
|
||||
timeout uint32
|
||||
// @Title zh-CN DashVector 向量存储服务 Collection ID
|
||||
// @Description zh-CN DashVector 向量存储服务 Collection ID
|
||||
collectionID string
|
||||
// @Title zh-CN 相似度度量阈值
|
||||
// @Description zh-CN 默认相似度度量阈值,默认为 1000。
|
||||
Threshold float64
|
||||
// @Title zh-CN 相似度度量比较方式
|
||||
// @Description zh-CN 相似度度量比较方式,默认为小于。
|
||||
// 相似度度量方式有 Cosine, DotProduct, Euclidean 等,前两者值越大相似度越高,后者值越小相似度越高。
|
||||
// 所以需要允许自定义比较方式,对于 Cosine 和 DotProduct 选择 gt,对于 Euclidean 则选择 lt。
|
||||
// 默认为 lt,所有条件包括 lt (less than,小于)、lte (less than or equal to,小等于)、gt (greater than,大于)、gte (greater than or equal to,大等于)
|
||||
ThresholdRelation string
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetProviderType() string {
|
||||
return c.typ
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.typ = json.Get("type").String()
|
||||
// DashVector
|
||||
c.serviceName = json.Get("serviceName").String()
|
||||
c.serviceHost = json.Get("serviceHost").String()
|
||||
c.servicePort = int64(json.Get("servicePort").Int())
|
||||
if c.servicePort == 0 {
|
||||
c.servicePort = 443
|
||||
}
|
||||
c.apiKey = json.Get("apiKey").String()
|
||||
c.collectionID = json.Get("collectionID").String()
|
||||
c.topK = int(json.Get("topK").Int())
|
||||
if c.topK == 0 {
|
||||
c.topK = 1
|
||||
}
|
||||
c.timeout = uint32(json.Get("timeout").Int())
|
||||
if c.timeout == 0 {
|
||||
c.timeout = 10000
|
||||
}
|
||||
c.Threshold = json.Get("threshold").Float()
|
||||
if c.Threshold == 0 {
|
||||
c.Threshold = 1000
|
||||
}
|
||||
c.ThresholdRelation = json.Get("thresholdRelation").String()
|
||||
if c.ThresholdRelation == "" {
|
||||
c.ThresholdRelation = "lt"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
if c.typ == "" {
|
||||
return errors.New("vector database service is required")
|
||||
}
|
||||
initializer, has := providerInitializers[c.typ]
|
||||
if !has {
|
||||
return errors.New("unknown vector database service provider type: " + c.typ)
|
||||
}
|
||||
if err := initializer.ValidateConfig(*c); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
initializer, has := providerInitializers[pc.typ]
|
||||
if !has {
|
||||
return nil, errors.New("unknown provider type: " + pc.typ)
|
||||
}
|
||||
return initializer.CreateProvider(pc)
|
||||
}
|
||||
Reference in New Issue
Block a user