mirror of
https://github.com/alibaba/higress.git
synced 2026-02-24 04:30:51 +08:00
209 lines
5.4 KiB
Go
209 lines
5.4 KiB
Go
package vector
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
|
"github.com/google/uuid"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type qdrantProviderInitializer struct{}
|
|
|
|
func (c *qdrantProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
|
if len(config.serviceName) == 0 {
|
|
return errors.New("[Qdrant] serviceName is required")
|
|
}
|
|
if len(config.collectionID) == 0 {
|
|
return errors.New("[Qdrant] collectionID is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *qdrantProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
|
return &qdrantProvider{
|
|
config: config,
|
|
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
|
FQDN: config.serviceName,
|
|
Host: config.serviceHost,
|
|
Port: int64(config.servicePort),
|
|
}),
|
|
}, nil
|
|
}
|
|
|
|
type qdrantProvider struct {
|
|
config ProviderConfig
|
|
client wrapper.HttpClient
|
|
}
|
|
|
|
func (c *qdrantProvider) GetProviderType() string {
|
|
return PROVIDER_TYPE_QDRANT
|
|
}
|
|
|
|
type qdrantPayload struct {
|
|
Question string `json:"question"`
|
|
Answer string `json:"answer"`
|
|
}
|
|
|
|
type qdrantPoint struct {
|
|
ID string `json:"id"`
|
|
Vector []float64 `json:"vector"`
|
|
Payload qdrantPayload `json:"payload"`
|
|
}
|
|
|
|
type qdrantInsertRequest struct {
|
|
Points []qdrantPoint `json:"points"`
|
|
}
|
|
|
|
func (d *qdrantProvider) UploadAnswerAndEmbedding(
|
|
queryString string,
|
|
queryEmb []float64,
|
|
queryAnswer string,
|
|
ctx wrapper.HttpContext,
|
|
log wrapper.Log,
|
|
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
|
// 最少需要填写的参数为 id 和 vector. payload 可选
|
|
// 下面是一个例子
|
|
// {
|
|
// "points": [
|
|
// {
|
|
// "id": "76874cce-1fb9-4e16-9b0b-f085ac06ed6f",
|
|
// "payload": {
|
|
// "question": "这里是问题",
|
|
// "answer": "这里是答案"
|
|
// },
|
|
// "vector": [
|
|
// 0.9,
|
|
// 0.1,
|
|
// 0.1
|
|
// ]
|
|
// }
|
|
// ]
|
|
// }
|
|
requestBody, err := json.Marshal(qdrantInsertRequest{
|
|
Points: []qdrantPoint{
|
|
{
|
|
ID: uuid.New().String(),
|
|
Vector: queryEmb,
|
|
Payload: qdrantPayload{Question: queryString, Answer: queryAnswer},
|
|
},
|
|
},
|
|
})
|
|
|
|
if err != nil {
|
|
log.Errorf("[Qdrant] Failed to marshal upload embedding request body: %v", err)
|
|
return err
|
|
}
|
|
|
|
return d.client.Put(
|
|
fmt.Sprintf("/collections/%s/points", d.config.collectionID),
|
|
[][2]string{
|
|
{"Content-Type", "application/json"},
|
|
{"api-key", d.config.apiKey},
|
|
},
|
|
requestBody,
|
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
log.Debugf("[Qdrant] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
|
callback(ctx, log, err)
|
|
},
|
|
d.config.timeout,
|
|
)
|
|
}
|
|
|
|
type qdrantQueryRequest struct {
|
|
Vector []float64 `json:"vector"`
|
|
Limit int `json:"limit"`
|
|
WithPayload bool `json:"with_payload"`
|
|
}
|
|
|
|
func (d *qdrantProvider) QueryEmbedding(
|
|
emb []float64,
|
|
ctx wrapper.HttpContext,
|
|
log wrapper.Log,
|
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
|
// 最少需要填写的参数为 vector 和 limit. with_payload 可选,为了直接得到问题答案,所以这里需要
|
|
// 下面是一个例子
|
|
// {
|
|
// "vector": [
|
|
// 0.2,
|
|
// 0.1,
|
|
// 0.9,
|
|
// 0.7
|
|
// ],
|
|
// "limit": 1
|
|
// }
|
|
requestBody, err := json.Marshal(qdrantQueryRequest{
|
|
Vector: emb,
|
|
Limit: d.config.topK,
|
|
WithPayload: true,
|
|
})
|
|
if err != nil {
|
|
log.Errorf("[Qdrant] Failed to marshal query embedding: %v", err)
|
|
return err
|
|
}
|
|
|
|
return d.client.Post(
|
|
fmt.Sprintf("/collections/%s/points/search", d.config.collectionID),
|
|
[][2]string{
|
|
{"Content-Type", "application/json"},
|
|
{"api-key", d.config.apiKey},
|
|
},
|
|
requestBody,
|
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
log.Debugf("[Qdrant] Query embedding response: %d, %s", statusCode, responseBody)
|
|
results, err := d.parseQueryResponse(responseBody, log)
|
|
if err != nil {
|
|
err = fmt.Errorf("[Qdrant] Failed to parse query response: %v", err)
|
|
}
|
|
callback(results, ctx, log, err)
|
|
},
|
|
d.config.timeout,
|
|
)
|
|
}
|
|
|
|
func (d *qdrantProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
|
// 返回的内容例子如下
|
|
// {
|
|
// "time": 0.002,
|
|
// "status": "ok",
|
|
// "result": [
|
|
// {
|
|
// "id": 42,
|
|
// "version": 3,
|
|
// "score": 0.75,
|
|
// "payload": {
|
|
// "question": "London",
|
|
// "answer": "green"
|
|
// },
|
|
// "shard_key": "region_1",
|
|
// "order_value": 42
|
|
// }
|
|
// ]
|
|
// }
|
|
if !gjson.GetBytes(responseBody, "result.0.score").Exists() {
|
|
log.Errorf("[Qdrant] No distance found in response body: %s", responseBody)
|
|
return nil, errors.New("[Qdrant] No distance found in response body")
|
|
}
|
|
|
|
if !gjson.GetBytes(responseBody, "result.0.payload.answer").Exists() {
|
|
log.Errorf("[Qdrant] No answer found in response body: %s", responseBody)
|
|
return nil, errors.New("[Qdrant] No answer found in response body")
|
|
}
|
|
|
|
resultNum := gjson.GetBytes(responseBody, "result.#").Int()
|
|
results := make([]QueryResult, 0, resultNum)
|
|
for i := 0; i < int(resultNum); i++ {
|
|
result := QueryResult{
|
|
Text: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.payload.question", i)).String(),
|
|
Score: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.score", i)).Float(),
|
|
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.payload.answer", i)).String(),
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
|
|
return results, nil
|
|
}
|