feat(ai-proxy): add cooldownDuration support for failover token recovery (#3700)

Signed-off-by: wydream <yaodiwu618@gmail.com>
Signed-off-by: woody <yaodiwu618@gmail.com>
This commit is contained in:
woody
2026-05-20 18:11:11 +08:00
committed by GitHub
parent e7651f3d3e
commit 739d47ba9c
6 changed files with 890 additions and 48 deletions

View File

@@ -31,6 +31,8 @@ type failover struct {
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
// @Title zh-CN 健康检测使用的模型
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
// @Title zh-CN apiToken 不可用后的冷却恢复时间,单位毫秒,配置后无需健康检测即可自动恢复
cooldownDuration int64 `required:"false" yaml:"cooldownDuration" json:"cooldownDuration"`
// @Title zh-CN 需要进行 failover 的原始请求的状态码,支持正则表达式匹配
failoverOnStatus []string `required:"false" yaml:"failoverOnStatus" json:"failoverOnStatus"`
// @Title zh-CN 本次请求使用的 apiToken
@@ -49,6 +51,8 @@ type failover struct {
ctxHealthCheckEndpoint string
// @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
ctxVmLease string
// @Title zh-CN 记录 apiToken 被标记为不可用的时间戳,用于冷却恢复
ctxApiTokenUnavailableSince string
}
type Lease struct {
@@ -96,6 +100,7 @@ func (f *failover) FromJson(json gjson.Result) {
f.healthCheckTimeout = 5000
}
f.healthCheckModel = json.Get("healthCheckModel").String()
f.cooldownDuration = json.Get("cooldownDuration").Int()
for _, status := range json.Get("failoverOnStatus").Array() {
f.failoverOnStatus = append(f.failoverOnStatus, status.String())
@@ -107,8 +112,8 @@ func (f *failover) FromJson(json gjson.Result) {
}
func (f *failover) Validate() error {
if f.healthCheckModel == "" {
return errors.New("missing healthCheckModel in failover config")
if f.healthCheckModel == "" && f.cooldownDuration <= 0 {
return errors.New("either healthCheckModel or cooldownDuration must be configured in failover config")
}
return nil
}
@@ -124,6 +129,7 @@ func (c *ProviderConfig) initVariable() {
c.failover.ctxUnavailableApiTokens = provider + "-" + id + "-unavailableApiTokens"
c.failover.ctxHealthCheckEndpoint = provider + "-" + id + "-requestHostAndPath"
c.failover.ctxVmLease = provider + "-" + id + "-vmLease"
c.failover.ctxApiTokenUnavailableSince = provider + "-" + id + "-apiTokenUnavailableSince"
}
func parseConfig(json gjson.Result, config *any) error {
@@ -132,7 +138,8 @@ func parseConfig(json gjson.Result, config *any) error {
func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
c.initVariable()
// Reset shared data in case plugin configuration is updated
// Reset failover shared data on config updates so stale cooldown/health-check
// state from the previous config does not leak into the new one.
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
c.resetSharedData()
@@ -156,29 +163,57 @@ func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
return
}
if len(unavailableTokens) > 0 {
for _, apiToken := range unavailableTokens {
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
Cluster: healthCheckEndpoint.Cluster,
})
ctx := createHttpContext()
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
// Cooldown recovery: restore tokens whose cooldown period has elapsed
if c.failover.cooldownDuration > 0 {
timestamps, _, err := getApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince)
if err != nil {
log.Errorf("Failed to transform request headers and body: %v", err)
}
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 200 {
c.handleAvailableApiToken(apiToken)
log.Errorf("Failed to get apiToken unavailable timestamps: %v", err)
} else {
now := time.Now().UnixMilli()
var recoveredTokens []string
for _, apiToken := range unavailableTokens {
if since, ok := timestamps[apiToken]; ok && now-since >= c.failover.cooldownDuration {
log.Infof("cooldown recovery: apiToken %s has cooled down for %dms, restoring to available list", apiToken, now-since)
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
addApiToken(c.failover.ctxApiTokens, apiToken)
removeApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
recoveredTokens = append(recoveredTokens, apiToken)
}
}
// Remove recovered tokens from the list to skip health check for them
for _, token := range recoveredTokens {
unavailableTokens = removeElement(unavailableTokens, token)
}
}
}
// Health check: probe remaining unavailable tokens with a real request
if c.failover.healthCheckModel != "" && len(unavailableTokens) > 0 {
for _, apiToken := range unavailableTokens {
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
Cluster: healthCheckEndpoint.Cluster,
})
ctx := createHttpContext()
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
if err != nil {
log.Errorf("Failed to transform request headers and body: %v", err)
}
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 200 {
c.handleAvailableApiToken(apiToken)
}
}, uint32(c.failover.healthCheckTimeout))
if err != nil {
log.Errorf("Failed to perform health check request: %v", err)
}
}, uint32(c.failover.healthCheckTimeout))
if err != nil {
log.Errorf("Failed to perform health check request: %v", err)
}
}
}
@@ -355,6 +390,10 @@ func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiT
removeApiToken(c.failover.ctxApiTokens, apiToken)
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
// Record the time when the apiToken becomes unavailable for cooldown recovery
if c.failover.cooldownDuration > 0 {
setApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken, time.Now().UnixMilli())
}
// Set the request host and path to shared data in case they are needed in apiToken health check
c.setHealthCheckEndpoint(ctx)
} else {
@@ -527,7 +566,76 @@ func modifyApiTokenRequestCount(key, apiToken string, op string) {
}
func (c *ProviderConfig) initApiTokens() error {
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
_, cas, _ := getApiTokens(c.failover.ctxApiTokens)
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, cas)
}
func getApiTokenUnavailableSince(key string) (map[string]int64, uint32, error) {
data, cas, err := proxywasm.GetSharedData(key)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return make(map[string]int64), cas, nil
}
return nil, 0, err
}
if data == nil {
return make(map[string]int64), cas, nil
}
var timestamps map[string]int64
if err = json.Unmarshal(data, &timestamps); err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal unavailableSince: %v", err)
}
return timestamps, cas, nil
}
func setApiTokenUnavailableSince(key, apiToken string, timestamp int64) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
timestamps, cas, err := getApiTokenUnavailableSince(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
timestamps[apiToken] = timestamp
data, err := json.Marshal(timestamps)
if err != nil {
log.Errorf("Failed to marshal unavailableSince: %v", err)
return
}
if err := proxywasm.SetSharedData(key, data, cas); err == nil {
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}
func removeApiTokenUnavailableSince(key, apiToken string) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
timestamps, cas, err := getApiTokenUnavailableSince(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
if _, ok := timestamps[apiToken]; !ok {
return
}
delete(timestamps, apiToken)
data, err := json.Marshal(timestamps)
if err != nil {
log.Errorf("Failed to marshal unavailableSince: %v", err)
return
}
if err := proxywasm.SetSharedData(key, data, cas); err == nil {
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}
func (c *ProviderConfig) GetGlobalRandomToken() string {
@@ -571,11 +679,16 @@ func (c *ProviderConfig) isFailoverEnabled() bool {
}
func (c *ProviderConfig) resetSharedData() {
// In the real proxy-wasm host, cas=0 means "ignore CAS and overwrite"
// instead of "match CAS=0". We rely on that behavior here so config updates
// can unconditionally clear previous shared data state.
_ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenUnavailableSince, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, nil, 0)
}
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string) types.Action {

View File

@@ -538,6 +538,50 @@ func TestFailover_FromJson_Defaults(t *testing.T) {
assert.Equal(t, int64(8000), f.healthCheckTimeout)
assert.Equal(t, "test-model", f.healthCheckModel)
})
t.Run("cooldown_duration_default", func(t *testing.T) {
f := &failover{}
jsonStr := `{"enabled": true}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, int64(0), f.cooldownDuration)
})
t.Run("cooldown_duration_custom", func(t *testing.T) {
f := &failover{}
jsonStr := `{"enabled": true, "cooldownDuration": 60000}`
f.FromJson(gjson.Parse(jsonStr))
assert.Equal(t, int64(60000), f.cooldownDuration)
})
}
func TestFailover_Validate(t *testing.T) {
t.Run("only_healthCheckModel", func(t *testing.T) {
f := &failover{healthCheckModel: "gpt-3.5-turbo"}
assert.NoError(t, f.Validate())
})
t.Run("only_cooldownDuration", func(t *testing.T) {
f := &failover{cooldownDuration: 60000}
assert.NoError(t, f.Validate())
})
t.Run("both_healthCheckModel_and_cooldownDuration", func(t *testing.T) {
f := &failover{healthCheckModel: "gpt-3.5-turbo", cooldownDuration: 60000}
assert.NoError(t, f.Validate())
})
t.Run("neither_healthCheckModel_nor_cooldownDuration", func(t *testing.T) {
f := &failover{}
err := f.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "either healthCheckModel or cooldownDuration")
})
t.Run("negative_cooldownDuration", func(t *testing.T) {
f := &failover{cooldownDuration: -1}
err := f.Validate()
assert.Error(t, err)
})
}
func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
@@ -565,19 +609,6 @@ func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
})
}
func TestFailover_Validate(t *testing.T) {
t.Run("missing_healthCheckModel", func(t *testing.T) {
f := &failover{}
f.FromJson(gjson.Parse(`{"enabled":true}`))
assert.Error(t, f.Validate())
})
t.Run("ok_with_healthCheckModel", func(t *testing.T) {
f := &failover{}
f.FromJson(gjson.Parse(`{"enabled":true,"healthCheckModel":"gpt-4o-mini"}`))
assert.NoError(t, f.Validate())
})
}
func TestHealthCheckEndpoint_Struct(t *testing.T) {
t.Run("health_check_endpoint_fields", func(t *testing.T) {
endpoint := HealthCheckEndpoint{