mirror of
https://github.com/alibaba/higress.git
synced 2026-02-25 05:01:19 +08:00
189 lines
5.7 KiB
Go
189 lines
5.7 KiB
Go
package vector
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type weaviateProviderInitializer struct{}
|
|
|
|
func (c *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
|
if len(config.collectionID) == 0 {
|
|
return errors.New("[Weaviate] collectionID is required")
|
|
}
|
|
if len(config.serviceName) == 0 {
|
|
return errors.New("[Weaviate] serviceName is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
|
return &WeaviateProvider{
|
|
config: config,
|
|
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
|
FQDN: config.serviceName,
|
|
Host: config.serviceHost,
|
|
Port: int64(config.servicePort),
|
|
}),
|
|
}, nil
|
|
}
|
|
|
|
type WeaviateProvider struct {
|
|
config ProviderConfig
|
|
client wrapper.HttpClient
|
|
}
|
|
|
|
func (c *WeaviateProvider) GetProviderType() string {
|
|
return PROVIDER_TYPE_WEAVIATE
|
|
}
|
|
|
|
func (d *WeaviateProvider) QueryEmbedding(
|
|
emb []float64,
|
|
ctx wrapper.HttpContext,
|
|
log wrapper.Log,
|
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
|
// 最少需要填写的参数为 class, vector
|
|
// 下面是一个例子
|
|
// {"query": "{ Get { Higress ( limit: 2 nearVector: { vector: [0.1, 0.2, 0.3] } ) { question _additional { distance } } } }"}
|
|
embString, err := json.Marshal(emb)
|
|
if err != nil {
|
|
log.Errorf("[Weaviate] Failed to marshal query embedding: %v", err)
|
|
return err
|
|
}
|
|
// 这里默认按照 distance 进行升序,所以不用再次排序
|
|
graphql := fmt.Sprintf(`
|
|
{
|
|
Get {
|
|
%s (
|
|
limit: %d
|
|
nearVector: {
|
|
vector: %s
|
|
}
|
|
) {
|
|
question
|
|
answer
|
|
_additional {
|
|
distance
|
|
}
|
|
}
|
|
}
|
|
}
|
|
`, d.config.collectionID, d.config.topK, embString)
|
|
|
|
requestBody, err := json.Marshal(weaviateQueryRequest{
|
|
Query: graphql,
|
|
})
|
|
|
|
if err != nil {
|
|
log.Errorf("[Weaviate] Failed to marshal query embedding request body: %v", err)
|
|
return err
|
|
}
|
|
|
|
err = d.client.Post(
|
|
"/v1/graphql",
|
|
[][2]string{
|
|
{"Content-Type", "application/json"},
|
|
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
|
|
},
|
|
requestBody,
|
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
log.Debugf("[Weaviate] Query embedding response: %d, %s", statusCode, responseBody)
|
|
results, err := d.parseQueryResponse(responseBody, log)
|
|
if err != nil {
|
|
err = fmt.Errorf("[Weaviate] Failed to parse query response: %v", err)
|
|
}
|
|
callback(results, ctx, log, err)
|
|
},
|
|
d.config.timeout,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (d *WeaviateProvider) UploadAnswerAndEmbedding(
|
|
queryString string,
|
|
queryEmb []float64,
|
|
queryAnswer string,
|
|
ctx wrapper.HttpContext,
|
|
log wrapper.Log,
|
|
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
|
// 最少需要填写的参数为 class, vector 和 question 和 answer
|
|
// 下面是一个例子
|
|
// {"class": "Higress", "vector": [0.1, 0.2, 0.3], "properties": {"question": "这里是问题", "answer": "这里是答案"}}
|
|
requestBody, err := json.Marshal(weaviateInsertRequest{
|
|
Class: d.config.collectionID,
|
|
Vector: queryEmb,
|
|
Properties: weaviateProperties{Question: queryString, Answer: queryAnswer}, // queryString 指的是用户查询的问题
|
|
})
|
|
|
|
if err != nil {
|
|
log.Errorf("[Weaviate] Failed to marshal upload embedding request body: %v", err)
|
|
return err
|
|
}
|
|
|
|
return d.client.Post(
|
|
"/v1/objects",
|
|
[][2]string{
|
|
{"Content-Type", "application/json"},
|
|
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
|
|
},
|
|
requestBody,
|
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
log.Debugf("[Weaviate] statusCode: %d, responseBody: %s", statusCode, string(responseBody))
|
|
callback(ctx, log, err)
|
|
},
|
|
d.config.timeout,
|
|
)
|
|
}
|
|
|
|
type weaviateProperties struct {
|
|
Question string `json:"question"`
|
|
Answer string `json:"answer"`
|
|
}
|
|
|
|
type weaviateInsertRequest struct {
|
|
Class string `json:"class"`
|
|
Vector []float64 `json:"vector"`
|
|
Properties weaviateProperties `json:"properties"`
|
|
}
|
|
|
|
type weaviateQueryRequest struct {
|
|
Query string `json:"query"`
|
|
}
|
|
|
|
func (d *WeaviateProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
|
log.Infof("[Weaviate] queryResp: %s", string(responseBody))
|
|
|
|
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0._additional.distance", d.config.collectionID)).Exists() {
|
|
log.Errorf("[Weaviate] No distance found in response body: %s", responseBody)
|
|
return nil, errors.New("[Weaviate] No distance found in response body")
|
|
}
|
|
|
|
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.question", d.config.collectionID)).Exists() {
|
|
log.Errorf("[Weaviate] No question found in response body: %s", responseBody)
|
|
return nil, errors.New("[Weaviate] No question found in response body")
|
|
}
|
|
|
|
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.answer", d.config.collectionID)).Exists() {
|
|
log.Errorf("[Weaviate] No answer found in response body: %s", responseBody)
|
|
return nil, errors.New("[Weaviate] No answer found in response body")
|
|
}
|
|
|
|
resultNum := gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.#", d.config.collectionID)).Int()
|
|
results := make([]QueryResult, 0, resultNum)
|
|
for i := 0; i < int(resultNum); i++ {
|
|
result := QueryResult{
|
|
Text: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d.question", d.config.collectionID, i)).String(),
|
|
Score: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d._additional.distance", d.config.collectionID, i)).Float(),
|
|
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d.answer", d.config.collectionID, i)).String(),
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
|
|
return results, nil
|
|
}
|