mirror of
https://github.com/alibaba/higress.git
synced 2026-05-25 21:28:17 +08:00
feat(ai-proxy): add cooldownDuration support for failover token recovery (#3700)
Signed-off-by: wydream <yaodiwu618@gmail.com> Signed-off-by: woody <yaodiwu618@gmail.com>
This commit is contained in:
@@ -31,6 +31,8 @@ type failover struct {
|
||||
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
|
||||
// @Title zh-CN 健康检测使用的模型
|
||||
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
|
||||
// @Title zh-CN apiToken 不可用后的冷却恢复时间,单位毫秒,配置后无需健康检测即可自动恢复
|
||||
cooldownDuration int64 `required:"false" yaml:"cooldownDuration" json:"cooldownDuration"`
|
||||
// @Title zh-CN 需要进行 failover 的原始请求的状态码,支持正则表达式匹配
|
||||
failoverOnStatus []string `required:"false" yaml:"failoverOnStatus" json:"failoverOnStatus"`
|
||||
// @Title zh-CN 本次请求使用的 apiToken
|
||||
@@ -49,6 +51,8 @@ type failover struct {
|
||||
ctxHealthCheckEndpoint string
|
||||
// @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
|
||||
ctxVmLease string
|
||||
// @Title zh-CN 记录 apiToken 被标记为不可用的时间戳,用于冷却恢复
|
||||
ctxApiTokenUnavailableSince string
|
||||
}
|
||||
|
||||
type Lease struct {
|
||||
@@ -96,6 +100,7 @@ func (f *failover) FromJson(json gjson.Result) {
|
||||
f.healthCheckTimeout = 5000
|
||||
}
|
||||
f.healthCheckModel = json.Get("healthCheckModel").String()
|
||||
f.cooldownDuration = json.Get("cooldownDuration").Int()
|
||||
|
||||
for _, status := range json.Get("failoverOnStatus").Array() {
|
||||
f.failoverOnStatus = append(f.failoverOnStatus, status.String())
|
||||
@@ -107,8 +112,8 @@ func (f *failover) FromJson(json gjson.Result) {
|
||||
}
|
||||
|
||||
func (f *failover) Validate() error {
|
||||
if f.healthCheckModel == "" {
|
||||
return errors.New("missing healthCheckModel in failover config")
|
||||
if f.healthCheckModel == "" && f.cooldownDuration <= 0 {
|
||||
return errors.New("either healthCheckModel or cooldownDuration must be configured in failover config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -124,6 +129,7 @@ func (c *ProviderConfig) initVariable() {
|
||||
c.failover.ctxUnavailableApiTokens = provider + "-" + id + "-unavailableApiTokens"
|
||||
c.failover.ctxHealthCheckEndpoint = provider + "-" + id + "-requestHostAndPath"
|
||||
c.failover.ctxVmLease = provider + "-" + id + "-vmLease"
|
||||
c.failover.ctxApiTokenUnavailableSince = provider + "-" + id + "-apiTokenUnavailableSince"
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *any) error {
|
||||
@@ -132,7 +138,8 @@ func parseConfig(json gjson.Result, config *any) error {
|
||||
|
||||
func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
|
||||
c.initVariable()
|
||||
// Reset shared data in case plugin configuration is updated
|
||||
// Reset failover shared data on config updates so stale cooldown/health-check
|
||||
// state from the previous config does not leak into the new one.
|
||||
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
|
||||
c.resetSharedData()
|
||||
|
||||
@@ -156,29 +163,57 @@ func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
|
||||
return
|
||||
}
|
||||
if len(unavailableTokens) > 0 {
|
||||
for _, apiToken := range unavailableTokens {
|
||||
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
|
||||
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
|
||||
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
|
||||
Cluster: healthCheckEndpoint.Cluster,
|
||||
})
|
||||
|
||||
ctx := createHttpContext()
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
|
||||
// Cooldown recovery: restore tokens whose cooldown period has elapsed
|
||||
if c.failover.cooldownDuration > 0 {
|
||||
timestamps, _, err := getApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to transform request headers and body: %v", err)
|
||||
}
|
||||
|
||||
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
|
||||
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode == 200 {
|
||||
c.handleAvailableApiToken(apiToken)
|
||||
log.Errorf("Failed to get apiToken unavailable timestamps: %v", err)
|
||||
} else {
|
||||
now := time.Now().UnixMilli()
|
||||
var recoveredTokens []string
|
||||
for _, apiToken := range unavailableTokens {
|
||||
if since, ok := timestamps[apiToken]; ok && now-since >= c.failover.cooldownDuration {
|
||||
log.Infof("cooldown recovery: apiToken %s has cooled down for %dms, restoring to available list", apiToken, now-since)
|
||||
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
|
||||
addApiToken(c.failover.ctxApiTokens, apiToken)
|
||||
removeApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
|
||||
recoveredTokens = append(recoveredTokens, apiToken)
|
||||
}
|
||||
}
|
||||
// Remove recovered tokens from the list to skip health check for them
|
||||
for _, token := range recoveredTokens {
|
||||
unavailableTokens = removeElement(unavailableTokens, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Health check: probe remaining unavailable tokens with a real request
|
||||
if c.failover.healthCheckModel != "" && len(unavailableTokens) > 0 {
|
||||
for _, apiToken := range unavailableTokens {
|
||||
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
|
||||
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
|
||||
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
|
||||
Cluster: healthCheckEndpoint.Cluster,
|
||||
})
|
||||
|
||||
ctx := createHttpContext()
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to transform request headers and body: %v", err)
|
||||
}
|
||||
|
||||
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
|
||||
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode == 200 {
|
||||
c.handleAvailableApiToken(apiToken)
|
||||
}
|
||||
}, uint32(c.failover.healthCheckTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to perform health check request: %v", err)
|
||||
}
|
||||
}, uint32(c.failover.healthCheckTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to perform health check request: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -355,6 +390,10 @@ func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiT
|
||||
removeApiToken(c.failover.ctxApiTokens, apiToken)
|
||||
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
|
||||
// Record the time when the apiToken becomes unavailable for cooldown recovery
|
||||
if c.failover.cooldownDuration > 0 {
|
||||
setApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken, time.Now().UnixMilli())
|
||||
}
|
||||
// Set the request host and path to shared data in case they are needed in apiToken health check
|
||||
c.setHealthCheckEndpoint(ctx)
|
||||
} else {
|
||||
@@ -527,7 +566,76 @@ func modifyApiTokenRequestCount(key, apiToken string, op string) {
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) initApiTokens() error {
|
||||
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
|
||||
_, cas, _ := getApiTokens(c.failover.ctxApiTokens)
|
||||
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, cas)
|
||||
}
|
||||
|
||||
func getApiTokenUnavailableSince(key string) (map[string]int64, uint32, error) {
|
||||
data, cas, err := proxywasm.GetSharedData(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return make(map[string]int64), cas, nil
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
if data == nil {
|
||||
return make(map[string]int64), cas, nil
|
||||
}
|
||||
|
||||
var timestamps map[string]int64
|
||||
if err = json.Unmarshal(data, ×tamps); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to unmarshal unavailableSince: %v", err)
|
||||
}
|
||||
return timestamps, cas, nil
|
||||
}
|
||||
|
||||
func setApiTokenUnavailableSince(key, apiToken string, timestamp int64) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
timestamps, cas, err := getApiTokenUnavailableSince(key)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
timestamps[apiToken] = timestamp
|
||||
data, err := json.Marshal(timestamps)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal unavailableSince: %v", err)
|
||||
return
|
||||
}
|
||||
if err := proxywasm.SetSharedData(key, data, cas); err == nil {
|
||||
return
|
||||
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
|
||||
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
|
||||
return
|
||||
}
|
||||
log.Errorf("CAS mismatch when setting %s, retrying...", key)
|
||||
}
|
||||
}
|
||||
|
||||
func removeApiTokenUnavailableSince(key, apiToken string) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
timestamps, cas, err := getApiTokenUnavailableSince(key)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
if _, ok := timestamps[apiToken]; !ok {
|
||||
return
|
||||
}
|
||||
delete(timestamps, apiToken)
|
||||
data, err := json.Marshal(timestamps)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal unavailableSince: %v", err)
|
||||
return
|
||||
}
|
||||
if err := proxywasm.SetSharedData(key, data, cas); err == nil {
|
||||
return
|
||||
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
|
||||
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
|
||||
return
|
||||
}
|
||||
log.Errorf("CAS mismatch when setting %s, retrying...", key)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetGlobalRandomToken() string {
|
||||
@@ -571,11 +679,16 @@ func (c *ProviderConfig) isFailoverEnabled() bool {
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) resetSharedData() {
|
||||
// In the real proxy-wasm host, cas=0 means "ignore CAS and overwrite"
|
||||
// instead of "match CAS=0". We rely on that behavior here so config updates
|
||||
// can unconditionally clear previous shared data state.
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenUnavailableSince, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, nil, 0)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string) types.Action {
|
||||
|
||||
@@ -538,6 +538,50 @@ func TestFailover_FromJson_Defaults(t *testing.T) {
|
||||
assert.Equal(t, int64(8000), f.healthCheckTimeout)
|
||||
assert.Equal(t, "test-model", f.healthCheckModel)
|
||||
})
|
||||
|
||||
t.Run("cooldown_duration_default", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(0), f.cooldownDuration)
|
||||
})
|
||||
|
||||
t.Run("cooldown_duration_custom", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true, "cooldownDuration": 60000}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(60000), f.cooldownDuration)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_Validate(t *testing.T) {
|
||||
t.Run("only_healthCheckModel", func(t *testing.T) {
|
||||
f := &failover{healthCheckModel: "gpt-3.5-turbo"}
|
||||
assert.NoError(t, f.Validate())
|
||||
})
|
||||
|
||||
t.Run("only_cooldownDuration", func(t *testing.T) {
|
||||
f := &failover{cooldownDuration: 60000}
|
||||
assert.NoError(t, f.Validate())
|
||||
})
|
||||
|
||||
t.Run("both_healthCheckModel_and_cooldownDuration", func(t *testing.T) {
|
||||
f := &failover{healthCheckModel: "gpt-3.5-turbo", cooldownDuration: 60000}
|
||||
assert.NoError(t, f.Validate())
|
||||
})
|
||||
|
||||
t.Run("neither_healthCheckModel_nor_cooldownDuration", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
err := f.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "either healthCheckModel or cooldownDuration")
|
||||
})
|
||||
|
||||
t.Run("negative_cooldownDuration", func(t *testing.T) {
|
||||
f := &failover{cooldownDuration: -1}
|
||||
err := f.Validate()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
|
||||
@@ -565,19 +609,6 @@ func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_Validate(t *testing.T) {
|
||||
t.Run("missing_healthCheckModel", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
f.FromJson(gjson.Parse(`{"enabled":true}`))
|
||||
assert.Error(t, f.Validate())
|
||||
})
|
||||
t.Run("ok_with_healthCheckModel", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
f.FromJson(gjson.Parse(`{"enabled":true,"healthCheckModel":"gpt-4o-mini"}`))
|
||||
assert.NoError(t, f.Validate())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHealthCheckEndpoint_Struct(t *testing.T) {
|
||||
t.Run("health_check_endpoint_fields", func(t *testing.T) {
|
||||
endpoint := HealthCheckEndpoint{
|
||||
|
||||
Reference in New Issue
Block a user