mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
feat: implement apiToken failover mechanism (#1256)
This commit is contained in:
594
plugins/wasm-go/extensions/ai-proxy/provider/failover.go
Normal file
594
plugins/wasm-go/extensions/ai-proxy/provider/failover.go
Normal file
@@ -0,0 +1,594 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/google/uuid"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type failover struct {
|
||||
// @Title zh-CN 是否启用 apiToken 的 failover 机制
|
||||
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
|
||||
// @Title zh-CN 触发 failover 连续请求失败的阈值
|
||||
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
|
||||
// @Title zh-CN 健康检测的成功阈值
|
||||
successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"`
|
||||
// @Title zh-CN 健康检测的间隔时间,单位毫秒
|
||||
healthCheckInterval int64 `required:"false" yaml:"healthCheckInterval" json:"healthCheckInterval"`
|
||||
// @Title zh-CN 健康检测的超时时间,单位毫秒
|
||||
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
|
||||
// @Title zh-CN 健康检测使用的模型
|
||||
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
|
||||
// @Title zh-CN 本次请求使用的 apiToken
|
||||
ctxApiTokenInUse string
|
||||
// @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数
|
||||
ctxApiTokenRequestFailureCount string
|
||||
// @Title zh-CN 记录 apiToken 健康检测成功的次数,key 为 apiToken,value 为成功次数
|
||||
ctxApiTokenRequestSuccessCount string
|
||||
// @Title zh-CN 记录所有可用的 apiToken 列表
|
||||
ctxApiTokens string
|
||||
// @Title zh-CN 记录所有不可用的 apiToken 列表
|
||||
ctxUnavailableApiTokens string
|
||||
// @Title zh-CN 记录请求的 cluster, host 和 path,用于在健康检测时构建请求
|
||||
ctxHealthCheckEndpoint string
|
||||
// @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
|
||||
ctxVmLease string
|
||||
}
|
||||
|
||||
type Lease struct {
|
||||
VMID string `json:"vmID"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type HealthCheckEndpoint struct {
|
||||
Host string `json:"host"`
|
||||
Path string `json:"path"`
|
||||
Cluster string `json:"cluster"`
|
||||
}
|
||||
|
||||
const (
|
||||
casMaxRetries = 10
|
||||
addApiTokenOperation = "addApiToken"
|
||||
removeApiTokenOperation = "removeApiToken"
|
||||
addApiTokenRequestCountOperation = "addApiTokenRequestCount"
|
||||
resetApiTokenRequestCountOperation = "resetApiTokenRequestCount"
|
||||
ctxRequestHost = "requestHost"
|
||||
ctxRequestPath = "requestPath"
|
||||
)
|
||||
|
||||
var (
|
||||
healthCheckClient wrapper.HttpClient
|
||||
)
|
||||
|
||||
func (f *failover) FromJson(json gjson.Result) {
|
||||
f.enabled = json.Get("enabled").Bool()
|
||||
f.failureThreshold = json.Get("failureThreshold").Int()
|
||||
if f.failureThreshold == 0 {
|
||||
f.failureThreshold = 3
|
||||
}
|
||||
f.successThreshold = json.Get("successThreshold").Int()
|
||||
if f.successThreshold == 0 {
|
||||
f.successThreshold = 1
|
||||
}
|
||||
f.healthCheckInterval = json.Get("healthCheckInterval").Int()
|
||||
if f.healthCheckInterval == 0 {
|
||||
f.healthCheckInterval = 5000
|
||||
}
|
||||
f.healthCheckTimeout = json.Get("healthCheckTimeout").Int()
|
||||
if f.healthCheckTimeout == 0 {
|
||||
f.healthCheckTimeout = 5000
|
||||
}
|
||||
f.healthCheckModel = json.Get("healthCheckModel").String()
|
||||
}
|
||||
|
||||
func (f *failover) Validate() error {
|
||||
if f.healthCheckModel == "" {
|
||||
return errors.New("missing healthCheckModel in failover config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) initVariable() {
|
||||
// Set provider name as prefix to differentiate shared data
|
||||
provider := c.GetType()
|
||||
c.failover.ctxApiTokenInUse = provider + "-apiTokenInUse"
|
||||
c.failover.ctxApiTokenRequestFailureCount = provider + "-apiTokenRequestFailureCount"
|
||||
c.failover.ctxApiTokenRequestSuccessCount = provider + "-apiTokenRequestSuccessCount"
|
||||
c.failover.ctxApiTokens = provider + "-apiTokens"
|
||||
c.failover.ctxUnavailableApiTokens = provider + "-unavailableApiTokens"
|
||||
c.failover.ctxHealthCheckEndpoint = provider + "-requestHostAndPath"
|
||||
c.failover.ctxVmLease = provider + "-vmLease"
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *any, log wrapper.Log) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Provider) error {
|
||||
c.initVariable()
|
||||
// Reset shared data in case plugin configuration is updated
|
||||
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
|
||||
c.resetSharedData()
|
||||
|
||||
if c.isFailoverEnabled() {
|
||||
log.Debugf("ai-proxy plugin failover is enabled")
|
||||
|
||||
vmID := generateVMID()
|
||||
err := c.initApiTokens()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init apiTokens: %v", err)
|
||||
}
|
||||
|
||||
wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
|
||||
// Only the Wasm VM that successfully acquires the lease will perform health check
|
||||
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID, log) {
|
||||
log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())
|
||||
unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get unavailable tokens: %v", err)
|
||||
return
|
||||
}
|
||||
if len(unavailableTokens) > 0 {
|
||||
for _, apiToken := range unavailableTokens {
|
||||
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
|
||||
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody(log)
|
||||
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
|
||||
Host: healthCheckEndpoint.Host,
|
||||
Cluster: healthCheckEndpoint.Cluster,
|
||||
})
|
||||
|
||||
ctx := createHttpContext()
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body, log)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to transform request headers and body: %v", err)
|
||||
}
|
||||
|
||||
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
|
||||
err = healthCheckClient.Post(healthCheckEndpoint.Path, modifiedHeaders, modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode == 200 {
|
||||
c.handleAvailableApiToken(apiToken, log)
|
||||
}
|
||||
}, uint32(c.failover.healthCheckTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to perform health check request: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte, log wrapper.Log) ([][2]string, []byte, error) {
|
||||
originalHeaders := util.SliceToHeader(headers)
|
||||
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
|
||||
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, originalHeaders, log)
|
||||
}
|
||||
|
||||
var err error
|
||||
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetOriginalHttpHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
|
||||
util.ReplaceOriginalHttpHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to transform request body: %v", err)
|
||||
}
|
||||
|
||||
modifiedHeaders := util.HeaderToSlice(originalHeaders)
|
||||
return modifiedHeaders, body, nil
|
||||
}
|
||||
|
||||
func createHttpContext() *wrapper.CommonHttpCtx[any] {
|
||||
setParseConfig := wrapper.ParseConfigBy[any](parseConfig)
|
||||
vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig)
|
||||
pluginCtx := vmCtx.NewPluginContext(rand.Uint32())
|
||||
ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any])
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) {
|
||||
data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get request host and path: %v", err)
|
||||
}
|
||||
var healthCheckEndpoint HealthCheckEndpoint
|
||||
err = json.Unmarshal(data, &healthCheckEndpoint)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to unmarshal request host and path: %v", err)
|
||||
}
|
||||
|
||||
headers := [][2]string{
|
||||
{"content-type", "application/json"},
|
||||
}
|
||||
body := []byte(fmt.Sprintf(`{
|
||||
"model": "%s",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "who are you?"
|
||||
}
|
||||
]
|
||||
}`, c.failover.healthCheckModel))
|
||||
return healthCheckEndpoint, headers, body
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool {
|
||||
now := time.Now().Unix()
|
||||
|
||||
data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return c.setLease(vmID, now, cas, log)
|
||||
} else {
|
||||
log.Errorf("Failed to get lease: %v", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
if data == nil {
|
||||
return c.setLease(vmID, now, cas, log)
|
||||
}
|
||||
|
||||
var lease Lease
|
||||
err = json.Unmarshal(data, &lease)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to unmarshal lease data: %v", err)
|
||||
return false
|
||||
}
|
||||
// If vmID is itself, try to renew the lease directly
|
||||
// If the lease is expired (60s), try to acquire the lease
|
||||
if lease.VMID == vmID || now-lease.Timestamp > 60 {
|
||||
lease.VMID = vmID
|
||||
lease.Timestamp = now
|
||||
return c.setLease(vmID, now, cas, log)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool {
|
||||
lease := Lease{
|
||||
VMID: vmID,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
leaseByte, err := json.Marshal(lease)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal lease data: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := proxywasm.SetSharedData(c.failover.ctxVmLease, leaseByte, cas); err != nil {
|
||||
log.Errorf("Failed to set or renew lease: %v", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func generateVMID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// When number of request successes exceeds the threshold during health check,
|
||||
// add the apiToken back to the available list and remove it from the unavailable list
|
||||
func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) {
|
||||
successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get successApiTokenRequestCount: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
successCount := successApiTokenRequestCount[apiToken] + 1
|
||||
if successCount >= c.failover.successThreshold {
|
||||
log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken)
|
||||
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
|
||||
addApiToken(c.failover.ctxApiTokens, apiToken, log)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
|
||||
} else {
|
||||
log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount)
|
||||
addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
|
||||
}
|
||||
}
|
||||
|
||||
// When number of request failures exceeds the threshold,
|
||||
// remove the apiToken from the available list and add it to the unavailable list
|
||||
func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string, log wrapper.Log) {
|
||||
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get failureApiTokenRequestCount: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
availableTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get available apiToken: %v", err)
|
||||
return
|
||||
}
|
||||
// unavailable apiToken has been removed from the available list
|
||||
if !containsElement(availableTokens, apiToken) {
|
||||
return
|
||||
}
|
||||
|
||||
failureCount := failureApiTokenRequestCount[apiToken] + 1
|
||||
if failureCount >= c.failover.failureThreshold {
|
||||
log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken)
|
||||
removeApiToken(c.failover.ctxApiTokens, apiToken, log)
|
||||
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
|
||||
// Set the request host and path to shared data in case they are needed in apiToken health check
|
||||
c.setHealthCheckEndpoint(ctx, log)
|
||||
} else {
|
||||
log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount)
|
||||
addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
|
||||
}
|
||||
}
|
||||
|
||||
func addApiToken(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiToken(key, apiToken, addApiTokenOperation, log)
|
||||
}
|
||||
|
||||
func removeApiToken(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiToken(key, apiToken, removeApiTokenOperation, log)
|
||||
}
|
||||
|
||||
func modifyApiToken(key, apiToken, op string, log wrapper.Log) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
apiTokens, cas, err := getApiTokens(key)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
|
||||
exists := containsElement(apiTokens, apiToken)
|
||||
if op == addApiTokenOperation && exists {
|
||||
log.Debugf("%s already exists in %s", apiToken, key)
|
||||
return
|
||||
} else if op == removeApiTokenOperation && !exists {
|
||||
log.Debugf("%s does not exist in %s", apiToken, key)
|
||||
return
|
||||
}
|
||||
|
||||
if op == addApiTokenOperation {
|
||||
apiTokens = append(apiTokens, apiToken)
|
||||
} else {
|
||||
apiTokens = removeElement(apiTokens, apiToken)
|
||||
}
|
||||
|
||||
if err := setApiTokens(key, apiTokens, cas); err == nil {
|
||||
log.Debugf("Successfully updated %s in %s", apiToken, key)
|
||||
return
|
||||
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
|
||||
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("CAS mismatch when setting %s, retrying...", key)
|
||||
}
|
||||
}
|
||||
|
||||
func getApiTokens(key string) ([]string, uint32, error) {
|
||||
data, cas, err := proxywasm.GetSharedData(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return []string{}, cas, nil
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
if data == nil {
|
||||
return []string{}, cas, nil
|
||||
}
|
||||
|
||||
var apiTokens []string
|
||||
if err = json.Unmarshal(data, &apiTokens); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to unmarshal tokens: %v", err)
|
||||
}
|
||||
|
||||
return apiTokens, cas, nil
|
||||
}
|
||||
|
||||
func setApiTokens(key string, apiTokens []string, cas uint32) error {
|
||||
data, err := json.Marshal(apiTokens)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal tokens: %v", err)
|
||||
}
|
||||
return proxywasm.SetSharedData(key, data, cas)
|
||||
}
|
||||
|
||||
func removeElement(slice []string, s string) []string {
|
||||
for i := 0; i < len(slice); i++ {
|
||||
if slice[i] == s {
|
||||
slice = append(slice[:i], slice[i+1:]...)
|
||||
i--
|
||||
}
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
func containsElement(slice []string, s string) bool {
|
||||
for _, item := range slice {
|
||||
if item == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) {
|
||||
data, cas, err := proxywasm.GetSharedData(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return make(map[string]int64), cas, nil
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
return make(map[string]int64), cas, nil
|
||||
}
|
||||
|
||||
var apiTokens map[string]int64
|
||||
err = json.Unmarshal(data, &apiTokens)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return apiTokens, cas, nil
|
||||
}
|
||||
|
||||
func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log)
|
||||
}
|
||||
|
||||
func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) {
|
||||
if c.isFailoverEnabled() {
|
||||
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get failureApiTokenRequestCount: %v", err)
|
||||
}
|
||||
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
|
||||
log.Infof("reset apiToken %s request failure count", apiTokenInUse)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
apiTokenRequestCount, cas, err := getApiTokenRequestCount(key)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if op == resetApiTokenRequestCountOperation {
|
||||
delete(apiTokenRequestCount, apiToken)
|
||||
} else {
|
||||
apiTokenRequestCount[apiToken]++
|
||||
}
|
||||
|
||||
apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal apiTokenRequestCount: %v", err)
|
||||
}
|
||||
|
||||
if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil {
|
||||
log.Debugf("Successfully updated the count of %s in %s", apiToken, key)
|
||||
return
|
||||
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
|
||||
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("CAS mismatch when setting %s, retrying...", key)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) initApiTokens() error {
|
||||
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string {
|
||||
apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
|
||||
unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
|
||||
log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens)
|
||||
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
count := len(apiTokens)
|
||||
switch count {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return apiTokens[0]
|
||||
default:
|
||||
return apiTokens[rand.Intn(count)]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) isFailoverEnabled() bool {
|
||||
return c.failover.enabled
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) resetSharedData() {
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
|
||||
if c.isFailoverEnabled() {
|
||||
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
||||
return ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
var apiToken string
|
||||
if c.isFailoverEnabled() {
|
||||
// if enable apiToken failover, only use available apiToken
|
||||
apiToken = c.GetGlobalRandomToken(log)
|
||||
} else {
|
||||
apiToken = c.GetRandomToken()
|
||||
}
|
||||
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
cluster, err := proxywasm.GetProperty([]string{"cluster_name"})
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get cluster_name: %v", err)
|
||||
}
|
||||
|
||||
host := wrapper.GetRequestHost()
|
||||
if host == "" {
|
||||
host = ctx.GetContext(ctxRequestHost).(string)
|
||||
}
|
||||
path := wrapper.GetRequestPath()
|
||||
if path == "" {
|
||||
path = ctx.GetContext(ctxRequestPath).(string)
|
||||
}
|
||||
|
||||
healthCheckEndpoint := HealthCheckEndpoint{
|
||||
Host: host,
|
||||
Path: path,
|
||||
Cluster: string(cluster),
|
||||
}
|
||||
|
||||
healthCheckEndpointByte, err := json.Marshal(healthCheckEndpoint)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal request host and path: %v", err)
|
||||
|
||||
}
|
||||
err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, healthCheckEndpointByte, 0)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to set request host and path: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user