mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 20:57:32 +08:00
feat: migrate baidu provider to v2 api (#1527)
This commit is contained in:
@@ -148,7 +148,15 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
|
|||||||
|
|
||||||
#### 文心一言(Baidu)
|
#### 文心一言(Baidu)
|
||||||
|
|
||||||
文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。
|
文心一言所对应的 `type` 为 `baidu`。它特有的配置字段如下:
|
||||||
|
|
||||||
|
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||||
|
|--------------------|-----------------|------|-----|-----------------------------------------------------------|
|
||||||
|
| `baiduAccessKeyAndSecret` | array of string | 必填 | - | Baidu 的 Access Key 和 Secret Key,中间用 `:` 分隔,用于申请 apiToken。 |
|
||||||
|
| `baiduApiTokenServiceName` | string | 必填 | - | 请求刷新百度 apiToken 服务名称。 |
|
||||||
|
| `baiduApiTokenServiceHost` | string | 非必填 | - | 请求刷新百度 apiToken 服务域名,默认是 iam.bj.baidubce.com。 |
|
||||||
|
| `baiduApiTokenServicePort` | int64 | 非必填 | - | 请求刷新百度 apiToken 服务端口,默认是 443。 |
|
||||||
|
|
||||||
|
|
||||||
#### 360智脑
|
#### 360智脑
|
||||||
|
|
||||||
|
|||||||
@@ -86,6 +86,11 @@ func (c *PluginConfig) Complete(log wrapper.Log) error {
|
|||||||
providerConfig := c.GetProviderConfig()
|
providerConfig := c.GetProviderConfig()
|
||||||
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)
|
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)
|
||||||
|
|
||||||
|
if handler, ok := c.activeProvider.(provider.TickFuncHandler); ok {
|
||||||
|
tickPeriod, tickFunc := handler.GetTickFunc(log)
|
||||||
|
wrapper.RegisteTickFunc(tickPeriod, tickFunc)
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,48 +1,53 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// baiduProvider is the provider for baidu ernie bot service.
|
// baiduProvider is the provider for baidu service.
|
||||||
|
|
||||||
const (
|
const (
|
||||||
baiduDomain = "aip.baidubce.com"
|
baiduDomain = "qianfan.baidubce.com"
|
||||||
baiduChatCompletionPath = "/chat"
|
baiduChatCompletionPath = "/v2/chat/completions"
|
||||||
|
baiduApiTokenDomain = "iam.bj.baidubce.com"
|
||||||
|
baiduApiTokenPort = 443
|
||||||
|
baiduApiTokenPath = "/v1/BCE-BEARER/token"
|
||||||
|
// refresh apiToken every 1 hour
|
||||||
|
baiduApiTokenRefreshInterval = 3600
|
||||||
|
// authorizationString expires in 30 minutes, authorizationString is used to generate apiToken
|
||||||
|
// the default expiration time of apiToken is 24 hours
|
||||||
|
baiduAuthorizationStringExpirationSeconds = 1800
|
||||||
|
bce_prefix = "x-bce-"
|
||||||
)
|
)
|
||||||
|
|
||||||
var baiduModelToPathSuffixMap = map[string]string{
|
type baiduProviderInitializer struct{}
|
||||||
"ERNIE-4.0-8K": "completions_pro",
|
|
||||||
"ERNIE-3.5-8K": "completions",
|
|
||||||
"ERNIE-3.5-128K": "ernie-3.5-128k",
|
|
||||||
"ERNIE-Speed-8K": "ernie_speed",
|
|
||||||
"ERNIE-Speed-128K": "ernie-speed-128k",
|
|
||||||
"ERNIE-Tiny-8K": "ernie-tiny-8k",
|
|
||||||
"ERNIE-Bot-8K": "ernie_bot_8k",
|
|
||||||
"BLOOMZ-7B": "bloomz_7b1",
|
|
||||||
}
|
|
||||||
|
|
||||||
type baiduProviderInitializer struct {
|
func (g *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||||
|
if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 {
|
||||||
|
return errors.New("no baiduAccessKeyAndSecret found in provider config")
|
||||||
}
|
}
|
||||||
|
if config.baiduApiTokenServiceName == "" {
|
||||||
func (b *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
return errors.New("no baiduApiTokenServiceName found in provider config")
|
||||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
}
|
||||||
return errors.New("no apiToken found in provider config")
|
if !config.failover.enabled {
|
||||||
|
config.useGlobalApiToken = true
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||||
return &baiduProvider{
|
return &baiduProvider{
|
||||||
config: config,
|
config: config,
|
||||||
contextCache: createContextCache(&config),
|
contextCache: createContextCache(&config),
|
||||||
@@ -54,234 +59,235 @@ type baiduProvider struct {
|
|||||||
contextCache *contextCache
|
contextCache *contextCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *baiduProvider) GetProviderType() string {
|
func (g *baiduProvider) GetProviderType() string {
|
||||||
return providerTypeBaidu
|
return providerTypeBaidu
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||||
if apiName != ApiNameChatCompletion {
|
if apiName != ApiNameChatCompletion {
|
||||||
return types.ActionContinue, errUnsupportedApiName
|
return types.ActionContinue, errUnsupportedApiName
|
||||||
}
|
}
|
||||||
b.config.handleRequestHeaders(b, ctx, apiName, log)
|
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
return types.ActionContinue, nil
|
||||||
return types.HeaderStopIteration, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||||
|
if apiName != ApiNameChatCompletion {
|
||||||
|
return types.ActionContinue, errUnsupportedApiName
|
||||||
|
}
|
||||||
|
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||||
|
util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath)
|
||||||
util.OverwriteRequestHostHeader(headers, baiduDomain)
|
util.OverwriteRequestHostHeader(headers, baiduDomain)
|
||||||
headers.Del("Accept-Encoding")
|
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
|
||||||
headers.Del("Content-Length")
|
headers.Del("Content-Length")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
func (g *baiduProvider) GetApiName(path string) ApiName {
|
||||||
if apiName != ApiNameChatCompletion {
|
|
||||||
return types.ActionContinue, errUnsupportedApiName
|
|
||||||
}
|
|
||||||
return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body, log)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
|
||||||
request := &chatCompletionRequest{}
|
|
||||||
err := b.config.parseRequestAndMapModel(ctx, request, body, log)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
path := b.getRequestPath(ctx, request.Model)
|
|
||||||
util.OverwriteRequestPathHeader(headers, path)
|
|
||||||
|
|
||||||
baiduRequest := b.baiduTextGenRequest(request)
|
|
||||||
return json.Marshal(baiduRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
|
||||||
// 使用文心一言接口协议,跳过OnStreamingResponseBody()和OnResponseBody()
|
|
||||||
if b.config.protocol == protocolOriginal {
|
|
||||||
ctx.DontReadResponseBody()
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
|
||||||
return types.ActionContinue, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
|
||||||
if isLastChunk || len(chunk) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
// sample event response:
|
|
||||||
// data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089502,"sentence_id":0,"is_end":false,"is_truncated":false,"result":"当然可以,","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}
|
|
||||||
|
|
||||||
// sample end event response:
|
|
||||||
// data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089531,"sentence_id":20,"is_end":true,"is_truncated":false,"result":"","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":420,"total_tokens":425}}
|
|
||||||
responseBuilder := &strings.Builder{}
|
|
||||||
lines := strings.Split(string(chunk), "\n")
|
|
||||||
for _, data := range lines {
|
|
||||||
if len(data) < 6 {
|
|
||||||
// ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[6:]
|
|
||||||
var baiduResponse baiduTextGenStreamResponse
|
|
||||||
if err := json.Unmarshal([]byte(data), &baiduResponse); err != nil {
|
|
||||||
log.Errorf("unable to unmarshal baidu response: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
response := b.streamResponseBaidu2OpenAI(ctx, &baiduResponse)
|
|
||||||
responseBody, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to marshal response: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
b.appendResponse(responseBuilder, string(responseBody))
|
|
||||||
}
|
|
||||||
modifiedResponseChunk := responseBuilder.String()
|
|
||||||
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
|
|
||||||
return []byte(modifiedResponseChunk), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
|
||||||
baiduResponse := &baiduTextGenResponse{}
|
|
||||||
if err := json.Unmarshal(body, baiduResponse); err != nil {
|
|
||||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal baidu response: %v", err)
|
|
||||||
}
|
|
||||||
if baiduResponse.ErrorMsg != "" {
|
|
||||||
return types.ActionContinue, fmt.Errorf("baidu response error, error_code: %d, error_message: %s", baiduResponse.ErrorCode, baiduResponse.ErrorMsg)
|
|
||||||
}
|
|
||||||
response := b.responseBaidu2OpenAI(ctx, baiduResponse)
|
|
||||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
|
||||||
}
|
|
||||||
|
|
||||||
type baiduTextGenRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []chatMessage `json:"messages"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
System string `json:"system,omitempty"`
|
|
||||||
DisableSearch bool `json:"disable_search,omitempty"`
|
|
||||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
|
||||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
|
||||||
UserId string `json:"user_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string {
|
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
|
||||||
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
|
|
||||||
if !ok {
|
|
||||||
suffix = baiduModel
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
|
|
||||||
request.System = content
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) baiduTextGenRequest(request *chatCompletionRequest) *baiduTextGenRequest {
|
|
||||||
baiduRequest := baiduTextGenRequest{
|
|
||||||
Messages: make([]chatMessage, 0, len(request.Messages)),
|
|
||||||
Temperature: request.Temperature,
|
|
||||||
TopP: request.TopP,
|
|
||||||
PenaltyScore: request.FrequencyPenalty,
|
|
||||||
Stream: request.Stream,
|
|
||||||
DisableSearch: false,
|
|
||||||
EnableCitation: false,
|
|
||||||
MaxOutputTokens: request.MaxTokens,
|
|
||||||
UserId: request.User,
|
|
||||||
}
|
|
||||||
for _, message := range request.Messages {
|
|
||||||
if message.Role == roleSystem {
|
|
||||||
baiduRequest.System = message.StringContent()
|
|
||||||
} else {
|
|
||||||
baiduRequest.Messages = append(baiduRequest.Messages, chatMessage{
|
|
||||||
Role: message.Role,
|
|
||||||
Content: message.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &baiduRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
type baiduTextGenResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
IsTruncated bool `json:"is_truncated"`
|
|
||||||
NeedClearHistory bool `json:"need_clear_history"`
|
|
||||||
Usage baiduTextGenResponseUsage `json:"usage"`
|
|
||||||
baiduTextGenResponseError
|
|
||||||
}
|
|
||||||
|
|
||||||
type baiduTextGenResponseError struct {
|
|
||||||
ErrorCode int `json:"error_code"`
|
|
||||||
ErrorMsg string `json:"error_msg"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type baiduTextGenStreamResponse struct {
|
|
||||||
baiduTextGenResponse
|
|
||||||
SentenceId int `json:"sentence_id"`
|
|
||||||
IsEnd bool `json:"is_end"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type baiduTextGenResponseUsage struct {
|
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenResponse) *chatCompletionResponse {
|
|
||||||
choice := chatCompletionChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: &chatMessage{Role: roleAssistant, Content: response.Result},
|
|
||||||
FinishReason: finishReasonStop,
|
|
||||||
}
|
|
||||||
return &chatCompletionResponse{
|
|
||||||
Id: response.Id,
|
|
||||||
Created: time.Now().UnixMilli() / 1000,
|
|
||||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
|
||||||
SystemFingerprint: "",
|
|
||||||
Object: objectChatCompletion,
|
|
||||||
Choices: []chatCompletionChoice{choice},
|
|
||||||
Usage: usage{
|
|
||||||
PromptTokens: response.Usage.PromptTokens,
|
|
||||||
CompletionTokens: response.Usage.CompletionTokens,
|
|
||||||
TotalTokens: response.Usage.TotalTokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenStreamResponse) *chatCompletionResponse {
|
|
||||||
choice := chatCompletionChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: &chatMessage{Role: roleAssistant, Content: response.Result},
|
|
||||||
}
|
|
||||||
if response.IsEnd {
|
|
||||||
choice.FinishReason = finishReasonStop
|
|
||||||
}
|
|
||||||
return &chatCompletionResponse{
|
|
||||||
Id: response.Id,
|
|
||||||
Created: time.Now().UnixMilli() / 1000,
|
|
||||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
|
||||||
SystemFingerprint: "",
|
|
||||||
Object: objectChatCompletionChunk,
|
|
||||||
Choices: []chatCompletionChoice{choice},
|
|
||||||
Usage: usage{
|
|
||||||
PromptTokens: response.Usage.PromptTokens,
|
|
||||||
CompletionTokens: response.Usage.CompletionTokens,
|
|
||||||
TotalTokens: response.Usage.TotalTokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
|
||||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *baiduProvider) GetApiName(path string) ApiName {
|
|
||||||
if strings.Contains(path, baiduChatCompletionPath) {
|
if strings.Contains(path, baiduChatCompletionPath) {
|
||||||
return ApiNameChatCompletion
|
return ApiNameChatCompletion
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateAuthorizationString(accessKeyAndSecret string, expirationInSeconds int) string {
|
||||||
|
c := strings.Split(accessKeyAndSecret, ":")
|
||||||
|
credentials := BceCredentials{
|
||||||
|
AccessKeyId: c[0],
|
||||||
|
SecretAccessKey: c[1],
|
||||||
|
}
|
||||||
|
httpMethod := "GET"
|
||||||
|
path := baiduApiTokenPath
|
||||||
|
headers := map[string]string{"host": baiduApiTokenDomain}
|
||||||
|
timestamp := time.Now().Unix()
|
||||||
|
|
||||||
|
headersToSign := make([]string, 0, len(headers))
|
||||||
|
for k := range headers {
|
||||||
|
headersToSign = append(headersToSign, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sign(credentials, httpMethod, path, headers, timestamp, expirationInSeconds, headersToSign)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BceCredentials holds the access key and secret key
|
||||||
|
type BceCredentials struct {
|
||||||
|
AccessKeyId string
|
||||||
|
SecretAccessKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeString performs URI encoding according to RFC 3986
|
||||||
|
func normalizeString(inStr string, encodingSlash bool) string {
|
||||||
|
if inStr == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var result strings.Builder
|
||||||
|
for _, ch := range []byte(inStr) {
|
||||||
|
if (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') ||
|
||||||
|
(ch >= '0' && ch <= '9') || ch == '.' || ch == '-' ||
|
||||||
|
ch == '_' || ch == '~' || (!encodingSlash && ch == '/') {
|
||||||
|
result.WriteByte(ch)
|
||||||
|
} else {
|
||||||
|
result.WriteString(fmt.Sprintf("%%%02X", ch))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCanonicalTime generates a timestamp in UTC format
|
||||||
|
func getCanonicalTime(timestamp int64) string {
|
||||||
|
if timestamp == 0 {
|
||||||
|
timestamp = time.Now().Unix()
|
||||||
|
}
|
||||||
|
t := time.Unix(timestamp, 0).UTC()
|
||||||
|
return t.Format("2006-01-02T15:04:05Z")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCanonicalUri generates a canonical URI
|
||||||
|
func getCanonicalUri(path string) string {
|
||||||
|
return normalizeString(path, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCanonicalHeaders generates canonical headers
|
||||||
|
func getCanonicalHeaders(headers map[string]string, headersToSign []string) string {
|
||||||
|
if len(headers) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// If headersToSign is not specified, use default headers
|
||||||
|
if len(headersToSign) == 0 {
|
||||||
|
headersToSign = []string{"host", "content-md5", "content-length", "content-type"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert headersToSign to a map for easier lookup
|
||||||
|
headerMap := make(map[string]bool)
|
||||||
|
for _, header := range headersToSign {
|
||||||
|
headerMap[strings.ToLower(strings.TrimSpace(header))] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a slice to hold the canonical headers
|
||||||
|
var canonicalHeaders []string
|
||||||
|
for k, v := range headers {
|
||||||
|
k = strings.ToLower(strings.TrimSpace(k))
|
||||||
|
v = strings.TrimSpace(v)
|
||||||
|
|
||||||
|
// Add headers that start with x-bce- or are in headersToSign
|
||||||
|
if strings.HasPrefix(k, bce_prefix) || headerMap[k] {
|
||||||
|
canonicalHeaders = append(canonicalHeaders,
|
||||||
|
fmt.Sprintf("%s:%s", normalizeString(k, true), normalizeString(v, true)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the canonical headers
|
||||||
|
sort.Strings(canonicalHeaders)
|
||||||
|
|
||||||
|
return strings.Join(canonicalHeaders, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// sign generates the authorization string
|
||||||
|
func sign(credentials BceCredentials, httpMethod, path string, headers map[string]string,
|
||||||
|
timestamp int64, expirationInSeconds int,
|
||||||
|
headersToSign []string) string {
|
||||||
|
|
||||||
|
// Generate sign key
|
||||||
|
signKeyInfo := fmt.Sprintf("bce-auth-v1/%s/%s/%d",
|
||||||
|
credentials.AccessKeyId,
|
||||||
|
getCanonicalTime(timestamp),
|
||||||
|
expirationInSeconds)
|
||||||
|
|
||||||
|
// Generate sign key using HMAC-SHA256
|
||||||
|
h := hmac.New(sha256.New, []byte(credentials.SecretAccessKey))
|
||||||
|
h.Write([]byte(signKeyInfo))
|
||||||
|
signKey := hex.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
|
// Generate canonical URI
|
||||||
|
canonicalUri := getCanonicalUri(path)
|
||||||
|
|
||||||
|
// Generate canonical headers
|
||||||
|
canonicalHeaders := getCanonicalHeaders(headers, headersToSign)
|
||||||
|
|
||||||
|
// Generate string to sign
|
||||||
|
stringToSign := strings.Join([]string{
|
||||||
|
httpMethod,
|
||||||
|
canonicalUri,
|
||||||
|
"",
|
||||||
|
canonicalHeaders,
|
||||||
|
}, "\n")
|
||||||
|
|
||||||
|
// Calculate final signature
|
||||||
|
h = hmac.New(sha256.New, []byte(signKey))
|
||||||
|
h.Write([]byte(stringToSign))
|
||||||
|
signature := hex.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
|
// Generate final authorization string
|
||||||
|
if len(headersToSign) > 0 {
|
||||||
|
return fmt.Sprintf("%s/%s/%s", signKeyInfo, strings.Join(headersToSign, ";"), signature)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s//%s", signKeyInfo, signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTickFunc Refresh apiToken (apiToken) periodically, the maximum apiToken expiration time is 24 hours
|
||||||
|
func (g *baiduProvider) GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) {
|
||||||
|
vmID := generateVMID()
|
||||||
|
|
||||||
|
return baiduApiTokenRefreshInterval * 1000, func() {
|
||||||
|
// Only the Wasm VM that successfully acquires the lease will refresh the apiToken
|
||||||
|
if g.config.tryAcquireOrRenewLease(vmID, log) {
|
||||||
|
log.Debugf("Successfully acquired or renewed lease for baidu apiToken refresh task, vmID: %v", vmID)
|
||||||
|
// Get the apiToken that is about to expire, will be removed after the new apiToken is obtained
|
||||||
|
oldApiTokens, _, err := getApiTokens(g.config.failover.ctxApiTokens)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Get old apiToken failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("Old apiTokens: %v", oldApiTokens)
|
||||||
|
|
||||||
|
for _, accessKeyAndSecret := range g.config.baiduAccessKeyAndSecret {
|
||||||
|
authorizationString := generateAuthorizationString(accessKeyAndSecret, baiduAuthorizationStringExpirationSeconds)
|
||||||
|
log.Debugf("Generate authorizationString: %v", authorizationString)
|
||||||
|
g.generateNewApiToken(authorizationString, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove old old apiToken
|
||||||
|
for _, token := range oldApiTokens {
|
||||||
|
log.Debugf("Remove old apiToken: %v", token)
|
||||||
|
removeApiToken(g.config.failover.ctxApiTokens, token, log)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *baiduProvider) generateNewApiToken(authorizationString string, log wrapper.Log) {
|
||||||
|
client := wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
|
FQDN: g.config.baiduApiTokenServiceName,
|
||||||
|
Host: g.config.baiduApiTokenServiceHost,
|
||||||
|
Port: g.config.baiduApiTokenServicePort,
|
||||||
|
})
|
||||||
|
|
||||||
|
headers := [][2]string{
|
||||||
|
{"content-type", "application/json"},
|
||||||
|
{"Authorization", authorizationString},
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiToken string
|
||||||
|
err := client.Get(baiduApiTokenPath, headers, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
if statusCode == 201 {
|
||||||
|
var response map[string]interface{}
|
||||||
|
err := json.Unmarshal(responseBody, &response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Unmarshal response failed: %v", err)
|
||||||
|
} else {
|
||||||
|
apiToken = response["token"].(string)
|
||||||
|
addApiToken(g.config.failover.ctxApiTokens, apiToken, log)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Errorf("Get apiToken failed, status code: %d, response body: %s", statusCode, string(responseBody))
|
||||||
|
}
|
||||||
|
}, 30000)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Get apiToken failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -467,7 +467,7 @@ func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string,
|
|||||||
log.Errorf("failed to get failureApiTokenRequestCount: %v", err)
|
log.Errorf("failed to get failureApiTokenRequestCount: %v", err)
|
||||||
}
|
}
|
||||||
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
|
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
|
||||||
log.Infof("reset apiToken %s request failure count", apiTokenInUse)
|
log.Infof("Reset apiToken %s request failure count", apiTokenInUse)
|
||||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
|
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -489,7 +489,7 @@ func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log
|
|||||||
|
|
||||||
apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount)
|
apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to marshal apiTokenRequestCount: %v", err)
|
log.Errorf("Failed to marshal apiTokenRequestCount: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil {
|
if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil {
|
||||||
@@ -551,7 +551,7 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
|||||||
|
|
||||||
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
|
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||||
var apiToken string
|
var apiToken string
|
||||||
if c.isFailoverEnabled() {
|
if c.isFailoverEnabled() || c.useGlobalApiToken {
|
||||||
// if enable apiToken failover, only use available apiToken
|
// if enable apiToken failover, only use available apiToken
|
||||||
apiToken = c.GetGlobalRandomToken(log)
|
apiToken = c.GetGlobalRandomToken(log)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -151,6 +151,12 @@ type ResponseBodyHandler interface {
|
|||||||
OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TickFuncHandler allows the provider to execute a function periodically
|
||||||
|
// Use case: the maximum expiration time of baidu apiToken is 24 hours, need to refresh periodically
|
||||||
|
type TickFuncHandler interface {
|
||||||
|
GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func())
|
||||||
|
}
|
||||||
|
|
||||||
type ProviderConfig struct {
|
type ProviderConfig struct {
|
||||||
// @Title zh-CN ID
|
// @Title zh-CN ID
|
||||||
// @Description zh-CN AI服务提供商标识
|
// @Description zh-CN AI服务提供商标识
|
||||||
@@ -227,6 +233,17 @@ type ProviderConfig struct {
|
|||||||
// @Title zh-CN 自定义大模型参数配置
|
// @Title zh-CN 自定义大模型参数配置
|
||||||
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
|
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
|
||||||
customSettings []CustomSetting
|
customSettings []CustomSetting
|
||||||
|
// @Title zh-CN Baidu 的 Access Key 和 Secret Key,中间用 : 分隔,用于申请 apiToken
|
||||||
|
baiduAccessKeyAndSecret []string `required:"false" yaml:"baiduAccessKeyAndSecret" json:"baiduAccessKeyAndSecret"`
|
||||||
|
// @Title zh-CN 请求刷新百度 apiToken 服务名称
|
||||||
|
baiduApiTokenServiceName string `required:"false" yaml:"baiduApiTokenServiceName" json:"baiduApiTokenServiceName"`
|
||||||
|
// @Title zh-CN 请求刷新百度 apiToken 服务域名
|
||||||
|
baiduApiTokenServiceHost string `required:"false" yaml:"baiduApiTokenServiceHost" json:"baiduApiTokenServiceHost"`
|
||||||
|
// @Title zh-CN 请求刷新百度 apiToken 服务端口
|
||||||
|
baiduApiTokenServicePort int64 `required:"false" yaml:"baiduApiTokenServicePort" json:"baiduApiTokenServicePort"`
|
||||||
|
// @Title zh-CN 是否使用全局的 apiToken
|
||||||
|
// @Description zh-CN 如果没有启用 apiToken failover,但是 apiToken 的状态又需要在多个 Wasm VM 中同步时需要将该参数设置为 true,例如 Baidu 的 apiToken 需要定时刷新
|
||||||
|
useGlobalApiToken bool `required:"false" yaml:"useGlobalApiToken" json:"useGlobalApiToken"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) GetId() string {
|
func (c *ProviderConfig) GetId() string {
|
||||||
@@ -321,6 +338,19 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
|||||||
if failoverJson.Exists() {
|
if failoverJson.Exists() {
|
||||||
c.failover.FromJson(failoverJson)
|
c.failover.FromJson(failoverJson)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
|
||||||
|
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
|
||||||
|
}
|
||||||
|
c.baiduApiTokenServiceName = json.Get("baiduApiTokenServiceName").String()
|
||||||
|
c.baiduApiTokenServiceHost = json.Get("baiduApiTokenServiceHost").String()
|
||||||
|
if c.baiduApiTokenServiceHost == "" {
|
||||||
|
c.baiduApiTokenServiceHost = baiduApiTokenDomain
|
||||||
|
}
|
||||||
|
c.baiduApiTokenServicePort = json.Get("baiduApiTokenServicePort").Int()
|
||||||
|
if c.baiduApiTokenServicePort == 0 {
|
||||||
|
c.baiduApiTokenServicePort = baiduApiTokenPort
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) Validate() error {
|
func (c *ProviderConfig) Validate() error {
|
||||||
|
|||||||
Reference in New Issue
Block a user