mirror of
https://github.com/alibaba/higress.git
synced 2026-02-22 19:51:03 +08:00
257 lines
6.9 KiB
Go
257 lines
6.9 KiB
Go
package vector
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
|
)
|
|
|
|
type dashVectorProviderInitializer struct {
|
|
}
|
|
|
|
func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
|
if len(config.apiKey) == 0 {
|
|
return errors.New("[DashVector] apiKey is required")
|
|
}
|
|
if len(config.collectionID) == 0 {
|
|
return errors.New("[DashVector] collectionID is required")
|
|
}
|
|
if len(config.serviceName) == 0 {
|
|
return errors.New("[DashVector] serviceName is required")
|
|
}
|
|
if len(config.serviceHost) == 0 {
|
|
return errors.New("[DashVector] serviceHost is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
|
return &DvProvider{
|
|
config: config,
|
|
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
|
FQDN: config.serviceName,
|
|
Host: config.serviceHost,
|
|
Port: int64(config.servicePort),
|
|
}),
|
|
}, nil
|
|
}
|
|
|
|
type DvProvider struct {
|
|
config ProviderConfig
|
|
client wrapper.HttpClient
|
|
}
|
|
|
|
func (d *DvProvider) GetProviderType() string {
|
|
return PROVIDER_TYPE_DASH_VECTOR
|
|
}
|
|
|
|
// type embeddingRequest struct {
|
|
// Model string `json:"model"`
|
|
// Input input `json:"input"`
|
|
// Parameters params `json:"parameters"`
|
|
// }
|
|
|
|
// type params struct {
|
|
// TextType string `json:"text_type"`
|
|
// }
|
|
|
|
// type input struct {
|
|
// Texts []string `json:"texts"`
|
|
// }
|
|
|
|
// queryResponse 定义查询响应的结构
|
|
type queryResponse struct {
|
|
Code int `json:"code"`
|
|
RequestID string `json:"request_id"`
|
|
Message string `json:"message"`
|
|
Output []result `json:"output"`
|
|
}
|
|
|
|
// queryRequest 定义查询请求的结构
|
|
type queryRequest struct {
|
|
Vector []float64 `json:"vector"`
|
|
TopK int `json:"topk"`
|
|
IncludeVector bool `json:"include_vector"`
|
|
}
|
|
|
|
// result 定义查询结果的结构
|
|
type result struct {
|
|
ID string `json:"id"`
|
|
Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化
|
|
Fields map[string]interface{} `json:"fields"`
|
|
Score float64 `json:"score"`
|
|
}
|
|
|
|
func (d *DvProvider) constructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) {
|
|
url := fmt.Sprintf("/v1/collections/%s/query", d.config.collectionID)
|
|
|
|
requestData := queryRequest{
|
|
Vector: vector,
|
|
TopK: d.config.topK,
|
|
IncludeVector: false,
|
|
}
|
|
|
|
requestBody, err := json.Marshal(requestData)
|
|
if err != nil {
|
|
return "", nil, nil, err
|
|
}
|
|
|
|
header := [][2]string{
|
|
{"Content-Type", "application/json"},
|
|
{"dashvector-auth-token", d.config.apiKey},
|
|
}
|
|
|
|
return url, requestBody, header, nil
|
|
}
|
|
|
|
func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, error) {
|
|
var queryResp queryResponse
|
|
err := json.Unmarshal(responseBody, &queryResp)
|
|
if err != nil {
|
|
return queryResponse{}, err
|
|
}
|
|
return queryResp, nil
|
|
}
|
|
|
|
func (d *DvProvider) QueryEmbedding(
|
|
emb []float64,
|
|
ctx wrapper.HttpContext,
|
|
log wrapper.Log,
|
|
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
|
url, body, headers, err := d.constructEmbeddingQueryParameters(emb)
|
|
log.Debugf("url:%s, body:%s, headers:%v", url, string(body), headers)
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to construct embedding query parameters: %v", err)
|
|
return err
|
|
}
|
|
|
|
err = d.client.Post(url, headers, body,
|
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
err = nil
|
|
if statusCode != http.StatusOK {
|
|
err = fmt.Errorf("failed to query embedding: %d", statusCode)
|
|
callback(nil, ctx, log, err)
|
|
return
|
|
}
|
|
log.Debugf("query embedding response: %d, %s", statusCode, responseBody)
|
|
results, err := d.ParseQueryResponse(responseBody, ctx, log)
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to parse query response: %v", err)
|
|
}
|
|
callback(results, ctx, log, err)
|
|
},
|
|
d.config.timeout)
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to query embedding: %v", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func getStringValue(fields map[string]interface{}, key string) string {
|
|
if val, ok := fields[key]; ok {
|
|
return val.(string)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) {
|
|
resp, err := d.parseQueryResponse(responseBody)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(resp.Output) == 0 {
|
|
return nil, errors.New("no query results found in response")
|
|
}
|
|
|
|
results := make([]QueryResult, 0, len(resp.Output))
|
|
|
|
for _, output := range resp.Output {
|
|
result := QueryResult{
|
|
Text: getStringValue(output.Fields, "query"),
|
|
Embedding: output.Vector,
|
|
Score: output.Score,
|
|
Answer: getStringValue(output.Fields, "answer"),
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
type document struct {
|
|
Vector []float64 `json:"vector"`
|
|
Fields map[string]string `json:"fields"`
|
|
}
|
|
|
|
type insertRequest struct {
|
|
Docs []document `json:"docs"`
|
|
}
|
|
|
|
func (d *DvProvider) constructUploadParameters(emb []float64, queryString string, answer string) (string, []byte, [][2]string, error) {
|
|
url := "/v1/collections/" + d.config.collectionID + "/docs"
|
|
|
|
doc := document{
|
|
Vector: emb,
|
|
Fields: map[string]string{
|
|
"query": queryString,
|
|
"answer": answer,
|
|
},
|
|
}
|
|
|
|
requestBody, err := json.Marshal(insertRequest{Docs: []document{doc}})
|
|
if err != nil {
|
|
return "", nil, nil, err
|
|
}
|
|
|
|
header := [][2]string{
|
|
{"Content-Type", "application/json"},
|
|
{"dashvector-auth-token", d.config.apiKey},
|
|
}
|
|
|
|
return url, requestBody, header, err
|
|
}
|
|
|
|
func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
|
url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = d.client.Post(
|
|
url,
|
|
headers,
|
|
body,
|
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
|
if statusCode != http.StatusOK {
|
|
err = fmt.Errorf("failed to upload embedding: %d", statusCode)
|
|
}
|
|
callback(ctx, log, err)
|
|
},
|
|
d.config.timeout)
|
|
return err
|
|
}
|
|
|
|
func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
|
url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = d.client.Post(
|
|
url,
|
|
headers,
|
|
body,
|
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody))
|
|
if statusCode != http.StatusOK {
|
|
err = fmt.Errorf("failed to upload embedding: %d", statusCode)
|
|
}
|
|
callback(ctx, log, err)
|
|
},
|
|
d.config.timeout)
|
|
return err
|
|
}
|