mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 12:47:28 +08:00
feat: Enhance ai-cache Plugin with Vector Similarity-Based LLM Cache Recall and Multi-DB Support (#1248)
This commit is contained in:
@@ -60,7 +60,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的
|
|||||||
| vector.apiKey | string | optional | "" | 向量存储服务 API Key |
|
| vector.apiKey | string | optional | "" | 向量存储服务 API Key |
|
||||||
| vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 |
|
| vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 |
|
||||||
| vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 |
|
| vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 |
|
||||||
| vector.collectionID | string | optional | "" | dashvector 向量存储服务 Collection ID |
|
| vector.collectionID | string | optional | "" | 向量存储服务 Collection ID |
|
||||||
| vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 |
|
| vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 |
|
||||||
| vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine` 和 `DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than,小于)、`lte` (less than or equal to,小等于)、`gt` (greater than,大于)、`gte` (greater than or equal to,大等于) |
|
| vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine` 和 `DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than,小于)、`lte` (less than or equal to,小等于)、`gt` (greater than,大于)、`gte` (greater than or equal to,大等于) |
|
||||||
|
|
||||||
@@ -99,6 +99,45 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的
|
|||||||
| responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
|
| responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
|
||||||
| streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
|
| streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
|
||||||
|
|
||||||
|
# 向量数据库提供商特有配置
|
||||||
|
## Chroma
|
||||||
|
Chroma 所对应的 `vector.type` 为 `chroma`。它并无特有的配置字段。需要提前创建 Collection,并填写 Collection ID 至配置项 `vector.collectionID`,一个 Collection ID 的示例为 `52bbb8b3-724c-477b-a4ce-d5b578214612`。
|
||||||
|
|
||||||
|
## DashVector
|
||||||
|
DashVector 所对应的 `vector.type` 为 `dashvector`。它并无特有的配置字段。需要提前创建 Collection,并填写 `Collection 名称` 至配置项 `vector.collectionID`。
|
||||||
|
|
||||||
|
## ElasticSearch
|
||||||
|
ElasticSearch 所对应的 `vector.type` 为 `elasticsearch`。需要提前创建 Index 并填写 Index Name 至配置项 `vector.collectionID` 。
|
||||||
|
|
||||||
|
当前依赖于 [KNN](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) 方法,请保证 ES 版本支持 `KNN`,当前已在 `8.16` 版本测试。
|
||||||
|
|
||||||
|
它特有的配置字段如下:
|
||||||
|
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||||
|
|-------------------|----------|----------|--------|-------------------------------------------------------------------------------|
|
||||||
|
| `vector.esUsername` | string | 非必填 | - | ElasticSearch 用户名 |
|
||||||
|
| `vector.esPassword` | string | 非必填 | - | ElasticSearch 密码 |
|
||||||
|
|
||||||
|
|
||||||
|
`vector.esUsername` 和 `vector.esPassword` 用于 Basic 认证。同时也支持 Api Key 认证,当填写了 `vector.apiKey` 时,则启用 Api Key 认证,如果使用 SaaS 版本需要填写 `encoded` 的值。
|
||||||
|
|
||||||
|
## Milvus
|
||||||
|
Milvus 所对应的 `vector.type` 为 `milvus`。它并无特有的配置字段。需要提前创建 Collection,并填写 Collection Name 至配置项 `vector.collectionID`。
|
||||||
|
|
||||||
|
## Pinecone
|
||||||
|
Pinecone 所对应的 `vector.type` 为 `pinecone`。它并无特有的配置字段。需要提前创建 Index,并填写 Index 访问域名至 `vector.serviceHost`。
|
||||||
|
|
||||||
|
Pinecone 中的 `Namespace` 参数通过插件的 `vector.collectionID` 进行配置,如果不填写 `vector.collectionID`,则默认为 Default Namespace。
|
||||||
|
|
||||||
|
## Qdrant
|
||||||
|
Qdrant 所对应的 `vector.type` 为 `qdrant`。它并无特有的配置字段。需要提前创建 Collection,并填写 Collection Name 至配置项 `vector.collectionID`。
|
||||||
|
|
||||||
|
## Weaviate
|
||||||
|
Weaviate 所对应的 `vector.type` 为 `weaviate`。它并无特有的配置字段。
|
||||||
|
需要提前创建 Collection,并填写 Collection Name 至配置项 `vector.collectionID`。
|
||||||
|
|
||||||
|
需要注意的是 Weaviate 会设置首字母自动大写,在填写配置 `collectionID` 的时候需要将首字母设置为大写。
|
||||||
|
|
||||||
|
如果使用 SaaS 需要填写 `vector.serviceHost` 参数。
|
||||||
|
|
||||||
## 配置示例
|
## 配置示例
|
||||||
### 基础配置
|
### 基础配置
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
package embedding
|
|
||||||
|
|
||||||
// import (
|
|
||||||
// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
|
||||||
// )
|
|
||||||
|
|
||||||
// const (
|
|
||||||
// weaviateURL = "172.17.0.1:8081"
|
|
||||||
// )
|
|
||||||
|
|
||||||
// type weaviateProviderInitializer struct {
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func (d *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func (d *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
|
||||||
// return &DSProvider{
|
|
||||||
// config: config,
|
|
||||||
// client: wrapper.NewClusterClient(wrapper.DnsCluster{
|
|
||||||
// ServiceName: config.ServiceName,
|
|
||||||
// Port: dashScopePort,
|
|
||||||
// Domain: dashScopeDomain,
|
|
||||||
// }),
|
|
||||||
// }, nil
|
|
||||||
// }
|
|
||||||
201
plugins/wasm-go/extensions/ai-cache/vector/chroma.go
Normal file
201
plugins/wasm-go/extensions/ai-cache/vector/chroma.go
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
package vector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type chromaProviderInitializer struct{}
|
||||||
|
|
||||||
|
func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||||
|
if len(config.collectionID) == 0 {
|
||||||
|
return errors.New("[Chroma] collectionID is required")
|
||||||
|
}
|
||||||
|
if len(config.serviceName) == 0 {
|
||||||
|
return errors.New("[Chroma] serviceName is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *chromaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||||
|
return &ChromaProvider{
|
||||||
|
config: config,
|
||||||
|
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
|
FQDN: config.serviceName,
|
||||||
|
Host: config.serviceHost,
|
||||||
|
Port: int64(config.servicePort),
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChromaProvider struct {
|
||||||
|
config ProviderConfig
|
||||||
|
client wrapper.HttpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChromaProvider) GetProviderType() string {
|
||||||
|
return PROVIDER_TYPE_CHROMA
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *ChromaProvider) QueryEmbedding(
|
||||||
|
emb []float64,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 collection_id, embeddings 和 ids
|
||||||
|
// 下面是一个例子
|
||||||
|
// {
|
||||||
|
// "where": {}, // 用于 metadata 过滤,可选参数
|
||||||
|
// "where_document": {}, // 用于 document 过滤,可选参数
|
||||||
|
// "query_embeddings": [
|
||||||
|
// [1.1, 2.3, 3.2]
|
||||||
|
// ],
|
||||||
|
// "limit": 5,
|
||||||
|
// "include": [
|
||||||
|
// "metadatas", // 可选
|
||||||
|
// "documents", // 如果需要答案则需要
|
||||||
|
// "distances"
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
|
||||||
|
requestBody, err := json.Marshal(chromaQueryRequest{
|
||||||
|
QueryEmbeddings: []chromaEmbedding{emb},
|
||||||
|
Limit: d.config.topK,
|
||||||
|
Include: []string{"distances", "documents"},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Chroma] Failed to marshal query embedding request body: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Post(
|
||||||
|
fmt.Sprintf("/api/v1/collections/%s/query", d.config.collectionID),
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Chroma] Query embedding response: %d, %s", statusCode, responseBody)
|
||||||
|
results, err := d.parseQueryResponse(responseBody, log)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("[Chroma] Failed to parse query response: %v", err)
|
||||||
|
}
|
||||||
|
callback(results, ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *ChromaProvider) UploadAnswerAndEmbedding(
|
||||||
|
queryString string,
|
||||||
|
queryEmb []float64,
|
||||||
|
queryAnswer string,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 collection_id, embeddings 和 ids
|
||||||
|
// 下面是一个例子
|
||||||
|
// {
|
||||||
|
// "embeddings": [
|
||||||
|
// [1.1, 2.3, 3.2]
|
||||||
|
// ],
|
||||||
|
// "ids": [
|
||||||
|
// "你吃了吗?"
|
||||||
|
// ],
|
||||||
|
// "documents": [
|
||||||
|
// "我吃了。"
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
// 如果要添加 answer,则按照以下例子
|
||||||
|
// {
|
||||||
|
// "embeddings": [
|
||||||
|
// [1.1, 2.3, 3.2]
|
||||||
|
// ],
|
||||||
|
// "documents": [
|
||||||
|
// "answer1"
|
||||||
|
// ],
|
||||||
|
// "ids": [
|
||||||
|
// "id1"
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
requestBody, err := json.Marshal(chromaInsertRequest{
|
||||||
|
Embeddings: []chromaEmbedding{queryEmb},
|
||||||
|
IDs: []string{queryString}, // queryString 指的是用户查询的问题
|
||||||
|
Documents: []string{queryAnswer}, // queryAnswer 指的是用户查询的问题的答案
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Chroma] Failed to marshal upload embedding request body: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = d.client.Post(
|
||||||
|
fmt.Sprintf("/api/v1/collections/%s/add", d.config.collectionID),
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
||||||
|
callback(ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type chromaEmbedding []float64
|
||||||
|
type chromaMetadataMap map[string]string
|
||||||
|
type chromaInsertRequest struct {
|
||||||
|
Embeddings []chromaEmbedding `json:"embeddings"`
|
||||||
|
Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // 可选参数
|
||||||
|
Documents []string `json:"documents,omitempty"` // 可选参数
|
||||||
|
IDs []string `json:"ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chromaQueryRequest struct {
|
||||||
|
Where map[string]string `json:"where,omitempty"` // 可选参数
|
||||||
|
WhereDocument map[string]string `json:"where_document,omitempty"` // 可选参数
|
||||||
|
QueryEmbeddings []chromaEmbedding `json:"query_embeddings"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
Include []string `json:"include"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chromaQueryResponse struct {
|
||||||
|
Ids [][]string `json:"ids"` // 第一维是 batch query,第二维是查询到的多个 ids
|
||||||
|
Distances [][]float64 `json:"distances,omitempty"` // 与 Ids 一一对应
|
||||||
|
Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // 可选参数
|
||||||
|
Embeddings []chromaEmbedding `json:"embeddings,omitempty"` // 可选参数
|
||||||
|
Documents [][]string `json:"documents,omitempty"` // 与 Ids 一一对应
|
||||||
|
Uris []string `json:"uris,omitempty"` // 可选参数
|
||||||
|
Data []interface{} `json:"data,omitempty"` // 可选参数
|
||||||
|
Included []string `json:"included"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *ChromaProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||||
|
var queryResp chromaQueryResponse
|
||||||
|
err := json.Unmarshal(responseBody, &queryResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("[Chroma] queryResp Ids len: %d", len(queryResp.Ids))
|
||||||
|
if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 {
|
||||||
|
return nil, errors.New("no query results found in response")
|
||||||
|
}
|
||||||
|
results := make([]QueryResult, 0, len(queryResp.Ids[0]))
|
||||||
|
for i := range queryResp.Ids[0] {
|
||||||
|
result := QueryResult{
|
||||||
|
Text: queryResp.Ids[0][i],
|
||||||
|
Score: queryResp.Distances[0][i],
|
||||||
|
Answer: queryResp.Documents[0][i],
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
200
plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go
Normal file
200
plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
package vector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type esProviderInitializer struct{}
|
||||||
|
|
||||||
|
func (c *esProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||||
|
if len(config.collectionID) == 0 {
|
||||||
|
return errors.New("[ES] collectionID is required")
|
||||||
|
}
|
||||||
|
if len(config.serviceName) == 0 {
|
||||||
|
return errors.New("[ES] serviceName is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *esProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||||
|
return &ESProvider{
|
||||||
|
config: config,
|
||||||
|
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
|
FQDN: config.serviceName,
|
||||||
|
Host: config.serviceHost,
|
||||||
|
Port: int64(config.servicePort),
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ESProvider struct {
|
||||||
|
config ProviderConfig
|
||||||
|
client wrapper.HttpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ESProvider) GetProviderType() string {
|
||||||
|
return PROVIDER_TYPE_ES
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *ESProvider) QueryEmbedding(
|
||||||
|
emb []float64,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
|
||||||
|
requestBody, err := json.Marshal(esQueryRequest{
|
||||||
|
Source: Source{Excludes: []string{"embedding"}},
|
||||||
|
Knn: knn{
|
||||||
|
Field: "embedding",
|
||||||
|
QueryVector: emb,
|
||||||
|
K: d.config.topK,
|
||||||
|
},
|
||||||
|
Size: d.config.topK,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[ES] Failed to marshal query embedding request body: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Post(
|
||||||
|
fmt.Sprintf("/%s/_search", d.config.collectionID),
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"Authorization", d.getCredentials()},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[ES] Query embedding response: %d, %s", statusCode, responseBody)
|
||||||
|
results, err := d.parseQueryResponse(responseBody, log)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("[ES] Failed to parse query response: %v", err)
|
||||||
|
}
|
||||||
|
callback(results, ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// base64 编码 ES 身份认证字符串或使用 Apikey
|
||||||
|
func (d *ESProvider) getCredentials() string {
|
||||||
|
if len(d.config.apiKey) != 0 {
|
||||||
|
return fmt.Sprintf("ApiKey %s", d.config.apiKey)
|
||||||
|
} else {
|
||||||
|
credentials := fmt.Sprintf("%s:%s", d.config.esUsername, d.config.esPassword)
|
||||||
|
encodedCredentials := base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||||
|
return fmt.Sprintf("Basic %s", encodedCredentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *ESProvider) UploadAnswerAndEmbedding(
|
||||||
|
queryString string,
|
||||||
|
queryEmb []float64,
|
||||||
|
queryAnswer string,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 index, embeddings 和 question
|
||||||
|
// 下面是一个例子
|
||||||
|
// POST /<index>/_doc
|
||||||
|
// {
|
||||||
|
// "embedding": [
|
||||||
|
// [1.1, 2.3, 3.2]
|
||||||
|
// ],
|
||||||
|
// "question": [
|
||||||
|
// "你吃了吗?"
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
requestBody, err := json.Marshal(esInsertRequest{
|
||||||
|
Embedding: queryEmb,
|
||||||
|
Question: queryString,
|
||||||
|
Answer: queryAnswer,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[ES] Failed to marshal upload embedding request body: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Post(
|
||||||
|
fmt.Sprintf("/%s/_doc", d.config.collectionID),
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"Authorization", d.getCredentials()},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[ES] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
||||||
|
callback(ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type esInsertRequest struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
Question string `json:"question"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type knn struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
QueryVector []float64 `json:"query_vector"`
|
||||||
|
K int `json:"k"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Source struct {
|
||||||
|
Excludes []string `json:"excludes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type esQueryRequest struct {
|
||||||
|
Source Source `json:"_source"`
|
||||||
|
Knn knn `json:"knn"`
|
||||||
|
Size int `json:"size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type esQueryResponse struct {
|
||||||
|
Took int `json:"took"`
|
||||||
|
TimedOut bool `json:"timed_out"`
|
||||||
|
Hits struct {
|
||||||
|
Total struct {
|
||||||
|
Value int `json:"value"`
|
||||||
|
Relation string `json:"relation"`
|
||||||
|
} `json:"total"`
|
||||||
|
Hits []struct {
|
||||||
|
Index string `json:"_index"`
|
||||||
|
ID string `json:"_id"`
|
||||||
|
Score float64 `json:"_score"`
|
||||||
|
Source map[string]interface{} `json:"_source"`
|
||||||
|
} `json:"hits"`
|
||||||
|
} `json:"hits"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *ESProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||||
|
log.Infof("[ES] responseBody: %s", string(responseBody))
|
||||||
|
var queryResp esQueryResponse
|
||||||
|
err := json.Unmarshal(responseBody, &queryResp)
|
||||||
|
if err != nil {
|
||||||
|
return []QueryResult{}, err
|
||||||
|
}
|
||||||
|
log.Debugf("[ES] queryResp Hits len: %d", len(queryResp.Hits.Hits))
|
||||||
|
if len(queryResp.Hits.Hits) == 0 {
|
||||||
|
return nil, errors.New("no query results found in response")
|
||||||
|
}
|
||||||
|
results := make([]QueryResult, 0, queryResp.Hits.Total.Value)
|
||||||
|
for _, hit := range queryResp.Hits.Hits {
|
||||||
|
result := QueryResult{
|
||||||
|
Text: hit.Source["question"].(string),
|
||||||
|
Score: hit.Score,
|
||||||
|
Answer: hit.Source["answer"].(string),
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
206
plugins/wasm-go/extensions/ai-cache/vector/milvus.go
Normal file
206
plugins/wasm-go/extensions/ai-cache/vector/milvus.go
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
package vector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type milvusProviderInitializer struct{}
|
||||||
|
|
||||||
|
func (c *milvusProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||||
|
if len(config.serviceName) == 0 {
|
||||||
|
return errors.New("[Milvus] serviceName is required")
|
||||||
|
}
|
||||||
|
if len(config.collectionID) == 0 {
|
||||||
|
return errors.New("[Milvus] collectionID is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *milvusProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||||
|
return &milvusProvider{
|
||||||
|
config: config,
|
||||||
|
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
|
FQDN: config.serviceName,
|
||||||
|
Host: config.serviceHost,
|
||||||
|
Port: int64(config.servicePort),
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type milvusProvider struct {
|
||||||
|
config ProviderConfig
|
||||||
|
client wrapper.HttpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *milvusProvider) GetProviderType() string {
|
||||||
|
return PROVIDER_TYPE_MILVUS
|
||||||
|
}
|
||||||
|
|
||||||
|
type milvusData struct {
|
||||||
|
Vector []float64 `json:"vector"`
|
||||||
|
Question string `json:"question,omitempty"`
|
||||||
|
Answer string `json:"answer,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type milvusInsertRequest struct {
|
||||||
|
CollectionName string `json:"collectionName"`
|
||||||
|
Data []milvusData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *milvusProvider) UploadAnswerAndEmbedding(
|
||||||
|
queryString string,
|
||||||
|
queryEmb []float64,
|
||||||
|
queryAnswer string,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 collectionName, data 和 Authorization. question, answer 可选
|
||||||
|
// 需要填写 id,否则 v2.4.13-hotfix 提示 invalid syntax: invalid parameter[expected=Int64][actual=]
|
||||||
|
// 如果不填写 id,要在创建 collection 的时候设置 autoId 为 true
|
||||||
|
// 下面是一个例子
|
||||||
|
// {
|
||||||
|
// "collectionName": "higress",
|
||||||
|
// "data": [
|
||||||
|
// {
|
||||||
|
// "question": "这里是问题",
|
||||||
|
// "answer": "这里是答案"
|
||||||
|
// "vector": [
|
||||||
|
// 0.9,
|
||||||
|
// 0.1,
|
||||||
|
// 0.1
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
requestBody, err := json.Marshal(milvusInsertRequest{
|
||||||
|
CollectionName: d.config.collectionID,
|
||||||
|
Data: []milvusData{
|
||||||
|
{
|
||||||
|
Question: queryString,
|
||||||
|
Answer: queryAnswer,
|
||||||
|
Vector: queryEmb,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Milvus] Failed to marshal upload embedding request body: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Post(
|
||||||
|
"/v2/vectordb/entities/insert",
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Milvus] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
||||||
|
callback(ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type milvusQueryRequest struct {
|
||||||
|
CollectionName string `json:"collectionName"`
|
||||||
|
Data [][]float64 `json:"data"`
|
||||||
|
AnnsField string `json:"annsField"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
OutputFields []string `json:"outputFields"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *milvusProvider) QueryEmbedding(
|
||||||
|
emb []float64,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 collectionName, data, annsField. outputFields 为可选参数
|
||||||
|
// 下面是一个例子
|
||||||
|
// {
|
||||||
|
// "collectionName": "quick_setup",
|
||||||
|
// "data": [
|
||||||
|
// [
|
||||||
|
// 0.3580376395471989,
|
||||||
|
// "Unknown type",
|
||||||
|
// 0.18414012509913835,
|
||||||
|
// "Unknown type",
|
||||||
|
// 0.9029438446296592
|
||||||
|
// ]
|
||||||
|
// ],
|
||||||
|
// "annsField": "vector",
|
||||||
|
// "limit": 3,
|
||||||
|
// "outputFields": [
|
||||||
|
// "color"
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
requestBody, err := json.Marshal(milvusQueryRequest{
|
||||||
|
CollectionName: d.config.collectionID,
|
||||||
|
Data: [][]float64{emb},
|
||||||
|
AnnsField: "vector",
|
||||||
|
Limit: d.config.topK,
|
||||||
|
OutputFields: []string{
|
||||||
|
"question",
|
||||||
|
"answer",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Milvus] Failed to marshal query embedding: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Post(
|
||||||
|
"/v2/vectordb/entities/search",
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Milvus] Query embedding response: %d, %s", statusCode, responseBody)
|
||||||
|
results, err := d.parseQueryResponse(responseBody, log)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("[Milvus] Failed to parse query response: %v", err)
|
||||||
|
}
|
||||||
|
callback(results, ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *milvusProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||||
|
if !gjson.GetBytes(responseBody, "data.0.distance").Exists() {
|
||||||
|
log.Errorf("[Milvus] No distance found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Milvus] No distance found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.GetBytes(responseBody, "data.0.question").Exists() {
|
||||||
|
log.Errorf("[Milvus] No question found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Milvus] No question found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.GetBytes(responseBody, "data.0.answer").Exists() {
|
||||||
|
log.Errorf("[Milvus] No answer found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Milvus] No answer found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
resultNum := gjson.GetBytes(responseBody, "data.#").Int()
|
||||||
|
results := make([]QueryResult, 0, resultNum)
|
||||||
|
for i := 0; i < int(resultNum); i++ {
|
||||||
|
result := QueryResult{
|
||||||
|
Text: gjson.GetBytes(responseBody, fmt.Sprintf("data.%d.question", i)).String(),
|
||||||
|
Score: gjson.GetBytes(responseBody, fmt.Sprintf("data.%d.distance", i)).Float(),
|
||||||
|
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("data.%d.answer", i)).String(),
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
194
plugins/wasm-go/extensions/ai-cache/vector/pinecone.go
Normal file
194
plugins/wasm-go/extensions/ai-cache/vector/pinecone.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package vector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pineconeProviderInitializer struct{}
|
||||||
|
|
||||||
|
func (c *pineconeProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||||
|
if len(config.serviceHost) == 0 {
|
||||||
|
return errors.New("[Pinecone] serviceHost is required")
|
||||||
|
}
|
||||||
|
if len(config.serviceName) == 0 {
|
||||||
|
return errors.New("[Pinecone] serviceName is required")
|
||||||
|
}
|
||||||
|
if len(config.apiKey) == 0 {
|
||||||
|
return errors.New("[Pinecone] apiKey is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pineconeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||||
|
return &pineconeProvider{
|
||||||
|
config: config,
|
||||||
|
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
|
FQDN: config.serviceName,
|
||||||
|
Host: config.serviceHost,
|
||||||
|
Port: int64(config.servicePort),
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type pineconeProvider struct {
|
||||||
|
config ProviderConfig
|
||||||
|
client wrapper.HttpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pineconeProvider) GetProviderType() string {
|
||||||
|
return PROVIDER_TYPE_PINECONE
|
||||||
|
}
|
||||||
|
|
||||||
|
type pineconeMetadata struct {
|
||||||
|
Question string `json:"question"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type pineconeVector struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Values []float64 `json:"values"`
|
||||||
|
Properties pineconeMetadata `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type pineconeInsertRequest struct {
|
||||||
|
Vectors []pineconeVector `json:"vectors"`
|
||||||
|
Namespace string `json:"namespace"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *pineconeProvider) UploadAnswerAndEmbedding(
|
||||||
|
queryString string,
|
||||||
|
queryEmb []float64,
|
||||||
|
queryAnswer string,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 vector 和 question
|
||||||
|
// 下面是一个例子
|
||||||
|
// {
|
||||||
|
// "vectors": [
|
||||||
|
// {
|
||||||
|
// "id": "A",
|
||||||
|
// "values": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
|
||||||
|
// "metadata": {"question": "你好", "answer": "你也好"}
|
||||||
|
// }
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
requestBody, err := json.Marshal(pineconeInsertRequest{
|
||||||
|
Vectors: []pineconeVector{
|
||||||
|
{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
Values: queryEmb,
|
||||||
|
Properties: pineconeMetadata{Question: queryString, Answer: queryAnswer},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Namespace: d.config.collectionID,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Pinecone] Failed to marshal upload embedding request body: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Post(
|
||||||
|
"/vectors/upsert",
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"Api-Key", d.config.apiKey},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Pinecone] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
||||||
|
callback(ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type pineconeQueryRequest struct {
|
||||||
|
Namespace string `json:"namespace"`
|
||||||
|
Vector []float64 `json:"vector"`
|
||||||
|
TopK int `json:"topK"`
|
||||||
|
IncludeMetadata bool `json:"includeMetadata"`
|
||||||
|
IncludeValues bool `json:"includeValues"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *pineconeProvider) QueryEmbedding(
|
||||||
|
emb []float64,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 vector
|
||||||
|
// 下面是一个例子
|
||||||
|
// {
|
||||||
|
// "namespace": "higress",
|
||||||
|
// "vector": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
|
||||||
|
// "topK": 1,
|
||||||
|
// "includeMetadata": false
|
||||||
|
// }
|
||||||
|
requestBody, err := json.Marshal(pineconeQueryRequest{
|
||||||
|
Namespace: d.config.collectionID,
|
||||||
|
Vector: emb,
|
||||||
|
TopK: d.config.topK,
|
||||||
|
IncludeMetadata: true,
|
||||||
|
IncludeValues: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Pinecone] Failed to marshal query embedding: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Post(
|
||||||
|
"/query",
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"Api-Key", d.config.apiKey},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Pinecone] Query embedding response: %d, %s", statusCode, responseBody)
|
||||||
|
results, err := d.parseQueryResponse(responseBody, log)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("[Pinecone] Failed to parse query response: %v", err)
|
||||||
|
}
|
||||||
|
callback(results, ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *pineconeProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||||
|
if !gjson.GetBytes(responseBody, "matches.0.score").Exists() {
|
||||||
|
log.Errorf("[Pinecone] No distance found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Pinecone] No distance found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.GetBytes(responseBody, "matches.0.metadata.question").Exists() {
|
||||||
|
log.Errorf("[Pinecone] No question found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Pinecone] No question found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.GetBytes(responseBody, "matches.0.metadata.answer").Exists() {
|
||||||
|
log.Errorf("[Pinecone] No answer found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Pinecone] No answer found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
resultNum := gjson.GetBytes(responseBody, "matches.#").Int()
|
||||||
|
results := make([]QueryResult, 0, resultNum)
|
||||||
|
for i := 0; i < int(resultNum); i++ {
|
||||||
|
result := QueryResult{
|
||||||
|
Text: gjson.GetBytes(responseBody, fmt.Sprintf("matches.%d.metadata.question", i)).String(),
|
||||||
|
Score: gjson.GetBytes(responseBody, fmt.Sprintf("matches.%d.score", i)).Float(),
|
||||||
|
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("matches.%d.metadata.answer", i)).String(),
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
@@ -10,6 +10,11 @@ import (
|
|||||||
const (
|
const (
|
||||||
PROVIDER_TYPE_DASH_VECTOR = "dashvector"
|
PROVIDER_TYPE_DASH_VECTOR = "dashvector"
|
||||||
PROVIDER_TYPE_CHROMA = "chroma"
|
PROVIDER_TYPE_CHROMA = "chroma"
|
||||||
|
PROVIDER_TYPE_ES = "elasticsearch"
|
||||||
|
PROVIDER_TYPE_WEAVIATE = "weaviate"
|
||||||
|
PROVIDER_TYPE_PINECONE = "pinecone"
|
||||||
|
PROVIDER_TYPE_QDRANT = "qdrant"
|
||||||
|
PROVIDER_TYPE_MILVUS = "milvus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type providerInitializer interface {
|
type providerInitializer interface {
|
||||||
@@ -20,7 +25,12 @@ type providerInitializer interface {
|
|||||||
var (
|
var (
|
||||||
providerInitializers = map[string]providerInitializer{
|
providerInitializers = map[string]providerInitializer{
|
||||||
PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{},
|
PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{},
|
||||||
// PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{},
|
PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{},
|
||||||
|
PROVIDER_TYPE_ES: &esProviderInitializer{},
|
||||||
|
PROVIDER_TYPE_WEAVIATE: &weaviateProviderInitializer{},
|
||||||
|
PROVIDER_TYPE_PINECONE: &pineconeProviderInitializer{},
|
||||||
|
PROVIDER_TYPE_QDRANT: &qdrantProviderInitializer{},
|
||||||
|
PROVIDER_TYPE_MILVUS: &milvusProviderInitializer{},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -71,10 +81,6 @@ type StringQuerier interface {
|
|||||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type SimilarityThresholdProvider interface {
|
|
||||||
GetSimilarityThreshold() float64
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProviderConfig struct {
|
type ProviderConfig struct {
|
||||||
// @Title zh-CN 向量存储服务提供者类型
|
// @Title zh-CN 向量存储服务提供者类型
|
||||||
// @Description zh-CN 向量存储服务提供者类型,例如 dashvector、chroma
|
// @Description zh-CN 向量存储服务提供者类型,例如 dashvector、chroma
|
||||||
@@ -97,8 +103,8 @@ type ProviderConfig struct {
|
|||||||
// @Title zh-CN 请求超时
|
// @Title zh-CN 请求超时
|
||||||
// @Description zh-CN 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒
|
// @Description zh-CN 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒
|
||||||
timeout uint32
|
timeout uint32
|
||||||
// @Title zh-CN DashVector 向量存储服务 Collection ID
|
// @Title zh-CN 向量存储服务 Collection ID
|
||||||
// @Description zh-CN DashVector 向量存储服务 Collection ID
|
// @Description zh-CN 向量存储服务的 Collection ID
|
||||||
collectionID string
|
collectionID string
|
||||||
// @Title zh-CN 相似度度量阈值
|
// @Title zh-CN 相似度度量阈值
|
||||||
// @Description zh-CN 默认相似度度量阈值,默认为 1000。
|
// @Description zh-CN 默认相似度度量阈值,默认为 1000。
|
||||||
@@ -109,6 +115,14 @@ type ProviderConfig struct {
|
|||||||
// 所以需要允许自定义比较方式,对于 Cosine 和 DotProduct 选择 gt,对于 Euclidean 则选择 lt。
|
// 所以需要允许自定义比较方式,对于 Cosine 和 DotProduct 选择 gt,对于 Euclidean 则选择 lt。
|
||||||
// 默认为 lt,所有条件包括 lt (less than,小于)、lte (less than or equal to,小等于)、gt (greater than,大于)、gte (greater than or equal to,大等于)
|
// 默认为 lt,所有条件包括 lt (less than,小于)、lte (less than or equal to,小等于)、gt (greater than,大于)、gte (greater than or equal to,大等于)
|
||||||
ThresholdRelation string
|
ThresholdRelation string
|
||||||
|
|
||||||
|
// ES 配置
|
||||||
|
// @Title zh-CN ES 用户名
|
||||||
|
// @Description zh-CN ES 用户名
|
||||||
|
esUsername string
|
||||||
|
// @Title zh-CN ES 密码
|
||||||
|
// @Description zh-CN ES 密码
|
||||||
|
esPassword string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) GetProviderType() string {
|
func (c *ProviderConfig) GetProviderType() string {
|
||||||
@@ -117,7 +131,6 @@ func (c *ProviderConfig) GetProviderType() string {
|
|||||||
|
|
||||||
func (c *ProviderConfig) FromJson(json gjson.Result) {
|
func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||||
c.typ = json.Get("type").String()
|
c.typ = json.Get("type").String()
|
||||||
// DashVector
|
|
||||||
c.serviceName = json.Get("serviceName").String()
|
c.serviceName = json.Get("serviceName").String()
|
||||||
c.serviceHost = json.Get("serviceHost").String()
|
c.serviceHost = json.Get("serviceHost").String()
|
||||||
c.servicePort = int64(json.Get("servicePort").Int())
|
c.servicePort = int64(json.Get("servicePort").Int())
|
||||||
@@ -142,6 +155,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
|||||||
if c.ThresholdRelation == "" {
|
if c.ThresholdRelation == "" {
|
||||||
c.ThresholdRelation = "lt"
|
c.ThresholdRelation = "lt"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ES
|
||||||
|
c.esUsername = json.Get("esUsername").String()
|
||||||
|
c.esPassword = json.Get("esPassword").String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) Validate() error {
|
func (c *ProviderConfig) Validate() error {
|
||||||
@@ -152,6 +169,9 @@ func (c *ProviderConfig) Validate() error {
|
|||||||
if !has {
|
if !has {
|
||||||
return errors.New("unknown vector database service provider type: " + c.typ)
|
return errors.New("unknown vector database service provider type: " + c.typ)
|
||||||
}
|
}
|
||||||
|
if !isRelationValid(c.ThresholdRelation) {
|
||||||
|
return errors.New("invalid thresholdRelation: " + c.ThresholdRelation)
|
||||||
|
}
|
||||||
if err := initializer.ValidateConfig(*c); err != nil {
|
if err := initializer.ValidateConfig(*c); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -165,3 +185,12 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
|
|||||||
}
|
}
|
||||||
return initializer.CreateProvider(pc)
|
return initializer.CreateProvider(pc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isRelationValid(relation string) bool {
|
||||||
|
for _, r := range []string{"lt", "lte", "gt", "gte"} {
|
||||||
|
if r == relation {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
208
plugins/wasm-go/extensions/ai-cache/vector/qdrant.go
Normal file
208
plugins/wasm-go/extensions/ai-cache/vector/qdrant.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package vector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type qdrantProviderInitializer struct{}
|
||||||
|
|
||||||
|
func (c *qdrantProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||||
|
if len(config.serviceName) == 0 {
|
||||||
|
return errors.New("[Qdrant] serviceName is required")
|
||||||
|
}
|
||||||
|
if len(config.collectionID) == 0 {
|
||||||
|
return errors.New("[Qdrant] collectionID is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *qdrantProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||||
|
return &qdrantProvider{
|
||||||
|
config: config,
|
||||||
|
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
|
FQDN: config.serviceName,
|
||||||
|
Host: config.serviceHost,
|
||||||
|
Port: int64(config.servicePort),
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type qdrantProvider struct {
|
||||||
|
config ProviderConfig
|
||||||
|
client wrapper.HttpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *qdrantProvider) GetProviderType() string {
|
||||||
|
return PROVIDER_TYPE_QDRANT
|
||||||
|
}
|
||||||
|
|
||||||
|
type qdrantPayload struct {
|
||||||
|
Question string `json:"question"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type qdrantPoint struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Vector []float64 `json:"vector"`
|
||||||
|
Payload qdrantPayload `json:"payload"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type qdrantInsertRequest struct {
|
||||||
|
Points []qdrantPoint `json:"points"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *qdrantProvider) UploadAnswerAndEmbedding(
|
||||||
|
queryString string,
|
||||||
|
queryEmb []float64,
|
||||||
|
queryAnswer string,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 id 和 vector. payload 可选
|
||||||
|
// 下面是一个例子
|
||||||
|
// {
|
||||||
|
// "points": [
|
||||||
|
// {
|
||||||
|
// "id": "76874cce-1fb9-4e16-9b0b-f085ac06ed6f",
|
||||||
|
// "payload": {
|
||||||
|
// "question": "这里是问题",
|
||||||
|
// "answer": "这里是答案"
|
||||||
|
// },
|
||||||
|
// "vector": [
|
||||||
|
// 0.9,
|
||||||
|
// 0.1,
|
||||||
|
// 0.1
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
requestBody, err := json.Marshal(qdrantInsertRequest{
|
||||||
|
Points: []qdrantPoint{
|
||||||
|
{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
Vector: queryEmb,
|
||||||
|
Payload: qdrantPayload{Question: queryString, Answer: queryAnswer},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Qdrant] Failed to marshal upload embedding request body: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Put(
|
||||||
|
fmt.Sprintf("/collections/%s/points", d.config.collectionID),
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"api-key", d.config.apiKey},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Qdrant] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
||||||
|
callback(ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type qdrantQueryRequest struct {
|
||||||
|
Vector []float64 `json:"vector"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
WithPayload bool `json:"with_payload"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *qdrantProvider) QueryEmbedding(
|
||||||
|
emb []float64,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 vector 和 limit. with_payload 可选,为了直接得到问题答案,所以这里需要
|
||||||
|
// 下面是一个例子
|
||||||
|
// {
|
||||||
|
// "vector": [
|
||||||
|
// 0.2,
|
||||||
|
// 0.1,
|
||||||
|
// 0.9,
|
||||||
|
// 0.7
|
||||||
|
// ],
|
||||||
|
// "limit": 1
|
||||||
|
// }
|
||||||
|
requestBody, err := json.Marshal(qdrantQueryRequest{
|
||||||
|
Vector: emb,
|
||||||
|
Limit: d.config.topK,
|
||||||
|
WithPayload: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Qdrant] Failed to marshal query embedding: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Post(
|
||||||
|
fmt.Sprintf("/collections/%s/points/search", d.config.collectionID),
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"api-key", d.config.apiKey},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Qdrant] Query embedding response: %d, %s", statusCode, responseBody)
|
||||||
|
results, err := d.parseQueryResponse(responseBody, log)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("[Qdrant] Failed to parse query response: %v", err)
|
||||||
|
}
|
||||||
|
callback(results, ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *qdrantProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||||
|
// 返回的内容例子如下
|
||||||
|
// {
|
||||||
|
// "time": 0.002,
|
||||||
|
// "status": "ok",
|
||||||
|
// "result": [
|
||||||
|
// {
|
||||||
|
// "id": 42,
|
||||||
|
// "version": 3,
|
||||||
|
// "score": 0.75,
|
||||||
|
// "payload": {
|
||||||
|
// "question": "London",
|
||||||
|
// "answer": "green"
|
||||||
|
// },
|
||||||
|
// "shard_key": "region_1",
|
||||||
|
// "order_value": 42
|
||||||
|
// }
|
||||||
|
// ]
|
||||||
|
// }
|
||||||
|
if !gjson.GetBytes(responseBody, "result.0.score").Exists() {
|
||||||
|
log.Errorf("[Qdrant] No distance found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Qdrant] No distance found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.GetBytes(responseBody, "result.0.payload.answer").Exists() {
|
||||||
|
log.Errorf("[Qdrant] No answer found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Qdrant] No answer found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
resultNum := gjson.GetBytes(responseBody, "result.#").Int()
|
||||||
|
results := make([]QueryResult, 0, resultNum)
|
||||||
|
for i := 0; i < int(resultNum); i++ {
|
||||||
|
result := QueryResult{
|
||||||
|
Text: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.payload.question", i)).String(),
|
||||||
|
Score: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.score", i)).Float(),
|
||||||
|
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.payload.answer", i)).String(),
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
188
plugins/wasm-go/extensions/ai-cache/vector/weaviate.go
Normal file
188
plugins/wasm-go/extensions/ai-cache/vector/weaviate.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package vector
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type weaviateProviderInitializer struct{}
|
||||||
|
|
||||||
|
func (c *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||||
|
if len(config.collectionID) == 0 {
|
||||||
|
return errors.New("[Weaviate] collectionID is required")
|
||||||
|
}
|
||||||
|
if len(config.serviceName) == 0 {
|
||||||
|
return errors.New("[Weaviate] serviceName is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||||
|
return &WeaviateProvider{
|
||||||
|
config: config,
|
||||||
|
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
|
FQDN: config.serviceName,
|
||||||
|
Host: config.serviceHost,
|
||||||
|
Port: int64(config.servicePort),
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type WeaviateProvider struct {
|
||||||
|
config ProviderConfig
|
||||||
|
client wrapper.HttpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WeaviateProvider) GetProviderType() string {
|
||||||
|
return PROVIDER_TYPE_WEAVIATE
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *WeaviateProvider) QueryEmbedding(
|
||||||
|
emb []float64,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 class, vector
|
||||||
|
// 下面是一个例子
|
||||||
|
// {"query": "{ Get { Higress ( limit: 2 nearVector: { vector: [0.1, 0.2, 0.3] } ) { question _additional { distance } } } }"}
|
||||||
|
embString, err := json.Marshal(emb)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Weaviate] Failed to marshal query embedding: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 这里默认按照 distance 进行升序,所以不用再次排序
|
||||||
|
graphql := fmt.Sprintf(`
|
||||||
|
{
|
||||||
|
Get {
|
||||||
|
%s (
|
||||||
|
limit: %d
|
||||||
|
nearVector: {
|
||||||
|
vector: %s
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
question
|
||||||
|
answer
|
||||||
|
_additional {
|
||||||
|
distance
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`, d.config.collectionID, d.config.topK, embString)
|
||||||
|
|
||||||
|
requestBody, err := json.Marshal(weaviateQueryRequest{
|
||||||
|
Query: graphql,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Weaviate] Failed to marshal query embedding request body: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = d.client.Post(
|
||||||
|
"/v1/graphql",
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Weaviate] Query embedding response: %d, %s", statusCode, responseBody)
|
||||||
|
results, err := d.parseQueryResponse(responseBody, log)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("[Weaviate] Failed to parse query response: %v", err)
|
||||||
|
}
|
||||||
|
callback(results, ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *WeaviateProvider) UploadAnswerAndEmbedding(
|
||||||
|
queryString string,
|
||||||
|
queryEmb []float64,
|
||||||
|
queryAnswer string,
|
||||||
|
ctx wrapper.HttpContext,
|
||||||
|
log wrapper.Log,
|
||||||
|
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||||
|
// 最少需要填写的参数为 class, vector 和 question 和 answer
|
||||||
|
// 下面是一个例子
|
||||||
|
// {"class": "Higress", "vector": [0.1, 0.2, 0.3], "properties": {"question": "这里是问题", "answer": "这里是答案"}}
|
||||||
|
requestBody, err := json.Marshal(weaviateInsertRequest{
|
||||||
|
Class: d.config.collectionID,
|
||||||
|
Vector: queryEmb,
|
||||||
|
Properties: weaviateProperties{Question: queryString, Answer: queryAnswer}, // queryString 指的是用户查询的问题
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("[Weaviate] Failed to marshal upload embedding request body: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.client.Post(
|
||||||
|
"/v1/objects",
|
||||||
|
[][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
|
||||||
|
},
|
||||||
|
requestBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
log.Debugf("[Weaviate] statusCode: %d, responseBody: %s", statusCode, string(responseBody))
|
||||||
|
callback(ctx, log, err)
|
||||||
|
},
|
||||||
|
d.config.timeout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type weaviateProperties struct {
|
||||||
|
Question string `json:"question"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type weaviateInsertRequest struct {
|
||||||
|
Class string `json:"class"`
|
||||||
|
Vector []float64 `json:"vector"`
|
||||||
|
Properties weaviateProperties `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type weaviateQueryRequest struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *WeaviateProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||||
|
log.Infof("[Weaviate] queryResp: %s", string(responseBody))
|
||||||
|
|
||||||
|
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0._additional.distance", d.config.collectionID)).Exists() {
|
||||||
|
log.Errorf("[Weaviate] No distance found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Weaviate] No distance found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.question", d.config.collectionID)).Exists() {
|
||||||
|
log.Errorf("[Weaviate] No question found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Weaviate] No question found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.answer", d.config.collectionID)).Exists() {
|
||||||
|
log.Errorf("[Weaviate] No answer found in response body: %s", responseBody)
|
||||||
|
return nil, errors.New("[Weaviate] No answer found in response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
resultNum := gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.#", d.config.collectionID)).Int()
|
||||||
|
results := make([]QueryResult, 0, resultNum)
|
||||||
|
for i := 0; i < int(resultNum); i++ {
|
||||||
|
result := QueryResult{
|
||||||
|
Text: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d.question", d.config.collectionID, i)).String(),
|
||||||
|
Score: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d._additional.distance", d.config.collectionID, i)).Float(),
|
||||||
|
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d.answer", d.config.collectionID, i)).String(),
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user