diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 1fc7a160a..0153abb82 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -157,15 +157,7 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。 #### 文心一言(Baidu) -文心一言所对应的 `type` 为 `baidu`。它特有的配置字段如下: - -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -|--------------------|-----------------|------|-----|-----------------------------------------------------------| -| `baiduAccessKeyAndSecret` | array of string | 必填 | - | Baidu 的 Access Key 和 Secret Key,中间用 `:` 分隔,用于申请 apiToken。 | -| `baiduApiTokenServiceName` | string | 必填 | - | 请求刷新百度 apiToken 服务名称。 | -| `baiduApiTokenServiceHost` | string | 非必填 | - | 请求刷新百度 apiToken 服务域名,默认是 iam.bj.baidubce.com。 | -| `baiduApiTokenServicePort` | int64 | 非必填 | - | 请求刷新百度 apiToken 服务端口,默认是 443。 | - +文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。 #### 360智脑 diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go index a510115f4..48f08dd9e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/config/config.go +++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go @@ -86,11 +86,6 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { providerConfig := c.GetProviderConfig() err = providerConfig.SetApiTokensFailover(log, c.activeProvider) - if handler, ok := c.activeProvider.(provider.TickFuncHandler); ok { - tickPeriod, tickFunc := handler.GetTickFunc(log) - wrapper.RegisteTickFunc(tickPeriod, tickFunc) - } - return err } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index a75a72547..f541d31fe 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -1,16 +1,9 @@ package provider import ( - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" "errors" - "fmt" "net/http" - "sort" "strings" - "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -21,28 +14,14 @@ import ( const ( baiduDomain = "qianfan.baidubce.com" baiduChatCompletionPath = "/v2/chat/completions" - baiduApiTokenDomain = "iam.bj.baidubce.com" - baiduApiTokenPort = 443 - baiduApiTokenPath = "/v1/BCE-BEARER/token" - // refresh apiToken every 1 hour - baiduApiTokenRefreshInterval = 3600 - // authorizationString expires in 30 minutes, authorizationString is used to generate apiToken - // the default expiration time of apiToken is 24 hours - baiduAuthorizationStringExpirationSeconds = 1800 - bce_prefix = "x-bce-" ) type baiduProviderInitializer struct{} func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error { - if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 { - return errors.New("no baiduAccessKeyAndSecret found in provider config") + if config.apiTokens == nil || len(config.apiTokens) == 0 { + return errors.New("no apiToken found in provider config") } - if config.baiduApiTokenServiceName == "" { - return errors.New("no baiduApiTokenServiceName found in provider config") - } - // baidu use access key and access secret to refresh apiToken regularly, the apiToken should be accessed globally (via all Wasm VMs) - config.useGlobalApiToken = true return nil } @@ -90,203 +69,3 @@ func (g *baiduProvider) GetApiName(path string) ApiName { } return "" } - -func generateAuthorizationString(accessKeyAndSecret string, expirationInSeconds int) string { - c := strings.Split(accessKeyAndSecret, ":") - credentials := BceCredentials{ - AccessKeyId: c[0], - SecretAccessKey: c[1], - } - httpMethod := "GET" - path := baiduApiTokenPath - headers := map[string]string{"host": baiduApiTokenDomain} - timestamp := time.Now().Unix() - - headersToSign := make([]string, 0, len(headers)) - for k := range headers { - headersToSign = append(headersToSign, k) - } - - return sign(credentials, httpMethod, path, headers, timestamp, expirationInSeconds, headersToSign) -} - -// BceCredentials holds the access key and secret key -type BceCredentials struct { - AccessKeyId string - SecretAccessKey string -} - -// normalizeString performs URI encoding according to RFC 3986 -func normalizeString(inStr string, encodingSlash bool) string { - if inStr == "" { - return "" - } - - var result strings.Builder - for _, ch := range []byte(inStr) { - if (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') || - (ch >= '0' && ch <= '9') || ch == '.' || ch == '-' || - ch == '_' || ch == '~' || (!encodingSlash && ch == '/') { - result.WriteByte(ch) - } else { - result.WriteString(fmt.Sprintf("%%%02X", ch)) - } - } - return result.String() -} - -// getCanonicalTime generates a timestamp in UTC format -func getCanonicalTime(timestamp int64) string { - if timestamp == 0 { - timestamp = time.Now().Unix() - } - t := time.Unix(timestamp, 0).UTC() - return t.Format("2006-01-02T15:04:05Z") -} - -// getCanonicalUri generates a canonical URI -func getCanonicalUri(path string) string { - return normalizeString(path, false) -} - -// getCanonicalHeaders generates canonical headers -func getCanonicalHeaders(headers map[string]string, headersToSign []string) string { - if len(headers) == 0 { - return "" - } - - // If headersToSign is not specified, use default headers - if len(headersToSign) == 0 { - headersToSign = []string{"host", "content-md5", "content-length", "content-type"} - } - - // Convert headersToSign to a map for easier lookup - headerMap := make(map[string]bool) - for _, header := range headersToSign { - headerMap[strings.ToLower(strings.TrimSpace(header))] = true - } - - // Create a slice to hold the canonical headers - var canonicalHeaders []string - for k, v := range headers { - k = strings.ToLower(strings.TrimSpace(k)) - v = strings.TrimSpace(v) - - // Add headers that start with x-bce- or are in headersToSign - if strings.HasPrefix(k, bce_prefix) || headerMap[k] { - canonicalHeaders = append(canonicalHeaders, - fmt.Sprintf("%s:%s", normalizeString(k, true), normalizeString(v, true))) - } - } - - // Sort the canonical headers - sort.Strings(canonicalHeaders) - - return strings.Join(canonicalHeaders, "\n") -} - -// sign generates the authorization string -func sign(credentials BceCredentials, httpMethod, path string, headers map[string]string, - timestamp int64, expirationInSeconds int, - headersToSign []string) string { - - // Generate sign key - signKeyInfo := fmt.Sprintf("bce-auth-v1/%s/%s/%d", - credentials.AccessKeyId, - getCanonicalTime(timestamp), - expirationInSeconds) - - // Generate sign key using HMAC-SHA256 - h := hmac.New(sha256.New, []byte(credentials.SecretAccessKey)) - h.Write([]byte(signKeyInfo)) - signKey := hex.EncodeToString(h.Sum(nil)) - - // Generate canonical URI - canonicalUri := getCanonicalUri(path) - - // Generate canonical headers - canonicalHeaders := getCanonicalHeaders(headers, headersToSign) - - // Generate string to sign - stringToSign := strings.Join([]string{ - httpMethod, - canonicalUri, - "", - canonicalHeaders, - }, "\n") - - // Calculate final signature - h = hmac.New(sha256.New, []byte(signKey)) - h.Write([]byte(stringToSign)) - signature := hex.EncodeToString(h.Sum(nil)) - - // Generate final authorization string - if len(headersToSign) > 0 { - return fmt.Sprintf("%s/%s/%s", signKeyInfo, strings.Join(headersToSign, ";"), signature) - } - return fmt.Sprintf("%s//%s", signKeyInfo, signature) -} - -// GetTickFunc Refresh apiToken (apiToken) periodically, the maximum apiToken expiration time is 24 hours -func (g *baiduProvider) GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) { - vmID := generateVMID() - - return baiduApiTokenRefreshInterval * 1000, func() { - // Only the Wasm VM that successfully acquires the lease will refresh the apiToken - if g.config.tryAcquireOrRenewLease(vmID, log) { - log.Debugf("Successfully acquired or renewed lease for baidu apiToken refresh task, vmID: %v", vmID) - // Get the apiToken that is about to expire, will be removed after the new apiToken is obtained - oldApiTokens, _, err := getApiTokens(g.config.failover.ctxApiTokens) - if err != nil { - log.Errorf("Get old apiToken failed: %v", err) - return - } - log.Debugf("Old apiTokens: %v", oldApiTokens) - - for _, accessKeyAndSecret := range g.config.baiduAccessKeyAndSecret { - authorizationString := generateAuthorizationString(accessKeyAndSecret, baiduAuthorizationStringExpirationSeconds) - log.Debugf("Generate authorizationString: %v", authorizationString) - g.generateNewApiToken(authorizationString, log) - } - - // remove old old apiToken - for _, token := range oldApiTokens { - log.Debugf("Remove old apiToken: %v", token) - removeApiToken(g.config.failover.ctxApiTokens, token, log) - } - } - } -} - -func (g *baiduProvider) generateNewApiToken(authorizationString string, log wrapper.Log) { - client := wrapper.NewClusterClient(wrapper.FQDNCluster{ - FQDN: g.config.baiduApiTokenServiceName, - Host: g.config.baiduApiTokenServiceHost, - Port: g.config.baiduApiTokenServicePort, - }) - - headers := [][2]string{ - {"content-type", "application/json"}, - {"Authorization", authorizationString}, - } - - var apiToken string - err := client.Get(baiduApiTokenPath, headers, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - if statusCode == 201 { - var response map[string]interface{} - err := json.Unmarshal(responseBody, &response) - if err != nil { - log.Errorf("Unmarshal response failed: %v", err) - } else { - apiToken = response["token"].(string) - addApiToken(g.config.failover.ctxApiTokens, apiToken, log) - } - } else { - log.Errorf("Get apiToken failed, status code: %d, response body: %s", statusCode, string(responseBody)) - } - }, 30000) - - if err != nil { - log.Errorf("Get apiToken failed: %v", err) - } -} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 7574224be..6c8259949 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -557,9 +557,8 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { var apiToken string - if c.isFailoverEnabled() || c.useGlobalApiToken { - // if enable apiToken failover, only use available apiToken from global apiTokens list - // or the apiToken need to be accessed globally (via all Wasm VMs, e.g. baidu), + // if enable apiToken failover, only use available apiToken from global apiTokens list + if c.isFailoverEnabled() { apiToken = c.GetGlobalRandomToken(log) } else { apiToken = c.GetRandomToken() diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 6cb3492ef..6b6239828 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -155,12 +155,6 @@ type TransformResponseBodyHandler interface { TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) } -// TickFuncHandler allows the provider to execute a function periodically -// Use case: the maximum expiration time of baidu apiToken is 24 hours, need to refresh periodically -type TickFuncHandler interface { - GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) -} - type ProviderConfig struct { // @Title zh-CN ID // @Description zh-CN AI服务提供商标识 @@ -246,17 +240,6 @@ type ProviderConfig struct { // @Title zh-CN 自定义大模型参数配置 // @Description zh-CN 用于填充或者覆盖大模型调用时的参数 customSettings []CustomSetting - // @Title zh-CN Baidu 的 Access Key 和 Secret Key,中间用 : 分隔,用于申请 apiToken - baiduAccessKeyAndSecret []string `required:"false" yaml:"baiduAccessKeyAndSecret" json:"baiduAccessKeyAndSecret"` - // @Title zh-CN 请求刷新百度 apiToken 服务名称 - baiduApiTokenServiceName string `required:"false" yaml:"baiduApiTokenServiceName" json:"baiduApiTokenServiceName"` - // @Title zh-CN 请求刷新百度 apiToken 服务域名 - baiduApiTokenServiceHost string `required:"false" yaml:"baiduApiTokenServiceHost" json:"baiduApiTokenServiceHost"` - // @Title zh-CN 请求刷新百度 apiToken 服务端口 - baiduApiTokenServicePort int64 `required:"false" yaml:"baiduApiTokenServicePort" json:"baiduApiTokenServicePort"` - // @Title zh-CN 是否使用全局的 apiToken - // @Description zh-CN 如果没有启用 apiToken failover,但是 apiToken 的状态又需要在多个 Wasm VM 中同步时需要将该参数设置为 true,例如 Baidu 的 apiToken 需要定时刷新 - useGlobalApiToken bool `required:"false" yaml:"useGlobalApiToken" json:"useGlobalApiToken"` } func (c *ProviderConfig) GetId() string { @@ -364,19 +347,6 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if retryOnFailureJson.Exists() { c.retryOnFailure.FromJson(retryOnFailureJson) } - - for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() { - c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String()) - } - c.baiduApiTokenServiceName = json.Get("baiduApiTokenServiceName").String() - c.baiduApiTokenServiceHost = json.Get("baiduApiTokenServiceHost").String() - if c.baiduApiTokenServiceHost == "" { - c.baiduApiTokenServiceHost = baiduApiTokenDomain - } - c.baiduApiTokenServicePort = json.Get("baiduApiTokenServicePort").Int() - if c.baiduApiTokenServicePort == 0 { - c.baiduApiTokenServicePort = baiduApiTokenPort - } } func (c *ProviderConfig) Validate() error {