mirror of
https://github.com/alibaba/higress.git
synced 2026-04-20 11:37:27 +08:00
feat: Enhance ai-cache Plugin with Vector Similarity-Based LLM Cache Recall and Multi-DB Support (#1248)
This commit is contained in:
201
plugins/wasm-go/extensions/ai-cache/vector/chroma.go
Normal file
201
plugins/wasm-go/extensions/ai-cache/vector/chroma.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package vector
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
type chromaProviderInitializer struct{}
|
||||
|
||||
func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
if len(config.collectionID) == 0 {
|
||||
return errors.New("[Chroma] collectionID is required")
|
||||
}
|
||||
if len(config.serviceName) == 0 {
|
||||
return errors.New("[Chroma] serviceName is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *chromaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &ChromaProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: config.serviceName,
|
||||
Host: config.serviceHost,
|
||||
Port: int64(config.servicePort),
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ChromaProvider struct {
|
||||
config ProviderConfig
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (c *ChromaProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_CHROMA
|
||||
}
|
||||
|
||||
func (d *ChromaProvider) QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
// 最少需要填写的参数为 collection_id, embeddings 和 ids
|
||||
// 下面是一个例子
|
||||
// {
|
||||
// "where": {}, // 用于 metadata 过滤,可选参数
|
||||
// "where_document": {}, // 用于 document 过滤,可选参数
|
||||
// "query_embeddings": [
|
||||
// [1.1, 2.3, 3.2]
|
||||
// ],
|
||||
// "limit": 5,
|
||||
// "include": [
|
||||
// "metadatas", // 可选
|
||||
// "documents", // 如果需要答案则需要
|
||||
// "distances"
|
||||
// ]
|
||||
// }
|
||||
|
||||
requestBody, err := json.Marshal(chromaQueryRequest{
|
||||
QueryEmbeddings: []chromaEmbedding{emb},
|
||||
Limit: d.config.topK,
|
||||
Include: []string{"distances", "documents"},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("[Chroma] Failed to marshal query embedding request body: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return d.client.Post(
|
||||
fmt.Sprintf("/api/v1/collections/%s/query", d.config.collectionID),
|
||||
[][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
},
|
||||
requestBody,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Debugf("[Chroma] Query embedding response: %d, %s", statusCode, responseBody)
|
||||
results, err := d.parseQueryResponse(responseBody, log)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("[Chroma] Failed to parse query response: %v", err)
|
||||
}
|
||||
callback(results, ctx, log, err)
|
||||
},
|
||||
d.config.timeout,
|
||||
)
|
||||
}
|
||||
|
||||
func (d *ChromaProvider) UploadAnswerAndEmbedding(
|
||||
queryString string,
|
||||
queryEmb []float64,
|
||||
queryAnswer string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
// 最少需要填写的参数为 collection_id, embeddings 和 ids
|
||||
// 下面是一个例子
|
||||
// {
|
||||
// "embeddings": [
|
||||
// [1.1, 2.3, 3.2]
|
||||
// ],
|
||||
// "ids": [
|
||||
// "你吃了吗?"
|
||||
// ],
|
||||
// "documents": [
|
||||
// "我吃了。"
|
||||
// ]
|
||||
// }
|
||||
// 如果要添加 answer,则按照以下例子
|
||||
// {
|
||||
// "embeddings": [
|
||||
// [1.1, 2.3, 3.2]
|
||||
// ],
|
||||
// "documents": [
|
||||
// "answer1"
|
||||
// ],
|
||||
// "ids": [
|
||||
// "id1"
|
||||
// ]
|
||||
// }
|
||||
requestBody, err := json.Marshal(chromaInsertRequest{
|
||||
Embeddings: []chromaEmbedding{queryEmb},
|
||||
IDs: []string{queryString}, // queryString 指的是用户查询的问题
|
||||
Documents: []string{queryAnswer}, // queryAnswer 指的是用户查询的问题的答案
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("[Chroma] Failed to marshal upload embedding request body: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = d.client.Post(
|
||||
fmt.Sprintf("/api/v1/collections/%s/add", d.config.collectionID),
|
||||
[][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
},
|
||||
requestBody,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Debugf("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
||||
callback(ctx, log, err)
|
||||
},
|
||||
d.config.timeout,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
type chromaEmbedding []float64
|
||||
type chromaMetadataMap map[string]string
|
||||
type chromaInsertRequest struct {
|
||||
Embeddings []chromaEmbedding `json:"embeddings"`
|
||||
Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // 可选参数
|
||||
Documents []string `json:"documents,omitempty"` // 可选参数
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
|
||||
type chromaQueryRequest struct {
|
||||
Where map[string]string `json:"where,omitempty"` // 可选参数
|
||||
WhereDocument map[string]string `json:"where_document,omitempty"` // 可选参数
|
||||
QueryEmbeddings []chromaEmbedding `json:"query_embeddings"`
|
||||
Limit int `json:"limit"`
|
||||
Include []string `json:"include"`
|
||||
}
|
||||
|
||||
type chromaQueryResponse struct {
|
||||
Ids [][]string `json:"ids"` // 第一维是 batch query,第二维是查询到的多个 ids
|
||||
Distances [][]float64 `json:"distances,omitempty"` // 与 Ids 一一对应
|
||||
Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // 可选参数
|
||||
Embeddings []chromaEmbedding `json:"embeddings,omitempty"` // 可选参数
|
||||
Documents [][]string `json:"documents,omitempty"` // 与 Ids 一一对应
|
||||
Uris []string `json:"uris,omitempty"` // 可选参数
|
||||
Data []interface{} `json:"data,omitempty"` // 可选参数
|
||||
Included []string `json:"included"`
|
||||
}
|
||||
|
||||
func (d *ChromaProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||
var queryResp chromaQueryResponse
|
||||
err := json.Unmarshal(responseBody, &queryResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("[Chroma] queryResp Ids len: %d", len(queryResp.Ids))
|
||||
if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 {
|
||||
return nil, errors.New("no query results found in response")
|
||||
}
|
||||
results := make([]QueryResult, 0, len(queryResp.Ids[0]))
|
||||
for i := range queryResp.Ids[0] {
|
||||
result := QueryResult{
|
||||
Text: queryResp.Ids[0][i],
|
||||
Score: queryResp.Distances[0][i],
|
||||
Answer: queryResp.Documents[0][i],
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
Reference in New Issue
Block a user