From b86e9fc9382e46853b2ea553cfaa32ae2fc2fd2c Mon Sep 17 00:00:00 2001 From: Yiiong <1783172311@qq.com> Date: Sat, 29 Mar 2025 18:08:37 +0800 Subject: [PATCH] feat: add azure embedding to ai-cache (#1975) --- plugins/wasm-go/extensions/ai-cache/core.go | 2 +- .../extensions/ai-cache/embedding/azure.go | 172 ++++++++++++++++++ .../extensions/ai-cache/embedding/cohere.go | 6 +- .../ai-cache/embedding/dashscope.go | 6 +- .../ai-cache/embedding/huggingface.go | 29 +-- .../extensions/ai-cache/embedding/ollama.go | 11 +- .../extensions/ai-cache/embedding/openai.go | 6 +- .../extensions/ai-cache/embedding/provider.go | 7 +- .../extensions/ai-cache/embedding/textin.go | 6 +- .../extensions/ai-cache/embedding/xfyun.go | 39 ++-- 10 files changed, 230 insertions(+), 54 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-cache/embedding/azure.go diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index b46fd28e8..44dea098d 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -130,7 +130,7 @@ func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginC return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] no embedding provider configured for similarity search")) } - return activeEmbeddingProvider.GetEmbedding(key, ctx, log, func(textEmbedding []float64, err error) { + return activeEmbeddingProvider.GetEmbedding(key, ctx, func(textEmbedding []float64, err error) { log.Debugf("[%s] [performEmbeddingQuery] GetEmbedding success, length of embedding: %d, error: %v", PLUGIN_NAME, len(textEmbedding), err) if err != nil { handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error getting embedding for key: %s", PLUGIN_NAME, key), log) diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/azure.go b/plugins/wasm-go/extensions/ai-cache/embedding/azure.go new file mode 100644 index 000000000..3fc33d23a --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/azure.go @@ -0,0 +1,172 @@ +package embedding + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + AZURE_PORT = 443 + AZURE_DEFAULT_MODEL_NAME = "text-embedding-ada-002" + AZURE_ENDPOINT = "/openai/deployments/{model}/embeddings" +) + +type azureProviderInitializer struct { +} + +var azureConfig azureProviderConfig + +type azureProviderConfig struct { + // @Title zh-CN 文本特征提取服务 API Key + // @Description zh-CN 文本特征提取服务 API Key + apiKey string + // @Title zh-CN 文本特征提取 api-version + // @Description zh-CN 文本特征提取服务 api-version + apiVersion string +} + +func (c *azureProviderInitializer) InitConfig(json gjson.Result) { + azureConfig.apiKey = json.Get("apiKey").String() + azureConfig.apiVersion = json.Get("apiVersion").String() +} + +func (c *azureProviderInitializer) ValidateConfig() error { + if azureConfig.apiKey == "" { + return errors.New("[Azure] apiKey is required") + } + if azureConfig.apiVersion == "" { + return errors.New("[Azure] apiVersion is required") + } + return nil +} + +func (t *azureProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { + if c.servicePort == 0 { + c.servicePort = AZURE_PORT + } + + if c.model == "" { + c.model = AZURE_DEFAULT_MODEL_NAME + } + + return &AzureProvider{ + config: c, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: c.serviceName, + Host: c.serviceHost, + Port: c.servicePort, + }), + }, nil +} + +func (t *AzureProvider) GetProviderType() string { + return PROVIDER_TYPE_AZURE +} + +type AzureProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +type AzureEmbeddingRequest struct { + Input string `json:"input"` +} + +func (t *AzureProvider) constructParameters(text string) (string, [][2]string, []byte, error) { + if text == "" { + err := errors.New("queryString text cannot be empty") + return "", nil, nil, err + } + + data := AzureEmbeddingRequest{ + Input: text, + } + + requestBody, err := json.Marshal(data) + if err != nil { + log.Errorf("failed to marshal request data: %v", err) + return "", nil, nil, err + } + + model := t.config.model + if model == "" { + model = AZURE_DEFAULT_MODEL_NAME + } + + // 拼接 endpoint + endpoint := strings.Replace(AZURE_ENDPOINT, "{model}", model, 1) + endpoint = endpoint + "?" + "api-version=" + azureConfig.apiVersion + + headers := [][2]string{ + {"api-key", azureConfig.apiKey}, + {"Content-Type", "application/json"}, + } + + return endpoint, headers, requestBody, err +} + +type AzureEmbeddingResponse struct { + Object string `json:"object"` + Model string `json:"model"` + Data []struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` +} + +func (t *AzureProvider) parseTextEmbedding(responseBody []byte) (*AzureEmbeddingResponse, error) { + var resp AzureEmbeddingResponse + if err := json.Unmarshal(responseBody, &resp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &resp, nil +} + +func (t *AzureProvider) GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + callback func(emb []float64, err error)) error { + embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString) + if err != nil { + log.Errorf("failed to construct parameters: %v", err) + return err + } + + var resp *AzureEmbeddingResponse + err = t.client.Post(embUrl, embHeaders, embRequestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + + if statusCode != http.StatusOK { + err = fmt.Errorf("failed to get embedding due to status code: %d, resp: %s", statusCode, responseBody) + callback(nil, err) + return + } + + resp, err = t.parseTextEmbedding(responseBody) + if err != nil { + err = fmt.Errorf("failed to parse response: %v", err) + callback(nil, err) + return + } + + log.Debugf("get embedding response: %d, %s", statusCode, responseBody) + + if len(resp.Data) == 0 { + err = errors.New("no embedding found in response") + callback(nil, err) + return + } + + callback(resp.Data[0].Embedding, nil) + + }, t.config.timeout) + return err +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/cohere.go b/plugins/wasm-go/extensions/ai-cache/embedding/cohere.go index d952d2ad2..1e37d1d77 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/cohere.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/cohere.go @@ -7,6 +7,7 @@ import ( "net/http" "strconv" + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -79,7 +80,7 @@ type CohereProvider struct { func (t *CohereProvider) GetProviderType() string { return PROVIDER_TYPE_COHERE } -func (t *CohereProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { +func (t *CohereProvider) constructParameters(texts []string) (string, [][2]string, []byte, error) { model := t.config.model if model == "" { @@ -118,9 +119,8 @@ func (t *CohereProvider) parseTextEmbedding(responseBody []byte) (*cohereRespons func (t *CohereProvider) GetEmbedding( queryString string, ctx wrapper.HttpContext, - log wrapper.Log, callback func(emb []float64, err error)) error { - embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log) + embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}) if err != nil { log.Errorf("failed to construct parameters: %v", err) return err diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index f31a8d17b..a577a024f 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -7,6 +7,7 @@ import ( "net/http" "strconv" + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -103,7 +104,7 @@ type DSProvider struct { client wrapper.HttpClient } -func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { +func (d *DSProvider) constructParameters(texts []string) (string, [][2]string, []byte, error) { model := d.config.model @@ -159,9 +160,8 @@ func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error) func (d *DSProvider) GetEmbedding( queryString string, ctx wrapper.HttpContext, - log wrapper.Log, callback func(emb []float64, err error)) error { - embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString}, log) + embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString}) if err != nil { log.Errorf("failed to construct parameters: %v", err) return err diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/huggingface.go b/plugins/wasm-go/extensions/ai-cache/embedding/huggingface.go index 45d9325d3..8a6613916 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/huggingface.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/huggingface.go @@ -4,11 +4,13 @@ import ( "encoding/json" "errors" "fmt" - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/tidwall/gjson" "net/http" "strconv" "strings" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" ) const ( @@ -18,29 +20,29 @@ const ( HUGGINGFACE_ENDPOINT = "/pipeline/feature-extraction/{modelId}" ) -type HuggingFaceProviderInitializer struct { +type huggingfaceProviderInitializer struct { } -var HuggingFaceConfig HuggingFaceProviderConfig +var huggingfaceConfig huggingfaceProviderConfig -type HuggingFaceProviderConfig struct { +type huggingfaceProviderConfig struct { // @Title zh-CN 文本特征提取服务 API Key // @Description zh-CN 文本特征提取服务 API Key。在HuggingFace定义为 hf_token apiKey string } -func (c *HuggingFaceProviderInitializer) InitConfig(json gjson.Result) { - HuggingFaceConfig.apiKey = json.Get("apiKey").String() +func (c *huggingfaceProviderInitializer) InitConfig(json gjson.Result) { + huggingfaceConfig.apiKey = json.Get("apiKey").String() } -func (c *HuggingFaceProviderInitializer) ValidateConfig() error { - if HuggingFaceConfig.apiKey == "" { +func (c *huggingfaceProviderInitializer) ValidateConfig() error { + if huggingfaceConfig.apiKey == "" { return errors.New("[HuggingFace] hfTokens is required") } return nil } -func (t *HuggingFaceProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { +func (t *huggingfaceProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { if c.servicePort == 0 { c.servicePort = HUGGINGFACE_PORT } @@ -78,7 +80,7 @@ type HuggingFaceEmbeddingRequest struct { } `json:"options"` } -func (t *HuggingFaceProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) { +func (t *HuggingFaceProvider) constructParameters(text string) (string, [][2]string, []byte, error) { if text == "" { err := errors.New("queryString text cannot be empty") return "", nil, nil, err @@ -108,7 +110,7 @@ func (t *HuggingFaceProvider) constructParameters(text string, log wrapper.Log) endpoint := strings.Replace(HUGGINGFACE_ENDPOINT, "{modelId}", modelId, 1) headers := [][2]string{ - {"Authorization", "Bearer " + HuggingFaceConfig.apiKey}, + {"Authorization", "Bearer " + huggingfaceConfig.apiKey}, {"Content-Type", "application/json"}, } @@ -127,9 +129,8 @@ func (t *HuggingFaceProvider) parseTextEmbedding(responseBody []byte) ([]float64 func (t *HuggingFaceProvider) GetEmbedding( queryString string, ctx wrapper.HttpContext, - log wrapper.Log, callback func(emb []float64, err error)) error { - embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log) + embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString) if err != nil { log.Errorf("failed to construct parameters: %v", err) return err diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/ollama.go b/plugins/wasm-go/extensions/ai-cache/embedding/ollama.go index a61bf7782..49bc6e20b 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/ollama.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/ollama.go @@ -4,10 +4,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/tidwall/gjson" "net/http" "strconv" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" ) const ( @@ -69,7 +71,7 @@ type ollamaEmbeddingRequest struct { Model string `json:"model"` } -func (t *ollamaProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) { +func (t *ollamaProvider) constructParameters(text string) (string, [][2]string, []byte, error) { if text == "" { err := errors.New("queryString text cannot be empty") return "", nil, nil, err @@ -105,9 +107,8 @@ func (t *ollamaProvider) parseTextEmbedding(responseBody []byte) (*ollamaRespons func (t *ollamaProvider) GetEmbedding( queryString string, ctx wrapper.HttpContext, - log wrapper.Log, callback func(emb []float64, err error)) error { - embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log) + embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString) if err != nil { log.Errorf("failed to construct parameters: %v", err) return err diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/openai.go b/plugins/wasm-go/extensions/ai-cache/embedding/openai.go index 04c1d8cdd..6d504305a 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/openai.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/openai.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -93,7 +94,7 @@ type OpenAIProvider struct { client wrapper.HttpClient } -func (t *OpenAIProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) { +func (t *OpenAIProvider) constructParameters(text string) (string, [][2]string, []byte, error) { if text == "" { err := errors.New("queryString text cannot be empty") return "", nil, nil, err @@ -130,9 +131,8 @@ func (t *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIRespons func (t *OpenAIProvider) GetEmbedding( queryString string, ctx wrapper.HttpContext, - log wrapper.Log, callback func(emb []float64, err error)) error { - embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log) + embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString) if err != nil { log.Errorf("failed to construct parameters: %v", err) return err diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index 62c3970fe..be0c84716 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -15,6 +15,7 @@ const ( PROVIDER_TYPE_OLLAMA = "ollama" PROVIDER_TYPE_HUGGINGFACE = "huggingface" PROVIDER_TYPE_XFYUN = "xfyun" + PROVIDER_TYPE_AZURE = "azure" ) type providerInitializer interface { @@ -30,8 +31,9 @@ var ( PROVIDER_TYPE_COHERE: &cohereProviderInitializer{}, PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{}, PROVIDER_TYPE_OLLAMA: &ollamaProviderInitializer{}, - PROVIDER_TYPE_HUGGINGFACE: &HuggingFaceProviderInitializer{}, - PROVIDER_TYPE_XFYUN: &XfyunProviderInitializer{}, + PROVIDER_TYPE_HUGGINGFACE: &huggingfaceProviderInitializer{}, + PROVIDER_TYPE_XFYUN: &xfyunProviderInitializer{}, + PROVIDER_TYPE_AZURE: &azureProviderInitializer{}, } ) @@ -108,6 +110,5 @@ type Provider interface { GetEmbedding( queryString string, ctx wrapper.HttpContext, - log wrapper.Log, callback func(emb []float64, err error)) error } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/textin.go b/plugins/wasm-go/extensions/ai-cache/embedding/textin.go index 5ff29f1af..6cef96baa 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/textin.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/textin.go @@ -7,6 +7,7 @@ import ( "net/http" "strconv" + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -97,7 +98,7 @@ type TIProvider struct { client wrapper.HttpClient } -func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { +func (t *TIProvider) constructParameters(texts []string) (string, [][2]string, []byte, error) { data := TextInEmbeddingRequest{ Input: texts, @@ -142,9 +143,8 @@ func (t *TIProvider) parseTextEmbedding(responseBody []byte) (*TextInResponse, e func (t *TIProvider) GetEmbedding( queryString string, ctx wrapper.HttpContext, - log wrapper.Log, callback func(emb []float64, err error)) error { - embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log) + embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}) if err != nil { log.Errorf("failed to construct parameters: %v", err) return err diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/xfyun.go b/plugins/wasm-go/extensions/ai-cache/embedding/xfyun.go index a3410d3dd..128c950fa 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/xfyun.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/xfyun.go @@ -8,13 +8,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/tidwall/gjson" "math" "net/http" "net/url" "strconv" "time" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" ) const ( @@ -22,12 +24,12 @@ const ( XFYUN_PORT = 443 ) -type XfyunProviderInitializer struct { +type xfyunProviderInitializer struct { } -var XfyunConfig XfyunProviderConfig +var xfyunConfig xfyunProviderConfig -type XfyunProviderConfig struct { +type xfyunProviderConfig struct { // @Title zh-CN 文本特征提取服务 API Key // @Description zh-CN 文本特征提取服务 API Key。 apiKey string @@ -39,26 +41,26 @@ type XfyunProviderConfig struct { xfyunApiSecret string } -func (c *XfyunProviderInitializer) InitConfig(json gjson.Result) { - XfyunConfig.xfyunAppID = json.Get("appId").String() - XfyunConfig.xfyunApiSecret = json.Get("apiSecret").String() - XfyunConfig.apiKey = json.Get("apiKey").String() +func (c *xfyunProviderInitializer) InitConfig(json gjson.Result) { + xfyunConfig.xfyunAppID = json.Get("appId").String() + xfyunConfig.xfyunApiSecret = json.Get("apiSecret").String() + xfyunConfig.apiKey = json.Get("apiKey").String() } -func (c *XfyunProviderInitializer) ValidateConfig() error { - if XfyunConfig.apiKey == "" { +func (c *xfyunProviderInitializer) ValidateConfig() error { + if xfyunConfig.apiKey == "" { return errors.New("[Xfyun] apiKey is required") } - if XfyunConfig.xfyunAppID == "" { + if xfyunConfig.xfyunAppID == "" { return errors.New("[Xfyun] appId is required") } - if XfyunConfig.xfyunApiSecret == "" { + if xfyunConfig.xfyunApiSecret == "" { return errors.New("[Xfyun] apiSecret is required") } return nil } -func (t *XfyunProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { +func (t *xfyunProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { if c.servicePort == 0 { c.servicePort = XFYUN_PORT } @@ -160,14 +162,14 @@ func constructAuth(requestURL, method, apiKey, apiSecret string) (string, error) return "?" + params.Encode(), nil } -func (t *XfyunProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) { +func (t *XfyunProvider) constructParameters(text string) (string, [][2]string, []byte, error) { if text == "" { err := errors.New("queryString text cannot be empty") return "", nil, nil, err } host := "https://" + t.config.serviceHost + "/" - auth, err := constructAuth(host, "POST", XfyunConfig.apiKey, XfyunConfig.xfyunApiSecret) + auth, err := constructAuth(host, "POST", xfyunConfig.apiKey, xfyunConfig.xfyunApiSecret) if err != nil { return "", nil, nil, err } @@ -199,7 +201,7 @@ func (t *XfyunProvider) constructParameters(text string, log wrapper.Log) (strin // 构建请求体 data := XfyunReqBody{ Header: XfyunHeader{ - AppID: XfyunConfig.xfyunAppID, + AppID: xfyunConfig.xfyunAppID, Status: 3, }, Parameter: XfyunParameter{ @@ -265,9 +267,8 @@ func (t *XfyunProvider) parseTextEmbedding(responseBody []byte) ([]float32, erro func (t *XfyunProvider) GetEmbedding( queryString string, ctx wrapper.HttpContext, - log wrapper.Log, callback func(emb []float64, err error)) error { - embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log) + embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString) if err != nil { log.Errorf("failed to construct parameters: %v", err) return err