mirror of
https://github.com/alibaba/higress.git
synced 2026-03-16 16:30:47 +08:00
feat: support baidu api key (#1687)
This commit is contained in:
@@ -157,15 +157,7 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
|
||||
|
||||
#### 文心一言(Baidu)
|
||||
|
||||
文心一言所对应的 `type` 为 `baidu`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------|-----|-----------------------------------------------------------|
|
||||
| `baiduAccessKeyAndSecret` | array of string | 必填 | - | Baidu 的 Access Key 和 Secret Key,中间用 `:` 分隔,用于申请 apiToken。 |
|
||||
| `baiduApiTokenServiceName` | string | 必填 | - | 请求刷新百度 apiToken 服务名称。 |
|
||||
| `baiduApiTokenServiceHost` | string | 非必填 | - | 请求刷新百度 apiToken 服务域名,默认是 iam.bj.baidubce.com。 |
|
||||
| `baiduApiTokenServicePort` | int64 | 非必填 | - | 请求刷新百度 apiToken 服务端口,默认是 443。 |
|
||||
|
||||
文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。
|
||||
|
||||
#### 360智脑
|
||||
|
||||
|
||||
@@ -86,11 +86,6 @@ func (c *PluginConfig) Complete(log wrapper.Log) error {
|
||||
providerConfig := c.GetProviderConfig()
|
||||
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)
|
||||
|
||||
if handler, ok := c.activeProvider.(provider.TickFuncHandler); ok {
|
||||
tickPeriod, tickFunc := handler.GetTickFunc(log)
|
||||
wrapper.RegisteTickFunc(tickPeriod, tickFunc)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
@@ -21,28 +14,14 @@ import (
|
||||
const (
|
||||
baiduDomain = "qianfan.baidubce.com"
|
||||
baiduChatCompletionPath = "/v2/chat/completions"
|
||||
baiduApiTokenDomain = "iam.bj.baidubce.com"
|
||||
baiduApiTokenPort = 443
|
||||
baiduApiTokenPath = "/v1/BCE-BEARER/token"
|
||||
// refresh apiToken every 1 hour
|
||||
baiduApiTokenRefreshInterval = 3600
|
||||
// authorizationString expires in 30 minutes, authorizationString is used to generate apiToken
|
||||
// the default expiration time of apiToken is 24 hours
|
||||
baiduAuthorizationStringExpirationSeconds = 1800
|
||||
bce_prefix = "x-bce-"
|
||||
)
|
||||
|
||||
type baiduProviderInitializer struct{}
|
||||
|
||||
func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 {
|
||||
return errors.New("no baiduAccessKeyAndSecret found in provider config")
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
}
|
||||
if config.baiduApiTokenServiceName == "" {
|
||||
return errors.New("no baiduApiTokenServiceName found in provider config")
|
||||
}
|
||||
// baidu use access key and access secret to refresh apiToken regularly, the apiToken should be accessed globally (via all Wasm VMs)
|
||||
config.useGlobalApiToken = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -90,203 +69,3 @@ func (g *baiduProvider) GetApiName(path string) ApiName {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func generateAuthorizationString(accessKeyAndSecret string, expirationInSeconds int) string {
|
||||
c := strings.Split(accessKeyAndSecret, ":")
|
||||
credentials := BceCredentials{
|
||||
AccessKeyId: c[0],
|
||||
SecretAccessKey: c[1],
|
||||
}
|
||||
httpMethod := "GET"
|
||||
path := baiduApiTokenPath
|
||||
headers := map[string]string{"host": baiduApiTokenDomain}
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
headersToSign := make([]string, 0, len(headers))
|
||||
for k := range headers {
|
||||
headersToSign = append(headersToSign, k)
|
||||
}
|
||||
|
||||
return sign(credentials, httpMethod, path, headers, timestamp, expirationInSeconds, headersToSign)
|
||||
}
|
||||
|
||||
// BceCredentials holds the access key and secret key
|
||||
type BceCredentials struct {
|
||||
AccessKeyId string
|
||||
SecretAccessKey string
|
||||
}
|
||||
|
||||
// normalizeString performs URI encoding according to RFC 3986
|
||||
func normalizeString(inStr string, encodingSlash bool) string {
|
||||
if inStr == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for _, ch := range []byte(inStr) {
|
||||
if (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') ||
|
||||
(ch >= '0' && ch <= '9') || ch == '.' || ch == '-' ||
|
||||
ch == '_' || ch == '~' || (!encodingSlash && ch == '/') {
|
||||
result.WriteByte(ch)
|
||||
} else {
|
||||
result.WriteString(fmt.Sprintf("%%%02X", ch))
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// getCanonicalTime generates a timestamp in UTC format
|
||||
func getCanonicalTime(timestamp int64) string {
|
||||
if timestamp == 0 {
|
||||
timestamp = time.Now().Unix()
|
||||
}
|
||||
t := time.Unix(timestamp, 0).UTC()
|
||||
return t.Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
|
||||
// getCanonicalUri generates a canonical URI
|
||||
func getCanonicalUri(path string) string {
|
||||
return normalizeString(path, false)
|
||||
}
|
||||
|
||||
// getCanonicalHeaders generates canonical headers
|
||||
func getCanonicalHeaders(headers map[string]string, headersToSign []string) string {
|
||||
if len(headers) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// If headersToSign is not specified, use default headers
|
||||
if len(headersToSign) == 0 {
|
||||
headersToSign = []string{"host", "content-md5", "content-length", "content-type"}
|
||||
}
|
||||
|
||||
// Convert headersToSign to a map for easier lookup
|
||||
headerMap := make(map[string]bool)
|
||||
for _, header := range headersToSign {
|
||||
headerMap[strings.ToLower(strings.TrimSpace(header))] = true
|
||||
}
|
||||
|
||||
// Create a slice to hold the canonical headers
|
||||
var canonicalHeaders []string
|
||||
for k, v := range headers {
|
||||
k = strings.ToLower(strings.TrimSpace(k))
|
||||
v = strings.TrimSpace(v)
|
||||
|
||||
// Add headers that start with x-bce- or are in headersToSign
|
||||
if strings.HasPrefix(k, bce_prefix) || headerMap[k] {
|
||||
canonicalHeaders = append(canonicalHeaders,
|
||||
fmt.Sprintf("%s:%s", normalizeString(k, true), normalizeString(v, true)))
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the canonical headers
|
||||
sort.Strings(canonicalHeaders)
|
||||
|
||||
return strings.Join(canonicalHeaders, "\n")
|
||||
}
|
||||
|
||||
// sign generates the authorization string
|
||||
func sign(credentials BceCredentials, httpMethod, path string, headers map[string]string,
|
||||
timestamp int64, expirationInSeconds int,
|
||||
headersToSign []string) string {
|
||||
|
||||
// Generate sign key
|
||||
signKeyInfo := fmt.Sprintf("bce-auth-v1/%s/%s/%d",
|
||||
credentials.AccessKeyId,
|
||||
getCanonicalTime(timestamp),
|
||||
expirationInSeconds)
|
||||
|
||||
// Generate sign key using HMAC-SHA256
|
||||
h := hmac.New(sha256.New, []byte(credentials.SecretAccessKey))
|
||||
h.Write([]byte(signKeyInfo))
|
||||
signKey := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
// Generate canonical URI
|
||||
canonicalUri := getCanonicalUri(path)
|
||||
|
||||
// Generate canonical headers
|
||||
canonicalHeaders := getCanonicalHeaders(headers, headersToSign)
|
||||
|
||||
// Generate string to sign
|
||||
stringToSign := strings.Join([]string{
|
||||
httpMethod,
|
||||
canonicalUri,
|
||||
"",
|
||||
canonicalHeaders,
|
||||
}, "\n")
|
||||
|
||||
// Calculate final signature
|
||||
h = hmac.New(sha256.New, []byte(signKey))
|
||||
h.Write([]byte(stringToSign))
|
||||
signature := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
// Generate final authorization string
|
||||
if len(headersToSign) > 0 {
|
||||
return fmt.Sprintf("%s/%s/%s", signKeyInfo, strings.Join(headersToSign, ";"), signature)
|
||||
}
|
||||
return fmt.Sprintf("%s//%s", signKeyInfo, signature)
|
||||
}
|
||||
|
||||
// GetTickFunc Refresh apiToken (apiToken) periodically, the maximum apiToken expiration time is 24 hours
|
||||
func (g *baiduProvider) GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) {
|
||||
vmID := generateVMID()
|
||||
|
||||
return baiduApiTokenRefreshInterval * 1000, func() {
|
||||
// Only the Wasm VM that successfully acquires the lease will refresh the apiToken
|
||||
if g.config.tryAcquireOrRenewLease(vmID, log) {
|
||||
log.Debugf("Successfully acquired or renewed lease for baidu apiToken refresh task, vmID: %v", vmID)
|
||||
// Get the apiToken that is about to expire, will be removed after the new apiToken is obtained
|
||||
oldApiTokens, _, err := getApiTokens(g.config.failover.ctxApiTokens)
|
||||
if err != nil {
|
||||
log.Errorf("Get old apiToken failed: %v", err)
|
||||
return
|
||||
}
|
||||
log.Debugf("Old apiTokens: %v", oldApiTokens)
|
||||
|
||||
for _, accessKeyAndSecret := range g.config.baiduAccessKeyAndSecret {
|
||||
authorizationString := generateAuthorizationString(accessKeyAndSecret, baiduAuthorizationStringExpirationSeconds)
|
||||
log.Debugf("Generate authorizationString: %v", authorizationString)
|
||||
g.generateNewApiToken(authorizationString, log)
|
||||
}
|
||||
|
||||
// remove old old apiToken
|
||||
for _, token := range oldApiTokens {
|
||||
log.Debugf("Remove old apiToken: %v", token)
|
||||
removeApiToken(g.config.failover.ctxApiTokens, token, log)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *baiduProvider) generateNewApiToken(authorizationString string, log wrapper.Log) {
|
||||
client := wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: g.config.baiduApiTokenServiceName,
|
||||
Host: g.config.baiduApiTokenServiceHost,
|
||||
Port: g.config.baiduApiTokenServicePort,
|
||||
})
|
||||
|
||||
headers := [][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{"Authorization", authorizationString},
|
||||
}
|
||||
|
||||
var apiToken string
|
||||
err := client.Get(baiduApiTokenPath, headers, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode == 201 {
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("Unmarshal response failed: %v", err)
|
||||
} else {
|
||||
apiToken = response["token"].(string)
|
||||
addApiToken(g.config.failover.ctxApiTokens, apiToken, log)
|
||||
}
|
||||
} else {
|
||||
log.Errorf("Get apiToken failed, status code: %d, response body: %s", statusCode, string(responseBody))
|
||||
}
|
||||
}, 30000)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("Get apiToken failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -557,9 +557,8 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
||||
|
||||
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
var apiToken string
|
||||
if c.isFailoverEnabled() || c.useGlobalApiToken {
|
||||
// if enable apiToken failover, only use available apiToken from global apiTokens list
|
||||
// or the apiToken need to be accessed globally (via all Wasm VMs, e.g. baidu),
|
||||
// if enable apiToken failover, only use available apiToken from global apiTokens list
|
||||
if c.isFailoverEnabled() {
|
||||
apiToken = c.GetGlobalRandomToken(log)
|
||||
} else {
|
||||
apiToken = c.GetRandomToken()
|
||||
|
||||
@@ -155,12 +155,6 @@ type TransformResponseBodyHandler interface {
|
||||
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||
}
|
||||
|
||||
// TickFuncHandler allows the provider to execute a function periodically
|
||||
// Use case: the maximum expiration time of baidu apiToken is 24 hours, need to refresh periodically
|
||||
type TickFuncHandler interface {
|
||||
GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func())
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
// @Title zh-CN ID
|
||||
// @Description zh-CN AI服务提供商标识
|
||||
@@ -246,17 +240,6 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 自定义大模型参数配置
|
||||
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
|
||||
customSettings []CustomSetting
|
||||
// @Title zh-CN Baidu 的 Access Key 和 Secret Key,中间用 : 分隔,用于申请 apiToken
|
||||
baiduAccessKeyAndSecret []string `required:"false" yaml:"baiduAccessKeyAndSecret" json:"baiduAccessKeyAndSecret"`
|
||||
// @Title zh-CN 请求刷新百度 apiToken 服务名称
|
||||
baiduApiTokenServiceName string `required:"false" yaml:"baiduApiTokenServiceName" json:"baiduApiTokenServiceName"`
|
||||
// @Title zh-CN 请求刷新百度 apiToken 服务域名
|
||||
baiduApiTokenServiceHost string `required:"false" yaml:"baiduApiTokenServiceHost" json:"baiduApiTokenServiceHost"`
|
||||
// @Title zh-CN 请求刷新百度 apiToken 服务端口
|
||||
baiduApiTokenServicePort int64 `required:"false" yaml:"baiduApiTokenServicePort" json:"baiduApiTokenServicePort"`
|
||||
// @Title zh-CN 是否使用全局的 apiToken
|
||||
// @Description zh-CN 如果没有启用 apiToken failover,但是 apiToken 的状态又需要在多个 Wasm VM 中同步时需要将该参数设置为 true,例如 Baidu 的 apiToken 需要定时刷新
|
||||
useGlobalApiToken bool `required:"false" yaml:"useGlobalApiToken" json:"useGlobalApiToken"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -364,19 +347,6 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
if retryOnFailureJson.Exists() {
|
||||
c.retryOnFailure.FromJson(retryOnFailureJson)
|
||||
}
|
||||
|
||||
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
|
||||
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
|
||||
}
|
||||
c.baiduApiTokenServiceName = json.Get("baiduApiTokenServiceName").String()
|
||||
c.baiduApiTokenServiceHost = json.Get("baiduApiTokenServiceHost").String()
|
||||
if c.baiduApiTokenServiceHost == "" {
|
||||
c.baiduApiTokenServiceHost = baiduApiTokenDomain
|
||||
}
|
||||
c.baiduApiTokenServicePort = json.Get("baiduApiTokenServicePort").Int()
|
||||
if c.baiduApiTokenServicePort == 0 {
|
||||
c.baiduApiTokenServicePort = baiduApiTokenPort
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
|
||||
Reference in New Issue
Block a user