mirror of
https://github.com/alibaba/higress.git
synced 2026-03-06 17:40:51 +08:00
294 lines
9.1 KiB
Go
294 lines
9.1 KiB
Go
package provider
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
)
|
|
|
|
// baiduProvider is the provider for baidu service.
|
|
const (
|
|
baiduDomain = "qianfan.baidubce.com"
|
|
baiduChatCompletionPath = "/v2/chat/completions"
|
|
baiduApiTokenDomain = "iam.bj.baidubce.com"
|
|
baiduApiTokenPort = 443
|
|
baiduApiTokenPath = "/v1/BCE-BEARER/token"
|
|
// refresh apiToken every 1 hour
|
|
baiduApiTokenRefreshInterval = 3600
|
|
// authorizationString expires in 30 minutes, authorizationString is used to generate apiToken
|
|
// the default expiration time of apiToken is 24 hours
|
|
baiduAuthorizationStringExpirationSeconds = 1800
|
|
bce_prefix = "x-bce-"
|
|
)
|
|
|
|
type baiduProviderInitializer struct{}
|
|
|
|
func (g *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
|
if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 {
|
|
return errors.New("no baiduAccessKeyAndSecret found in provider config")
|
|
}
|
|
if config.baiduApiTokenServiceName == "" {
|
|
return errors.New("no baiduApiTokenServiceName found in provider config")
|
|
}
|
|
if !config.failover.enabled {
|
|
config.useGlobalApiToken = true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
|
return &baiduProvider{
|
|
config: config,
|
|
contextCache: createContextCache(&config),
|
|
}, nil
|
|
}
|
|
|
|
type baiduProvider struct {
|
|
config ProviderConfig
|
|
contextCache *contextCache
|
|
}
|
|
|
|
func (g *baiduProvider) GetProviderType() string {
|
|
return providerTypeBaidu
|
|
}
|
|
|
|
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
|
if apiName != ApiNameChatCompletion {
|
|
return types.ActionContinue, errUnsupportedApiName
|
|
}
|
|
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
|
return types.ActionContinue, nil
|
|
}
|
|
|
|
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
|
if apiName != ApiNameChatCompletion {
|
|
return types.ActionContinue, errUnsupportedApiName
|
|
}
|
|
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
|
}
|
|
|
|
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
|
util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath)
|
|
util.OverwriteRequestHostHeader(headers, baiduDomain)
|
|
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
|
|
headers.Del("Content-Length")
|
|
}
|
|
|
|
func (g *baiduProvider) GetApiName(path string) ApiName {
|
|
if strings.Contains(path, baiduChatCompletionPath) {
|
|
return ApiNameChatCompletion
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func generateAuthorizationString(accessKeyAndSecret string, expirationInSeconds int) string {
|
|
c := strings.Split(accessKeyAndSecret, ":")
|
|
credentials := BceCredentials{
|
|
AccessKeyId: c[0],
|
|
SecretAccessKey: c[1],
|
|
}
|
|
httpMethod := "GET"
|
|
path := baiduApiTokenPath
|
|
headers := map[string]string{"host": baiduApiTokenDomain}
|
|
timestamp := time.Now().Unix()
|
|
|
|
headersToSign := make([]string, 0, len(headers))
|
|
for k := range headers {
|
|
headersToSign = append(headersToSign, k)
|
|
}
|
|
|
|
return sign(credentials, httpMethod, path, headers, timestamp, expirationInSeconds, headersToSign)
|
|
}
|
|
|
|
// BceCredentials holds the access key and secret key
|
|
type BceCredentials struct {
|
|
AccessKeyId string
|
|
SecretAccessKey string
|
|
}
|
|
|
|
// normalizeString performs URI encoding according to RFC 3986
|
|
func normalizeString(inStr string, encodingSlash bool) string {
|
|
if inStr == "" {
|
|
return ""
|
|
}
|
|
|
|
var result strings.Builder
|
|
for _, ch := range []byte(inStr) {
|
|
if (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') ||
|
|
(ch >= '0' && ch <= '9') || ch == '.' || ch == '-' ||
|
|
ch == '_' || ch == '~' || (!encodingSlash && ch == '/') {
|
|
result.WriteByte(ch)
|
|
} else {
|
|
result.WriteString(fmt.Sprintf("%%%02X", ch))
|
|
}
|
|
}
|
|
return result.String()
|
|
}
|
|
|
|
// getCanonicalTime generates a timestamp in UTC format
|
|
func getCanonicalTime(timestamp int64) string {
|
|
if timestamp == 0 {
|
|
timestamp = time.Now().Unix()
|
|
}
|
|
t := time.Unix(timestamp, 0).UTC()
|
|
return t.Format("2006-01-02T15:04:05Z")
|
|
}
|
|
|
|
// getCanonicalUri generates a canonical URI
|
|
func getCanonicalUri(path string) string {
|
|
return normalizeString(path, false)
|
|
}
|
|
|
|
// getCanonicalHeaders generates canonical headers
|
|
func getCanonicalHeaders(headers map[string]string, headersToSign []string) string {
|
|
if len(headers) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// If headersToSign is not specified, use default headers
|
|
if len(headersToSign) == 0 {
|
|
headersToSign = []string{"host", "content-md5", "content-length", "content-type"}
|
|
}
|
|
|
|
// Convert headersToSign to a map for easier lookup
|
|
headerMap := make(map[string]bool)
|
|
for _, header := range headersToSign {
|
|
headerMap[strings.ToLower(strings.TrimSpace(header))] = true
|
|
}
|
|
|
|
// Create a slice to hold the canonical headers
|
|
var canonicalHeaders []string
|
|
for k, v := range headers {
|
|
k = strings.ToLower(strings.TrimSpace(k))
|
|
v = strings.TrimSpace(v)
|
|
|
|
// Add headers that start with x-bce- or are in headersToSign
|
|
if strings.HasPrefix(k, bce_prefix) || headerMap[k] {
|
|
canonicalHeaders = append(canonicalHeaders,
|
|
fmt.Sprintf("%s:%s", normalizeString(k, true), normalizeString(v, true)))
|
|
}
|
|
}
|
|
|
|
// Sort the canonical headers
|
|
sort.Strings(canonicalHeaders)
|
|
|
|
return strings.Join(canonicalHeaders, "\n")
|
|
}
|
|
|
|
// sign generates the authorization string
|
|
func sign(credentials BceCredentials, httpMethod, path string, headers map[string]string,
|
|
timestamp int64, expirationInSeconds int,
|
|
headersToSign []string) string {
|
|
|
|
// Generate sign key
|
|
signKeyInfo := fmt.Sprintf("bce-auth-v1/%s/%s/%d",
|
|
credentials.AccessKeyId,
|
|
getCanonicalTime(timestamp),
|
|
expirationInSeconds)
|
|
|
|
// Generate sign key using HMAC-SHA256
|
|
h := hmac.New(sha256.New, []byte(credentials.SecretAccessKey))
|
|
h.Write([]byte(signKeyInfo))
|
|
signKey := hex.EncodeToString(h.Sum(nil))
|
|
|
|
// Generate canonical URI
|
|
canonicalUri := getCanonicalUri(path)
|
|
|
|
// Generate canonical headers
|
|
canonicalHeaders := getCanonicalHeaders(headers, headersToSign)
|
|
|
|
// Generate string to sign
|
|
stringToSign := strings.Join([]string{
|
|
httpMethod,
|
|
canonicalUri,
|
|
"",
|
|
canonicalHeaders,
|
|
}, "\n")
|
|
|
|
// Calculate final signature
|
|
h = hmac.New(sha256.New, []byte(signKey))
|
|
h.Write([]byte(stringToSign))
|
|
signature := hex.EncodeToString(h.Sum(nil))
|
|
|
|
// Generate final authorization string
|
|
if len(headersToSign) > 0 {
|
|
return fmt.Sprintf("%s/%s/%s", signKeyInfo, strings.Join(headersToSign, ";"), signature)
|
|
}
|
|
return fmt.Sprintf("%s//%s", signKeyInfo, signature)
|
|
}
|
|
|
|
// GetTickFunc Refresh apiToken (apiToken) periodically, the maximum apiToken expiration time is 24 hours
|
|
func (g *baiduProvider) GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) {
|
|
vmID := generateVMID()
|
|
|
|
return baiduApiTokenRefreshInterval * 1000, func() {
|
|
// Only the Wasm VM that successfully acquires the lease will refresh the apiToken
|
|
if g.config.tryAcquireOrRenewLease(vmID, log) {
|
|
log.Debugf("Successfully acquired or renewed lease for baidu apiToken refresh task, vmID: %v", vmID)
|
|
// Get the apiToken that is about to expire, will be removed after the new apiToken is obtained
|
|
oldApiTokens, _, err := getApiTokens(g.config.failover.ctxApiTokens)
|
|
if err != nil {
|
|
log.Errorf("Get old apiToken failed: %v", err)
|
|
return
|
|
}
|
|
log.Debugf("Old apiTokens: %v", oldApiTokens)
|
|
|
|
for _, accessKeyAndSecret := range g.config.baiduAccessKeyAndSecret {
|
|
authorizationString := generateAuthorizationString(accessKeyAndSecret, baiduAuthorizationStringExpirationSeconds)
|
|
log.Debugf("Generate authorizationString: %v", authorizationString)
|
|
g.generateNewApiToken(authorizationString, log)
|
|
}
|
|
|
|
// remove old old apiToken
|
|
for _, token := range oldApiTokens {
|
|
log.Debugf("Remove old apiToken: %v", token)
|
|
removeApiToken(g.config.failover.ctxApiTokens, token, log)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (g *baiduProvider) generateNewApiToken(authorizationString string, log wrapper.Log) {
|
|
client := wrapper.NewClusterClient(wrapper.FQDNCluster{
|
|
FQDN: g.config.baiduApiTokenServiceName,
|
|
Host: g.config.baiduApiTokenServiceHost,
|
|
Port: g.config.baiduApiTokenServicePort,
|
|
})
|
|
|
|
headers := [][2]string{
|
|
{"content-type", "application/json"},
|
|
{"Authorization", authorizationString},
|
|
}
|
|
|
|
var apiToken string
|
|
err := client.Get(baiduApiTokenPath, headers, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
if statusCode == 201 {
|
|
var response map[string]interface{}
|
|
err := json.Unmarshal(responseBody, &response)
|
|
if err != nil {
|
|
log.Errorf("Unmarshal response failed: %v", err)
|
|
} else {
|
|
apiToken = response["token"].(string)
|
|
addApiToken(g.config.failover.ctxApiTokens, apiToken, log)
|
|
}
|
|
} else {
|
|
log.Errorf("Get apiToken failed, status code: %d, response body: %s", statusCode, string(responseBody))
|
|
}
|
|
}, 30000)
|
|
|
|
if err != nil {
|
|
log.Errorf("Get apiToken failed: %v", err)
|
|
}
|
|
}
|