mirror of
https://github.com/alibaba/higress.git
synced 2026-03-18 09:17:26 +08:00
feat: add azure embedding to ai-cache (#1975)
This commit is contained in:
@@ -4,11 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -18,29 +20,29 @@ const (
|
||||
HUGGINGFACE_ENDPOINT = "/pipeline/feature-extraction/{modelId}"
|
||||
)
|
||||
|
||||
type HuggingFaceProviderInitializer struct {
|
||||
type huggingfaceProviderInitializer struct {
|
||||
}
|
||||
|
||||
var HuggingFaceConfig HuggingFaceProviderConfig
|
||||
var huggingfaceConfig huggingfaceProviderConfig
|
||||
|
||||
type HuggingFaceProviderConfig struct {
|
||||
type huggingfaceProviderConfig struct {
|
||||
// @Title zh-CN 文本特征提取服务 API Key
|
||||
// @Description zh-CN 文本特征提取服务 API Key。在HuggingFace定义为 hf_token
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func (c *HuggingFaceProviderInitializer) InitConfig(json gjson.Result) {
|
||||
HuggingFaceConfig.apiKey = json.Get("apiKey").String()
|
||||
func (c *huggingfaceProviderInitializer) InitConfig(json gjson.Result) {
|
||||
huggingfaceConfig.apiKey = json.Get("apiKey").String()
|
||||
}
|
||||
|
||||
func (c *HuggingFaceProviderInitializer) ValidateConfig() error {
|
||||
if HuggingFaceConfig.apiKey == "" {
|
||||
func (c *huggingfaceProviderInitializer) ValidateConfig() error {
|
||||
if huggingfaceConfig.apiKey == "" {
|
||||
return errors.New("[HuggingFace] hfTokens is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *HuggingFaceProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
|
||||
func (t *huggingfaceProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
|
||||
if c.servicePort == 0 {
|
||||
c.servicePort = HUGGINGFACE_PORT
|
||||
}
|
||||
@@ -78,7 +80,7 @@ type HuggingFaceEmbeddingRequest struct {
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
func (t *HuggingFaceProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) {
|
||||
func (t *HuggingFaceProvider) constructParameters(text string) (string, [][2]string, []byte, error) {
|
||||
if text == "" {
|
||||
err := errors.New("queryString text cannot be empty")
|
||||
return "", nil, nil, err
|
||||
@@ -108,7 +110,7 @@ func (t *HuggingFaceProvider) constructParameters(text string, log wrapper.Log)
|
||||
endpoint := strings.Replace(HUGGINGFACE_ENDPOINT, "{modelId}", modelId, 1)
|
||||
|
||||
headers := [][2]string{
|
||||
{"Authorization", "Bearer " + HuggingFaceConfig.apiKey},
|
||||
{"Authorization", "Bearer " + huggingfaceConfig.apiKey},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
|
||||
@@ -127,9 +129,8 @@ func (t *HuggingFaceProvider) parseTextEmbedding(responseBody []byte) ([]float64
|
||||
func (t *HuggingFaceProvider) GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(emb []float64, err error)) error {
|
||||
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log)
|
||||
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to construct parameters: %v", err)
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user