diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index dcc01f21..45a0d2b6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -52,7 +52,7 @@ description: AI 代理插件配置参考 | `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | | `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | | `customSettings` | array of customSetting | 非必填 | - | 为 AI 请求指定覆盖或者填充参数 | -| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | +| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过或冷却时间到期后重新添加回 apiToken 列表 | | `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 | | `reasoningContentMode` | string | 非必填 | - | 如何处理大模型服务返回的推理内容。目前支持以下取值:passthrough(正常输出推理内容)、ignore(不输出推理内容)、concat(将推理内容拼接在常规输出内容之前)。默认为 passthrough。仅支持通义千问服务。 | | `capabilities` | map of string | 非必填 | - | 部分 provider 的部分 ai 能力原生兼容 openai/v1 格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key 表示的是采用的厂商协议能力,values 表示的真实的厂商该能力的 api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank | @@ -92,15 +92,18 @@ custom-setting 会遵循如下表格,根据`name`和协议来替换对应的 `failover` 的配置字段说明如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -| ------------------- | --------------- | -------------------- | -------------- | -------------------------------------------------------- | -| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 | -| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) | -| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) | -| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 | -| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 | -| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 | -| failoverOnStatus | array of string | 非必填 | ["4.*", "5.*"] | 需要进行 failover 的原始请求的状态码,支持正则表达式匹配 | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ------------------- | --------------- | ----------------------------------------- | -------------- | -------------------------------------------------------------------- | +| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 | +| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) | +| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) | +| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 | +| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 | +| healthCheckModel | string | 启用 failover 时与 cooldownDuration 二选一 | - | 健康检测使用的模型。配置后会通过健康检测恢复不可用的 apiToken | +| cooldownDuration | int | 启用 failover 时与 healthCheckModel 二选一 | 0 | apiToken 不可用后的冷却恢复时间,单位毫秒。大于 0 时冷却到期自动恢复 | +| failoverOnStatus | array of string | 非必填 | ["4.*", "5.*"] | 需要进行 failover 的原始请求的状态码,支持正则表达式匹配 | + +`healthCheckModel` 和 `cooldownDuration` 至少需要配置一个。当两者同时配置时,apiToken 可通过健康检测提前恢复,也会在冷却时间到期后自动恢复。 `retryOnFailure` 的配置字段说明如下: diff --git a/plugins/wasm-go/extensions/ai-proxy/README_EN.md b/plugins/wasm-go/extensions/ai-proxy/README_EN.md index cd141ab9..93924b13 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README_EN.md +++ b/plugins/wasm-go/extensions/ai-proxy/README_EN.md @@ -51,6 +51,7 @@ Plugin execution priority: `100` | `protocol` | string | Optional | - | API contract provided by the plugin. Currently supports the following values: openai (default, uses OpenAI's interface contract), original (uses the raw interface contract of the target service provider). **Note: Auto protocol detection is now supported, no need to configure this field to support both OpenAI and Claude protocols** | | `context` | object | Optional | - | Configuration for AI conversation context information | | `customSettings` | array of customSetting | Optional | - | Specifies overrides or fills parameters for AI requests | +| `failover` | object | Optional | - | Configures apiToken failover. When an apiToken becomes unavailable, it is removed from the available token list and restored after a successful health check or after the cooldown period expires. | | `subPath` | string | Optional | - | If subPath is configured, the prefix will be removed from the request path before further processing. | | `contextCleanupCommands` | array of string | Optional | - | List of context cleanup commands. When a user message in the request exactly matches any of the configured commands, that message and all non-system messages before it will be removed, keeping only system messages and messages after the command. This enables users to actively clear conversation history. | @@ -84,6 +85,21 @@ The `custom-setting` adheres to the following table, replacing the corresponding If raw mode is enabled, `custom-setting` will directly alter the JSON content using the input `name` and `value`, without any restrictions or modifications to the parameter names. For most protocols, `custom-setting` modifies or fills parameters at the root path of the JSON content. For the `qwen` protocol, ai-proxy configures under the `parameters` subpath. For the `gemini` protocol, it configures under the `generation_config` subpath. +**Details for the `failover` configuration fields:** + +| Name | Data Type | Requirement | Default | Description | +| --------------------- | --------------- | ------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------------------- | +| `enabled` | bool | Optional | false | Whether to enable apiToken failover. | +| `failureThreshold` | int | Optional | 3 | Number of consecutive request failures required before triggering failover. | +| `successThreshold` | int | Optional | 1 | Number of successful health checks required before restoring an unavailable apiToken. | +| `healthCheckInterval` | int | Optional | 5000 | Health check interval in milliseconds. | +| `healthCheckTimeout` | int | Optional | 5000 | Health check timeout in milliseconds. | +| `healthCheckModel` | string | Required when failover is enabled unless `cooldownDuration` is configured | - | Model used for health checks. When configured, unavailable apiTokens can be restored after passing health checks. | +| `cooldownDuration` | int | Required when failover is enabled unless `healthCheckModel` is configured | 0 | Cooldown duration in milliseconds after an apiToken becomes unavailable. When greater than 0, the apiToken is restored automatically after the cooldown expires. | +| `failoverOnStatus` | array of string | Optional | ["4.*", "5.*"] | Response status codes that trigger failover for original requests. Regular expressions are supported. | + +At least one of `healthCheckModel` and `cooldownDuration` must be configured when failover is enabled. If both are configured, an apiToken can be restored either by a successful health check or after the cooldown period expires. + ### Provider-Specific Configurations #### OpenAI diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index 515c82d9..bb1bbd75 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -368,6 +368,12 @@ func TestZhipuAI(t *testing.T) { test.RunZhipuAIClaudeAutoConversionTests(t) } +func TestCooldown(t *testing.T) { + test.RunCooldownParseConfigTests(t) + test.RunCooldownOnHttpResponseHeadersTests(t) + test.RunCooldownRecoveryTests(t) +} + func TestDeepSeek(t *testing.T) { test.RunDeepSeekParseConfigTests(t) test.RunDeepSeekOnHttpRequestHeadersTests(t) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 4397c31e..fa3076f3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -31,6 +31,8 @@ type failover struct { healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"` // @Title zh-CN 健康检测使用的模型 healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"` + // @Title zh-CN apiToken 不可用后的冷却恢复时间,单位毫秒,配置后无需健康检测即可自动恢复 + cooldownDuration int64 `required:"false" yaml:"cooldownDuration" json:"cooldownDuration"` // @Title zh-CN 需要进行 failover 的原始请求的状态码,支持正则表达式匹配 failoverOnStatus []string `required:"false" yaml:"failoverOnStatus" json:"failoverOnStatus"` // @Title zh-CN 本次请求使用的 apiToken @@ -49,6 +51,8 @@ type failover struct { ctxHealthCheckEndpoint string // @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测 ctxVmLease string + // @Title zh-CN 记录 apiToken 被标记为不可用的时间戳,用于冷却恢复 + ctxApiTokenUnavailableSince string } type Lease struct { @@ -96,6 +100,7 @@ func (f *failover) FromJson(json gjson.Result) { f.healthCheckTimeout = 5000 } f.healthCheckModel = json.Get("healthCheckModel").String() + f.cooldownDuration = json.Get("cooldownDuration").Int() for _, status := range json.Get("failoverOnStatus").Array() { f.failoverOnStatus = append(f.failoverOnStatus, status.String()) @@ -107,8 +112,8 @@ func (f *failover) FromJson(json gjson.Result) { } func (f *failover) Validate() error { - if f.healthCheckModel == "" { - return errors.New("missing healthCheckModel in failover config") + if f.healthCheckModel == "" && f.cooldownDuration <= 0 { + return errors.New("either healthCheckModel or cooldownDuration must be configured in failover config") } return nil } @@ -124,6 +129,7 @@ func (c *ProviderConfig) initVariable() { c.failover.ctxUnavailableApiTokens = provider + "-" + id + "-unavailableApiTokens" c.failover.ctxHealthCheckEndpoint = provider + "-" + id + "-requestHostAndPath" c.failover.ctxVmLease = provider + "-" + id + "-vmLease" + c.failover.ctxApiTokenUnavailableSince = provider + "-" + id + "-apiTokenUnavailableSince" } func parseConfig(json gjson.Result, config *any) error { @@ -132,7 +138,8 @@ func parseConfig(json gjson.Result, config *any) error { func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error { c.initVariable() - // Reset shared data in case plugin configuration is updated + // Reset failover shared data on config updates so stale cooldown/health-check + // state from the previous config does not leak into the new one. log.Debugf("ai-proxy plugin configuration is updated, reset shared data") c.resetSharedData() @@ -156,29 +163,57 @@ func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error { return } if len(unavailableTokens) > 0 { - for _, apiToken := range unavailableTokens { - log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", ")) - healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody() - healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{ - Cluster: healthCheckEndpoint.Cluster, - }) - - ctx := createHttpContext() - ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) - - modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body) + // Cooldown recovery: restore tokens whose cooldown period has elapsed + if c.failover.cooldownDuration > 0 { + timestamps, _, err := getApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince) if err != nil { - log.Errorf("Failed to transform request headers and body: %v", err) - } - - // The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion - err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - if statusCode == 200 { - c.handleAvailableApiToken(apiToken) + log.Errorf("Failed to get apiToken unavailable timestamps: %v", err) + } else { + now := time.Now().UnixMilli() + var recoveredTokens []string + for _, apiToken := range unavailableTokens { + if since, ok := timestamps[apiToken]; ok && now-since >= c.failover.cooldownDuration { + log.Infof("cooldown recovery: apiToken %s has cooled down for %dms, restoring to available list", apiToken, now-since) + removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken) + addApiToken(c.failover.ctxApiTokens, apiToken) + removeApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken) + resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken) + recoveredTokens = append(recoveredTokens, apiToken) + } + } + // Remove recovered tokens from the list to skip health check for them + for _, token := range recoveredTokens { + unavailableTokens = removeElement(unavailableTokens, token) + } + } + } + + // Health check: probe remaining unavailable tokens with a real request + if c.failover.healthCheckModel != "" && len(unavailableTokens) > 0 { + for _, apiToken := range unavailableTokens { + log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", ")) + healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody() + healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{ + Cluster: healthCheckEndpoint.Cluster, + }) + + ctx := createHttpContext() + ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) + + modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body) + if err != nil { + log.Errorf("Failed to transform request headers and body: %v", err) + } + + // The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion + err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode == 200 { + c.handleAvailableApiToken(apiToken) + } + }, uint32(c.failover.healthCheckTimeout)) + if err != nil { + log.Errorf("Failed to perform health check request: %v", err) } - }, uint32(c.failover.healthCheckTimeout)) - if err != nil { - log.Errorf("Failed to perform health check request: %v", err) } } } @@ -355,6 +390,10 @@ func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiT removeApiToken(c.failover.ctxApiTokens, apiToken) addApiToken(c.failover.ctxUnavailableApiTokens, apiToken) resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken) + // Record the time when the apiToken becomes unavailable for cooldown recovery + if c.failover.cooldownDuration > 0 { + setApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken, time.Now().UnixMilli()) + } // Set the request host and path to shared data in case they are needed in apiToken health check c.setHealthCheckEndpoint(ctx) } else { @@ -527,7 +566,76 @@ func modifyApiTokenRequestCount(key, apiToken string, op string) { } func (c *ProviderConfig) initApiTokens() error { - return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0) + _, cas, _ := getApiTokens(c.failover.ctxApiTokens) + return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, cas) +} + +func getApiTokenUnavailableSince(key string) (map[string]int64, uint32, error) { + data, cas, err := proxywasm.GetSharedData(key) + if err != nil { + if errors.Is(err, types.ErrorStatusNotFound) { + return make(map[string]int64), cas, nil + } + return nil, 0, err + } + if data == nil { + return make(map[string]int64), cas, nil + } + + var timestamps map[string]int64 + if err = json.Unmarshal(data, ×tamps); err != nil { + return nil, 0, fmt.Errorf("failed to unmarshal unavailableSince: %v", err) + } + return timestamps, cas, nil +} + +func setApiTokenUnavailableSince(key, apiToken string, timestamp int64) { + for attempt := 1; attempt <= casMaxRetries; attempt++ { + timestamps, cas, err := getApiTokenUnavailableSince(key) + if err != nil { + log.Errorf("Failed to get %s: %v", key, err) + continue + } + timestamps[apiToken] = timestamp + data, err := json.Marshal(timestamps) + if err != nil { + log.Errorf("Failed to marshal unavailableSince: %v", err) + return + } + if err := proxywasm.SetSharedData(key, data, cas); err == nil { + return + } else if !errors.Is(err, types.ErrorStatusCasMismatch) { + log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err) + return + } + log.Errorf("CAS mismatch when setting %s, retrying...", key) + } +} + +func removeApiTokenUnavailableSince(key, apiToken string) { + for attempt := 1; attempt <= casMaxRetries; attempt++ { + timestamps, cas, err := getApiTokenUnavailableSince(key) + if err != nil { + log.Errorf("Failed to get %s: %v", key, err) + continue + } + if _, ok := timestamps[apiToken]; !ok { + return + } + delete(timestamps, apiToken) + data, err := json.Marshal(timestamps) + if err != nil { + log.Errorf("Failed to marshal unavailableSince: %v", err) + return + } + if err := proxywasm.SetSharedData(key, data, cas); err == nil { + return + } else if !errors.Is(err, types.ErrorStatusCasMismatch) { + log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err) + return + } + log.Errorf("CAS mismatch when setting %s, retrying...", key) + } } func (c *ProviderConfig) GetGlobalRandomToken() string { @@ -571,11 +679,16 @@ func (c *ProviderConfig) isFailoverEnabled() bool { } func (c *ProviderConfig) resetSharedData() { + // In the real proxy-wasm host, cas=0 means "ignore CAS and overwrite" + // instead of "match CAS=0". We rely on that behavior here so config updates + // can unconditionally clear previous shared data state. _ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0) _ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0) _ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0) _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0) _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxApiTokenUnavailableSince, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, nil, 0) } func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string) types.Action { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go index ecf0b979..a022a722 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go @@ -538,6 +538,50 @@ func TestFailover_FromJson_Defaults(t *testing.T) { assert.Equal(t, int64(8000), f.healthCheckTimeout) assert.Equal(t, "test-model", f.healthCheckModel) }) + + t.Run("cooldown_duration_default", func(t *testing.T) { + f := &failover{} + jsonStr := `{"enabled": true}` + f.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, int64(0), f.cooldownDuration) + }) + + t.Run("cooldown_duration_custom", func(t *testing.T) { + f := &failover{} + jsonStr := `{"enabled": true, "cooldownDuration": 60000}` + f.FromJson(gjson.Parse(jsonStr)) + assert.Equal(t, int64(60000), f.cooldownDuration) + }) +} + +func TestFailover_Validate(t *testing.T) { + t.Run("only_healthCheckModel", func(t *testing.T) { + f := &failover{healthCheckModel: "gpt-3.5-turbo"} + assert.NoError(t, f.Validate()) + }) + + t.Run("only_cooldownDuration", func(t *testing.T) { + f := &failover{cooldownDuration: 60000} + assert.NoError(t, f.Validate()) + }) + + t.Run("both_healthCheckModel_and_cooldownDuration", func(t *testing.T) { + f := &failover{healthCheckModel: "gpt-3.5-turbo", cooldownDuration: 60000} + assert.NoError(t, f.Validate()) + }) + + t.Run("neither_healthCheckModel_nor_cooldownDuration", func(t *testing.T) { + f := &failover{} + err := f.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "either healthCheckModel or cooldownDuration") + }) + + t.Run("negative_cooldownDuration", func(t *testing.T) { + f := &failover{cooldownDuration: -1} + err := f.Validate() + assert.Error(t, err) + }) } func TestFailover_FromJson_FailoverOnStatus(t *testing.T) { @@ -565,19 +609,6 @@ func TestFailover_FromJson_FailoverOnStatus(t *testing.T) { }) } -func TestFailover_Validate(t *testing.T) { - t.Run("missing_healthCheckModel", func(t *testing.T) { - f := &failover{} - f.FromJson(gjson.Parse(`{"enabled":true}`)) - assert.Error(t, f.Validate()) - }) - t.Run("ok_with_healthCheckModel", func(t *testing.T) { - f := &failover{} - f.FromJson(gjson.Parse(`{"enabled":true,"healthCheckModel":"gpt-4o-mini"}`)) - assert.NoError(t, f.Validate()) - }) -} - func TestHealthCheckEndpoint_Struct(t *testing.T) { t.Run("health_check_endpoint_fields", func(t *testing.T) { endpoint := HealthCheckEndpoint{ diff --git a/plugins/wasm-go/extensions/ai-proxy/test/cooldown.go b/plugins/wasm-go/extensions/ai-proxy/test/cooldown.go new file mode 100644 index 00000000..8da97e84 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/cooldown.go @@ -0,0 +1,673 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +// setUpstreamResponse marks the response as coming from upstream so onHttpResponseHeaders processes it +func setUpstreamResponse(host test.TestHost) { + _ = host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) +} + +// 测试配置:cooldown only(无 healthCheckModel) +var cooldownOnlyConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-token-a", "sk-token-b"}, + "modelMapping": map[string]string{ + "*": "gpt-3.5-turbo", + }, + "failover": map[string]interface{}{ + "enabled": true, + "failureThreshold": 1, + "cooldownDuration": 100, + "failoverOnStatus": []string{"429"}, + }, + }, + }) + return data +}() + +// 测试配置:cooldown + healthCheck 同时配置 +var cooldownWithHealthCheckConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-token-a", "sk-token-b"}, + "modelMapping": map[string]string{ + "*": "gpt-3.5-turbo", + }, + "failover": map[string]interface{}{ + "enabled": true, + "failureThreshold": 1, + "cooldownDuration": 100, + "healthCheckModel": "gpt-3.5-turbo", + "healthCheckTimeout": 5000, + "failoverOnStatus": []string{"429"}, + }, + }, + }) + return data +}() + +// 测试配置:failover 启用但既没有 cooldown 也没有 healthCheckModel +var failoverNoRecoveryConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-token-a"}, + "modelMapping": map[string]string{ + "*": "gpt-3.5-turbo", + }, + "failover": map[string]interface{}{ + "enabled": true, + "failureThreshold": 1, + "failoverOnStatus": []string{"429"}, + }, + }, + }) + return data +}() + +// 测试配置:cooldown 较长,用于测试冷却未到期的场景 +var cooldownLongDurationConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-token-a", "sk-token-b"}, + "modelMapping": map[string]string{ + "*": "gpt-3.5-turbo", + }, + "failover": map[string]interface{}{ + "enabled": true, + "failureThreshold": 1, + "cooldownDuration": 600000, + "failoverOnStatus": []string{"429"}, + }, + }, + }) + return data +}() + +// 测试配置:三个 token,failureThreshold=2 +var cooldownThreeTokensConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-token-a", "sk-token-b", "sk-token-c"}, + "modelMapping": map[string]string{ + "*": "gpt-3.5-turbo", + }, + "failover": map[string]interface{}{ + "enabled": true, + "failureThreshold": 2, + "cooldownDuration": 100, + "failoverOnStatus": []string{"429"}, + }, + }, + }) + return data +}() + +// 测试配置:单个 token + cooldown +var cooldownSingleTokenConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-only-token"}, + "modelMapping": map[string]string{ + "*": "gpt-3.5-turbo", + }, + "failover": map[string]interface{}{ + "enabled": true, + "failureThreshold": 1, + "cooldownDuration": 100, + "failoverOnStatus": []string{"429"}, + }, + }, + }) + return data +}() + +// 测试配置:cooldown + 多种 failoverOnStatus +var cooldownMultiStatusConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-token-a", "sk-token-b"}, + "modelMapping": map[string]string{ + "*": "gpt-3.5-turbo", + }, + "failover": map[string]interface{}{ + "enabled": true, + "failureThreshold": 1, + "cooldownDuration": 100, + "failoverOnStatus": []string{"429", "5.*"}, + }, + }, + }) + return data +}() + +// ============ Parse Config Tests ============ + +func RunCooldownParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + // cooldown only 配置应正常启动 + t.Run("cooldown only config starts ok", func(t *testing.T) { + host, status := test.NewTestHost(cooldownOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + // cooldown + healthCheck 同时配置应正常启动 + t.Run("cooldown with healthCheck config starts ok", func(t *testing.T) { + host, status := test.NewTestHost(cooldownWithHealthCheckConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + // failover 启用但既没有 cooldown 也没有 healthCheckModel 应启动失败 + t.Run("failover without recovery config fails", func(t *testing.T) { + host, status := test.NewTestHost(failoverNoRecoveryConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + + // 单 token + cooldown 配置应正常启动 + t.Run("single token with cooldown config starts ok", func(t *testing.T) { + host, status := test.NewTestHost(cooldownSingleTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + }) +} + +// ============ Failover on 429 Tests ============ + +func RunCooldownOnHttpResponseHeadersTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + // 429 响应应触发 failover 日志 + t.Run("429 triggers failover", func(t *testing.T) { + host, status := test.NewTestHost(cooldownOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + + setUpstreamResponse(host) + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "429"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.ActionContinue, action) + + // 验证 failover 日志 + warnLogs := host.GetWarnLogs() + hasFailoverLog := false + for _, log := range warnLogs { + if strings.Contains(log, "need failover") && strings.Contains(log, "429") { + hasFailoverLog = true + break + } + } + require.True(t, hasFailoverLog, "Should have failover warning log on 429") + }) + + // 200 响应不应触发 failover + t.Run("200 does not trigger failover", func(t *testing.T) { + host, status := test.NewTestHost(cooldownOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + + setUpstreamResponse(host) + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.ActionContinue, action) + + warnLogs := host.GetWarnLogs() + for _, log := range warnLogs { + require.False(t, strings.Contains(log, "need failover"), "Should not have failover log on 200") + } + }) + + // 非 failoverOnStatus 的错误码不应触发 failover + t.Run("500 does not trigger failover when only 429 configured", func(t *testing.T) { + host, status := test.NewTestHost(cooldownOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + + setUpstreamResponse(host) + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "500"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.ActionContinue, action) + + warnLogs := host.GetWarnLogs() + for _, log := range warnLogs { + require.False(t, strings.Contains(log, "need failover"), "Should not have failover log on 500 when only 429 configured") + } + }) + + // 多种 failoverOnStatus 匹配测试:500 应触发 failover + t.Run("500 triggers failover when 5xx configured", func(t *testing.T) { + host, status := test.NewTestHost(cooldownMultiStatusConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + + setUpstreamResponse(host) + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "500"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.ActionContinue, action) + + warnLogs := host.GetWarnLogs() + hasFailoverLog := false + for _, log := range warnLogs { + if strings.Contains(log, "need failover") { + hasFailoverLog = true + break + } + } + require.True(t, hasFailoverLog, "Should have failover log on 500 when 5.* configured") + }) + }) +} + +// ============ Cooldown Recovery Tests ============ + +func RunCooldownRecoveryTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + // token 被摘除后,冷却到期后 tick 应恢复 + t.Run("token recovered after cooldown expires", func(t *testing.T) { + host, status := test.NewTestHost(cooldownOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 触发一次 429 使 token 被摘除(failureThreshold=1) + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + setUpstreamResponse(host) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "429"}, + {"Content-Type", "application/json"}, + }) + + // 验证 token 被标记为不可用 + infoLogs := host.GetInfoLogs() + hasUnavailableLog := false + for _, log := range infoLogs { + if strings.Contains(log, "is unavailable now") { + hasUnavailableLog = true + break + } + } + require.True(t, hasUnavailableLog, "Token should be marked as unavailable after 429") + + // 等待冷却到期(cooldownDuration=100ms) + time.Sleep(150 * time.Millisecond) + + // 触发 tick 执行冷却恢复 + host.Tick() + + // 验证 token 被恢复 + infoLogs = host.GetInfoLogs() + hasRecoveryLog := false + for _, log := range infoLogs { + if strings.Contains(log, "cooldown recovery") && strings.Contains(log, "restoring to available list") { + hasRecoveryLog = true + break + } + } + require.True(t, hasRecoveryLog, "Token should be recovered after cooldown expires and tick fires") + }) + + // 冷却未到期时 tick 不应恢复 token + t.Run("token not recovered before cooldown expires", func(t *testing.T) { + host, status := test.NewTestHost(cooldownLongDurationConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 触发 429 使 token 被摘除 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + setUpstreamResponse(host) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "429"}, + {"Content-Type", "application/json"}, + }) + + // 立即 tick,冷却未到期(cooldownDuration=600000ms) + host.Tick() + + // 验证 token 未被恢复 + infoLogs := host.GetInfoLogs() + for _, log := range infoLogs { + require.False(t, strings.Contains(log, "cooldown recovery"), + "Token should NOT be recovered before cooldown expires") + } + }) + + // failureThreshold > 1 时,单次失败不应摘除 token + t.Run("single failure does not remove token when threshold is 2", func(t *testing.T) { + host, status := test.NewTestHost(cooldownThreeTokensConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 只触发一次 429(failureThreshold=2) + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + setUpstreamResponse(host) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "429"}, + {"Content-Type", "application/json"}, + }) + + // 验证 token 未被摘除(需要连续 2 次失败) + infoLogs := host.GetInfoLogs() + for _, log := range infoLogs { + require.False(t, strings.Contains(log, "is unavailable now"), + "Token should NOT be removed after single failure when threshold is 2") + } + + // 验证有 debug 日志记录失败次数 + debugLogs := host.GetDebugLogs() + hasThresholdLog := false + for _, log := range debugLogs { + if strings.Contains(log, "has not reached the failure threshold") { + hasThresholdLog = true + break + } + } + require.True(t, hasThresholdLog, "Should log that failure threshold not reached") + }) + + // 单个 token 被摘除后,应使用不可用 token 兜底 + t.Run("single token fallback to unavailable when all removed", func(t *testing.T) { + host, status := test.NewTestHost(cooldownSingleTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 触发 429 使唯一的 token 被摘除 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + setUpstreamResponse(host) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "429"}, + {"Content-Type", "application/json"}, + }) + host.CompleteHttp() + + // 发起新请求,应使用不可用 token 兜底 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestHeaders := host.GetRequestHeaders() + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist even when all tokens unavailable") + require.Contains(t, authValue, "sk-only-token", "Should fallback to the unavailable token") + + // 验证有 warn 日志 + warnLogs := host.GetWarnLogs() + hasAllUnavailableLog := false + for _, log := range warnLogs { + if strings.Contains(log, "all tokens are unavailable") { + hasAllUnavailableLog = true + break + } + } + require.True(t, hasAllUnavailableLog, "Should warn that all tokens are unavailable") + }) + + // 单个 token 被摘除后,冷却到期后恢复,新请求应正常使用 + t.Run("single token recovered after cooldown and used in new request", func(t *testing.T) { + host, status := test.NewTestHost(cooldownSingleTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 触发 429 使 token 被摘除 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + setUpstreamResponse(host) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "429"}, + {"Content-Type", "application/json"}, + }) + host.CompleteHttp() + + // 等待冷却到期 + time.Sleep(150 * time.Millisecond) + host.Tick() + + // 验证恢复日志 + infoLogs := host.GetInfoLogs() + hasRecoveryLog := false + for _, log := range infoLogs { + if strings.Contains(log, "cooldown recovery") { + hasRecoveryLog = true + break + } + } + require.True(t, hasRecoveryLog, "Token should be recovered after cooldown") + + // 发起新请求,应正常使用恢复的 token(不再有 all tokens unavailable 警告) + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestHeaders := host.GetRequestHeaders() + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.Contains(t, authValue, "sk-only-token", "Should use the recovered token") + }) + + // 两个 token,一个被摘除后另一个继续使用 + t.Run("second token used after first is removed", func(t *testing.T) { + host, status := test.NewTestHost(cooldownOnlyConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 触发 429 使一个 token 被摘除 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + setUpstreamResponse(host) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "429"}, + {"Content-Type", "application/json"}, + }) + host.CompleteHttp() + + // 验证 token 被摘除 + infoLogs := host.GetInfoLogs() + hasUnavailableLog := false + for _, log := range infoLogs { + if strings.Contains(log, "is unavailable now") { + hasUnavailableLog = true + break + } + } + require.True(t, hasUnavailableLog, "One token should be marked unavailable") + + // 发起新请求,应使用剩余的可用 token + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + requestHeaders := host.GetRequestHeaders() + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.True(t, strings.Contains(authValue, "sk-token-"), "Should use one of the configured tokens") + }) + + // 成功请求应重置失败计数 + t.Run("successful request resets failure count", func(t *testing.T) { + resetCountConfig := func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-reset-token"}, + "modelMapping": map[string]string{ + "*": "gpt-3.5-turbo", + }, + "failover": map[string]interface{}{ + "enabled": true, + "failureThreshold": 2, + "cooldownDuration": 100, + "failoverOnStatus": []string{"429"}, + }, + }, + }) + return data + }() + + host, status := test.NewTestHost(resetCountConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 第一次请求 429(failureThreshold=2,不会摘除) + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + setUpstreamResponse(host) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "429"}, + {"Content-Type", "application/json"}, + }) + host.CompleteHttp() + + // 第二次请求成功(200),应重置失败计数 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + setUpstreamResponse(host) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + }) + host.CompleteHttp() + + // 第三次请求 429(因为计数已重置,不应摘除) + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`)) + setUpstreamResponse(host) + host.CallOnHttpResponseHeaders([][2]string{ + {":status", "429"}, + {"Content-Type", "application/json"}, + }) + + // 验证 token 未被摘除 + infoLogs := host.GetInfoLogs() + for _, log := range infoLogs { + require.False(t, strings.Contains(log, "is unavailable now"), + "Token should NOT be removed because failure count was reset by successful request") + } + }) + }) +}