mirror of
https://github.com/alibaba/higress.git
synced 2026-04-21 12:07:26 +08:00
165 lines
4.3 KiB
Go
165 lines
4.3 KiB
Go
package elasticsearch
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
|
"github.com/tidwall/gjson"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine"
|
|
)
|
|
|
|
type ElasticsearchSearch struct {
|
|
client wrapper.HttpClient
|
|
index string
|
|
contentField string
|
|
semanticTextField string
|
|
linkField string
|
|
titleField string
|
|
start int
|
|
count int
|
|
timeoutMillisecond uint32
|
|
username string
|
|
password string
|
|
}
|
|
|
|
func NewElasticsearchSearch(config *gjson.Result, needReference bool) (*ElasticsearchSearch, error) {
|
|
engine := &ElasticsearchSearch{}
|
|
serviceName := config.Get("serviceName").String()
|
|
if serviceName == "" {
|
|
return nil, errors.New("serviceName not found")
|
|
}
|
|
servicePort := config.Get("servicePort").Int()
|
|
if servicePort == 0 {
|
|
if strings.HasSuffix(serviceName, ".static") {
|
|
servicePort = 80
|
|
} else if strings.HasSuffix(serviceName, ".dns") {
|
|
servicePort = 443
|
|
} else {
|
|
return nil, errors.New("servicePort not found")
|
|
}
|
|
}
|
|
engine.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
|
FQDN: serviceName,
|
|
Port: servicePort,
|
|
})
|
|
engine.index = config.Get("index").String()
|
|
if engine.index == "" {
|
|
return nil, errors.New("index not found")
|
|
}
|
|
|
|
engine.contentField = config.Get("contentField").String()
|
|
if engine.contentField == "" {
|
|
return nil, errors.New("contentField not found")
|
|
}
|
|
engine.semanticTextField = config.Get("semanticTextField").String()
|
|
if engine.semanticTextField == "" {
|
|
return nil, errors.New("semanticTextField not found")
|
|
}
|
|
|
|
if needReference {
|
|
engine.linkField = config.Get("linkField").String()
|
|
if engine.linkField == "" {
|
|
return nil, errors.New("linkField not found")
|
|
}
|
|
engine.titleField = config.Get("titleField").String()
|
|
if engine.titleField == "" {
|
|
return nil, errors.New("titleField not found")
|
|
}
|
|
}
|
|
|
|
engine.timeoutMillisecond = uint32(config.Get("timeoutMillisecond").Uint())
|
|
if engine.timeoutMillisecond == 0 {
|
|
engine.timeoutMillisecond = 5000
|
|
}
|
|
engine.start = int(config.Get("start").Uint())
|
|
engine.count = int(config.Get("count").Uint())
|
|
if engine.count == 0 {
|
|
engine.count = 10
|
|
}
|
|
|
|
engine.username = config.Get("username").String()
|
|
engine.password = config.Get("password").String()
|
|
|
|
return engine, nil
|
|
}
|
|
|
|
func (e ElasticsearchSearch) NeedExectue(ctx engine.SearchContext) bool {
|
|
return ctx.EngineType == "private" || ctx.EngineType == ""
|
|
}
|
|
|
|
func (e ElasticsearchSearch) Client() wrapper.HttpClient {
|
|
return e.client
|
|
}
|
|
|
|
func (e ElasticsearchSearch) generateAuthorizationHeader() string {
|
|
return fmt.Sprintf(`Basic %s`, base64.StdEncoding.EncodeToString([]byte(e.username+":"+e.password)))
|
|
}
|
|
|
|
func (e ElasticsearchSearch) generateQueryBody(ctx engine.SearchContext) string {
|
|
queryText := strings.Join(ctx.Querys, " ")
|
|
return fmt.Sprintf(`{
|
|
"_source":{
|
|
"excludes": "%s"
|
|
},
|
|
"retriever": {
|
|
"rrf": {
|
|
"retrievers": [
|
|
{
|
|
"standard": {
|
|
"query": {
|
|
"match": {
|
|
"%s": "%s"
|
|
}
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"standard": {
|
|
"query": {
|
|
"semantic": {
|
|
"field": "%s",
|
|
"query": "%s"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
]
|
|
}
|
|
}
|
|
}`, e.semanticTextField, e.contentField, queryText, e.semanticTextField, queryText)
|
|
}
|
|
|
|
func (e ElasticsearchSearch) CallArgs(ctx engine.SearchContext) engine.CallArgs {
|
|
queryBody := e.generateQueryBody(ctx)
|
|
return engine.CallArgs{
|
|
Method: http.MethodPost,
|
|
Url: fmt.Sprintf("/%s/_search?from=%d&size=%d", e.index, e.start, e.count),
|
|
Headers: [][2]string{
|
|
{"Content-Type", "application/json"},
|
|
{"Authorization", e.generateAuthorizationHeader()},
|
|
},
|
|
Body: []byte(queryBody),
|
|
TimeoutMillisecond: e.timeoutMillisecond,
|
|
}
|
|
}
|
|
|
|
func (e ElasticsearchSearch) ParseResult(ctx engine.SearchContext, response []byte) []engine.SearchResult {
|
|
jsonObj := gjson.ParseBytes(response)
|
|
var results []engine.SearchResult
|
|
for _, hit := range jsonObj.Get("hits.hits").Array() {
|
|
source := hit.Get("_source")
|
|
result := engine.SearchResult{
|
|
Title: source.Get(e.titleField).String(),
|
|
Link: source.Get(e.linkField).String(),
|
|
Content: source.Get(e.contentField).String(),
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
return results
|
|
}
|