mirror of
https://github.com/alibaba/higress.git
synced 2026-05-21 19:27:28 +08:00
feat(ai-proxy): add cooldownDuration support for failover token recovery (#3700)
Signed-off-by: wydream <yaodiwu618@gmail.com> Signed-off-by: woody <yaodiwu618@gmail.com>
This commit is contained in:
@@ -52,7 +52,7 @@ description: AI 代理插件配置参考
|
||||
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
|
||||
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
||||
| `customSettings` | array of customSetting | 非必填 | - | 为 AI 请求指定覆盖或者填充参数 |
|
||||
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
|
||||
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过或冷却时间到期后重新添加回 apiToken 列表 |
|
||||
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
|
||||
| `reasoningContentMode` | string | 非必填 | - | 如何处理大模型服务返回的推理内容。目前支持以下取值:passthrough(正常输出推理内容)、ignore(不输出推理内容)、concat(将推理内容拼接在常规输出内容之前)。默认为 passthrough。仅支持通义千问服务。 |
|
||||
| `capabilities` | map of string | 非必填 | - | 部分 provider 的部分 ai 能力原生兼容 openai/v1 格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key 表示的是采用的厂商协议能力,values 表示的真实的厂商该能力的 api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank |
|
||||
@@ -92,15 +92,18 @@ custom-setting 会遵循如下表格,根据`name`和协议来替换对应的
|
||||
|
||||
`failover` 的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ------------------- | --------------- | -------------------- | -------------- | -------------------------------------------------------- |
|
||||
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
|
||||
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
|
||||
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
|
||||
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
|
||||
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
|
||||
| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 |
|
||||
| failoverOnStatus | array of string | 非必填 | ["4.*", "5.*"] | 需要进行 failover 的原始请求的状态码,支持正则表达式匹配 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ------------------- | --------------- | ----------------------------------------- | -------------- | -------------------------------------------------------------------- |
|
||||
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
|
||||
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
|
||||
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
|
||||
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
|
||||
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
|
||||
| healthCheckModel | string | 启用 failover 时与 cooldownDuration 二选一 | - | 健康检测使用的模型。配置后会通过健康检测恢复不可用的 apiToken |
|
||||
| cooldownDuration | int | 启用 failover 时与 healthCheckModel 二选一 | 0 | apiToken 不可用后的冷却恢复时间,单位毫秒。大于 0 时冷却到期自动恢复 |
|
||||
| failoverOnStatus | array of string | 非必填 | ["4.*", "5.*"] | 需要进行 failover 的原始请求的状态码,支持正则表达式匹配 |
|
||||
|
||||
`healthCheckModel` 和 `cooldownDuration` 至少需要配置一个。当两者同时配置时,apiToken 可通过健康检测提前恢复,也会在冷却时间到期后自动恢复。
|
||||
|
||||
`retryOnFailure` 的配置字段说明如下:
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ Plugin execution priority: `100`
|
||||
| `protocol` | string | Optional | - | API contract provided by the plugin. Currently supports the following values: openai (default, uses OpenAI's interface contract), original (uses the raw interface contract of the target service provider). **Note: Auto protocol detection is now supported, no need to configure this field to support both OpenAI and Claude protocols** |
|
||||
| `context` | object | Optional | - | Configuration for AI conversation context information |
|
||||
| `customSettings` | array of customSetting | Optional | - | Specifies overrides or fills parameters for AI requests |
|
||||
| `failover` | object | Optional | - | Configures apiToken failover. When an apiToken becomes unavailable, it is removed from the available token list and restored after a successful health check or after the cooldown period expires. |
|
||||
| `subPath` | string | Optional | - | If subPath is configured, the prefix will be removed from the request path before further processing. |
|
||||
| `contextCleanupCommands` | array of string | Optional | - | List of context cleanup commands. When a user message in the request exactly matches any of the configured commands, that message and all non-system messages before it will be removed, keeping only system messages and messages after the command. This enables users to actively clear conversation history. |
|
||||
|
||||
@@ -84,6 +85,21 @@ The `custom-setting` adheres to the following table, replacing the corresponding
|
||||
If raw mode is enabled, `custom-setting` will directly alter the JSON content using the input `name` and `value`, without any restrictions or modifications to the parameter names.
|
||||
For most protocols, `custom-setting` modifies or fills parameters at the root path of the JSON content. For the `qwen` protocol, ai-proxy configures under the `parameters` subpath. For the `gemini` protocol, it configures under the `generation_config` subpath.
|
||||
|
||||
**Details for the `failover` configuration fields:**
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
| --------------------- | --------------- | ------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| `enabled` | bool | Optional | false | Whether to enable apiToken failover. |
|
||||
| `failureThreshold` | int | Optional | 3 | Number of consecutive request failures required before triggering failover. |
|
||||
| `successThreshold` | int | Optional | 1 | Number of successful health checks required before restoring an unavailable apiToken. |
|
||||
| `healthCheckInterval` | int | Optional | 5000 | Health check interval in milliseconds. |
|
||||
| `healthCheckTimeout` | int | Optional | 5000 | Health check timeout in milliseconds. |
|
||||
| `healthCheckModel` | string | Required when failover is enabled unless `cooldownDuration` is configured | - | Model used for health checks. When configured, unavailable apiTokens can be restored after passing health checks. |
|
||||
| `cooldownDuration` | int | Required when failover is enabled unless `healthCheckModel` is configured | 0 | Cooldown duration in milliseconds after an apiToken becomes unavailable. When greater than 0, the apiToken is restored automatically after the cooldown expires. |
|
||||
| `failoverOnStatus` | array of string | Optional | ["4.*", "5.*"] | Response status codes that trigger failover for original requests. Regular expressions are supported. |
|
||||
|
||||
At least one of `healthCheckModel` and `cooldownDuration` must be configured when failover is enabled. If both are configured, an apiToken can be restored either by a successful health check or after the cooldown period expires.
|
||||
|
||||
### Provider-Specific Configurations
|
||||
|
||||
#### OpenAI
|
||||
|
||||
@@ -368,6 +368,12 @@ func TestZhipuAI(t *testing.T) {
|
||||
test.RunZhipuAIClaudeAutoConversionTests(t)
|
||||
}
|
||||
|
||||
func TestCooldown(t *testing.T) {
|
||||
test.RunCooldownParseConfigTests(t)
|
||||
test.RunCooldownOnHttpResponseHeadersTests(t)
|
||||
test.RunCooldownRecoveryTests(t)
|
||||
}
|
||||
|
||||
func TestDeepSeek(t *testing.T) {
|
||||
test.RunDeepSeekParseConfigTests(t)
|
||||
test.RunDeepSeekOnHttpRequestHeadersTests(t)
|
||||
|
||||
@@ -31,6 +31,8 @@ type failover struct {
|
||||
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
|
||||
// @Title zh-CN 健康检测使用的模型
|
||||
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
|
||||
// @Title zh-CN apiToken 不可用后的冷却恢复时间,单位毫秒,配置后无需健康检测即可自动恢复
|
||||
cooldownDuration int64 `required:"false" yaml:"cooldownDuration" json:"cooldownDuration"`
|
||||
// @Title zh-CN 需要进行 failover 的原始请求的状态码,支持正则表达式匹配
|
||||
failoverOnStatus []string `required:"false" yaml:"failoverOnStatus" json:"failoverOnStatus"`
|
||||
// @Title zh-CN 本次请求使用的 apiToken
|
||||
@@ -49,6 +51,8 @@ type failover struct {
|
||||
ctxHealthCheckEndpoint string
|
||||
// @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
|
||||
ctxVmLease string
|
||||
// @Title zh-CN 记录 apiToken 被标记为不可用的时间戳,用于冷却恢复
|
||||
ctxApiTokenUnavailableSince string
|
||||
}
|
||||
|
||||
type Lease struct {
|
||||
@@ -96,6 +100,7 @@ func (f *failover) FromJson(json gjson.Result) {
|
||||
f.healthCheckTimeout = 5000
|
||||
}
|
||||
f.healthCheckModel = json.Get("healthCheckModel").String()
|
||||
f.cooldownDuration = json.Get("cooldownDuration").Int()
|
||||
|
||||
for _, status := range json.Get("failoverOnStatus").Array() {
|
||||
f.failoverOnStatus = append(f.failoverOnStatus, status.String())
|
||||
@@ -107,8 +112,8 @@ func (f *failover) FromJson(json gjson.Result) {
|
||||
}
|
||||
|
||||
func (f *failover) Validate() error {
|
||||
if f.healthCheckModel == "" {
|
||||
return errors.New("missing healthCheckModel in failover config")
|
||||
if f.healthCheckModel == "" && f.cooldownDuration <= 0 {
|
||||
return errors.New("either healthCheckModel or cooldownDuration must be configured in failover config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -124,6 +129,7 @@ func (c *ProviderConfig) initVariable() {
|
||||
c.failover.ctxUnavailableApiTokens = provider + "-" + id + "-unavailableApiTokens"
|
||||
c.failover.ctxHealthCheckEndpoint = provider + "-" + id + "-requestHostAndPath"
|
||||
c.failover.ctxVmLease = provider + "-" + id + "-vmLease"
|
||||
c.failover.ctxApiTokenUnavailableSince = provider + "-" + id + "-apiTokenUnavailableSince"
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *any) error {
|
||||
@@ -132,7 +138,8 @@ func parseConfig(json gjson.Result, config *any) error {
|
||||
|
||||
func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
|
||||
c.initVariable()
|
||||
// Reset shared data in case plugin configuration is updated
|
||||
// Reset failover shared data on config updates so stale cooldown/health-check
|
||||
// state from the previous config does not leak into the new one.
|
||||
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
|
||||
c.resetSharedData()
|
||||
|
||||
@@ -156,29 +163,57 @@ func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
|
||||
return
|
||||
}
|
||||
if len(unavailableTokens) > 0 {
|
||||
for _, apiToken := range unavailableTokens {
|
||||
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
|
||||
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
|
||||
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
|
||||
Cluster: healthCheckEndpoint.Cluster,
|
||||
})
|
||||
|
||||
ctx := createHttpContext()
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
|
||||
// Cooldown recovery: restore tokens whose cooldown period has elapsed
|
||||
if c.failover.cooldownDuration > 0 {
|
||||
timestamps, _, err := getApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to transform request headers and body: %v", err)
|
||||
}
|
||||
|
||||
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
|
||||
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode == 200 {
|
||||
c.handleAvailableApiToken(apiToken)
|
||||
log.Errorf("Failed to get apiToken unavailable timestamps: %v", err)
|
||||
} else {
|
||||
now := time.Now().UnixMilli()
|
||||
var recoveredTokens []string
|
||||
for _, apiToken := range unavailableTokens {
|
||||
if since, ok := timestamps[apiToken]; ok && now-since >= c.failover.cooldownDuration {
|
||||
log.Infof("cooldown recovery: apiToken %s has cooled down for %dms, restoring to available list", apiToken, now-since)
|
||||
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
|
||||
addApiToken(c.failover.ctxApiTokens, apiToken)
|
||||
removeApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
|
||||
recoveredTokens = append(recoveredTokens, apiToken)
|
||||
}
|
||||
}
|
||||
// Remove recovered tokens from the list to skip health check for them
|
||||
for _, token := range recoveredTokens {
|
||||
unavailableTokens = removeElement(unavailableTokens, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Health check: probe remaining unavailable tokens with a real request
|
||||
if c.failover.healthCheckModel != "" && len(unavailableTokens) > 0 {
|
||||
for _, apiToken := range unavailableTokens {
|
||||
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
|
||||
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
|
||||
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
|
||||
Cluster: healthCheckEndpoint.Cluster,
|
||||
})
|
||||
|
||||
ctx := createHttpContext()
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to transform request headers and body: %v", err)
|
||||
}
|
||||
|
||||
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
|
||||
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode == 200 {
|
||||
c.handleAvailableApiToken(apiToken)
|
||||
}
|
||||
}, uint32(c.failover.healthCheckTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to perform health check request: %v", err)
|
||||
}
|
||||
}, uint32(c.failover.healthCheckTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to perform health check request: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -355,6 +390,10 @@ func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiT
|
||||
removeApiToken(c.failover.ctxApiTokens, apiToken)
|
||||
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
|
||||
// Record the time when the apiToken becomes unavailable for cooldown recovery
|
||||
if c.failover.cooldownDuration > 0 {
|
||||
setApiTokenUnavailableSince(c.failover.ctxApiTokenUnavailableSince, apiToken, time.Now().UnixMilli())
|
||||
}
|
||||
// Set the request host and path to shared data in case they are needed in apiToken health check
|
||||
c.setHealthCheckEndpoint(ctx)
|
||||
} else {
|
||||
@@ -527,7 +566,76 @@ func modifyApiTokenRequestCount(key, apiToken string, op string) {
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) initApiTokens() error {
|
||||
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
|
||||
_, cas, _ := getApiTokens(c.failover.ctxApiTokens)
|
||||
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, cas)
|
||||
}
|
||||
|
||||
func getApiTokenUnavailableSince(key string) (map[string]int64, uint32, error) {
|
||||
data, cas, err := proxywasm.GetSharedData(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return make(map[string]int64), cas, nil
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
if data == nil {
|
||||
return make(map[string]int64), cas, nil
|
||||
}
|
||||
|
||||
var timestamps map[string]int64
|
||||
if err = json.Unmarshal(data, ×tamps); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to unmarshal unavailableSince: %v", err)
|
||||
}
|
||||
return timestamps, cas, nil
|
||||
}
|
||||
|
||||
func setApiTokenUnavailableSince(key, apiToken string, timestamp int64) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
timestamps, cas, err := getApiTokenUnavailableSince(key)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
timestamps[apiToken] = timestamp
|
||||
data, err := json.Marshal(timestamps)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal unavailableSince: %v", err)
|
||||
return
|
||||
}
|
||||
if err := proxywasm.SetSharedData(key, data, cas); err == nil {
|
||||
return
|
||||
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
|
||||
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
|
||||
return
|
||||
}
|
||||
log.Errorf("CAS mismatch when setting %s, retrying...", key)
|
||||
}
|
||||
}
|
||||
|
||||
func removeApiTokenUnavailableSince(key, apiToken string) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
timestamps, cas, err := getApiTokenUnavailableSince(key)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
if _, ok := timestamps[apiToken]; !ok {
|
||||
return
|
||||
}
|
||||
delete(timestamps, apiToken)
|
||||
data, err := json.Marshal(timestamps)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal unavailableSince: %v", err)
|
||||
return
|
||||
}
|
||||
if err := proxywasm.SetSharedData(key, data, cas); err == nil {
|
||||
return
|
||||
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
|
||||
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
|
||||
return
|
||||
}
|
||||
log.Errorf("CAS mismatch when setting %s, retrying...", key)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetGlobalRandomToken() string {
|
||||
@@ -571,11 +679,16 @@ func (c *ProviderConfig) isFailoverEnabled() bool {
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) resetSharedData() {
|
||||
// In the real proxy-wasm host, cas=0 means "ignore CAS and overwrite"
|
||||
// instead of "match CAS=0". We rely on that behavior here so config updates
|
||||
// can unconditionally clear previous shared data state.
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenUnavailableSince, nil, 0)
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, nil, 0)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string) types.Action {
|
||||
|
||||
@@ -538,6 +538,50 @@ func TestFailover_FromJson_Defaults(t *testing.T) {
|
||||
assert.Equal(t, int64(8000), f.healthCheckTimeout)
|
||||
assert.Equal(t, "test-model", f.healthCheckModel)
|
||||
})
|
||||
|
||||
t.Run("cooldown_duration_default", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(0), f.cooldownDuration)
|
||||
})
|
||||
|
||||
t.Run("cooldown_duration_custom", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
jsonStr := `{"enabled": true, "cooldownDuration": 60000}`
|
||||
f.FromJson(gjson.Parse(jsonStr))
|
||||
assert.Equal(t, int64(60000), f.cooldownDuration)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_Validate(t *testing.T) {
|
||||
t.Run("only_healthCheckModel", func(t *testing.T) {
|
||||
f := &failover{healthCheckModel: "gpt-3.5-turbo"}
|
||||
assert.NoError(t, f.Validate())
|
||||
})
|
||||
|
||||
t.Run("only_cooldownDuration", func(t *testing.T) {
|
||||
f := &failover{cooldownDuration: 60000}
|
||||
assert.NoError(t, f.Validate())
|
||||
})
|
||||
|
||||
t.Run("both_healthCheckModel_and_cooldownDuration", func(t *testing.T) {
|
||||
f := &failover{healthCheckModel: "gpt-3.5-turbo", cooldownDuration: 60000}
|
||||
assert.NoError(t, f.Validate())
|
||||
})
|
||||
|
||||
t.Run("neither_healthCheckModel_nor_cooldownDuration", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
err := f.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "either healthCheckModel or cooldownDuration")
|
||||
})
|
||||
|
||||
t.Run("negative_cooldownDuration", func(t *testing.T) {
|
||||
f := &failover{cooldownDuration: -1}
|
||||
err := f.Validate()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
|
||||
@@ -565,19 +609,6 @@ func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_Validate(t *testing.T) {
|
||||
t.Run("missing_healthCheckModel", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
f.FromJson(gjson.Parse(`{"enabled":true}`))
|
||||
assert.Error(t, f.Validate())
|
||||
})
|
||||
t.Run("ok_with_healthCheckModel", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
f.FromJson(gjson.Parse(`{"enabled":true,"healthCheckModel":"gpt-4o-mini"}`))
|
||||
assert.NoError(t, f.Validate())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHealthCheckEndpoint_Struct(t *testing.T) {
|
||||
t.Run("health_check_endpoint_fields", func(t *testing.T) {
|
||||
endpoint := HealthCheckEndpoint{
|
||||
|
||||
673
plugins/wasm-go/extensions/ai-proxy/test/cooldown.go
Normal file
673
plugins/wasm-go/extensions/ai-proxy/test/cooldown.go
Normal file
@@ -0,0 +1,673 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// setUpstreamResponse marks the response as coming from upstream so onHttpResponseHeaders processes it
|
||||
func setUpstreamResponse(host test.TestHost) {
|
||||
_ = host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
}
|
||||
|
||||
// 测试配置:cooldown only(无 healthCheckModel)
|
||||
var cooldownOnlyConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-token-a", "sk-token-b"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"failureThreshold": 1,
|
||||
"cooldownDuration": 100,
|
||||
"failoverOnStatus": []string{"429"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:cooldown + healthCheck 同时配置
|
||||
var cooldownWithHealthCheckConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-token-a", "sk-token-b"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"failureThreshold": 1,
|
||||
"cooldownDuration": 100,
|
||||
"healthCheckModel": "gpt-3.5-turbo",
|
||||
"healthCheckTimeout": 5000,
|
||||
"failoverOnStatus": []string{"429"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:failover 启用但既没有 cooldown 也没有 healthCheckModel
|
||||
var failoverNoRecoveryConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-token-a"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"failureThreshold": 1,
|
||||
"failoverOnStatus": []string{"429"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:cooldown 较长,用于测试冷却未到期的场景
|
||||
var cooldownLongDurationConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-token-a", "sk-token-b"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"failureThreshold": 1,
|
||||
"cooldownDuration": 600000,
|
||||
"failoverOnStatus": []string{"429"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:三个 token,failureThreshold=2
|
||||
var cooldownThreeTokensConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-token-a", "sk-token-b", "sk-token-c"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"failureThreshold": 2,
|
||||
"cooldownDuration": 100,
|
||||
"failoverOnStatus": []string{"429"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:单个 token + cooldown
|
||||
var cooldownSingleTokenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-only-token"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"failureThreshold": 1,
|
||||
"cooldownDuration": 100,
|
||||
"failoverOnStatus": []string{"429"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:cooldown + 多种 failoverOnStatus
|
||||
var cooldownMultiStatusConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-token-a", "sk-token-b"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"failureThreshold": 1,
|
||||
"cooldownDuration": 100,
|
||||
"failoverOnStatus": []string{"429", "5.*"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// ============ Parse Config Tests ============
|
||||
|
||||
func RunCooldownParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// cooldown only 配置应正常启动
|
||||
t.Run("cooldown only config starts ok", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// cooldown + healthCheck 同时配置应正常启动
|
||||
t.Run("cooldown with healthCheck config starts ok", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownWithHealthCheckConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// failover 启用但既没有 cooldown 也没有 healthCheckModel 应启动失败
|
||||
t.Run("failover without recovery config fails", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(failoverNoRecoveryConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// 单 token + cooldown 配置应正常启动
|
||||
t.Run("single token with cooldown config starts ok", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownSingleTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ============ Failover on 429 Tests ============
|
||||
|
||||
func RunCooldownOnHttpResponseHeadersTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 429 响应应触发 failover 日志
|
||||
t.Run("429 triggers failover", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
|
||||
setUpstreamResponse(host)
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证 failover 日志
|
||||
warnLogs := host.GetWarnLogs()
|
||||
hasFailoverLog := false
|
||||
for _, log := range warnLogs {
|
||||
if strings.Contains(log, "need failover") && strings.Contains(log, "429") {
|
||||
hasFailoverLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasFailoverLog, "Should have failover warning log on 429")
|
||||
})
|
||||
|
||||
// 200 响应不应触发 failover
|
||||
t.Run("200 does not trigger failover", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
|
||||
setUpstreamResponse(host)
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
warnLogs := host.GetWarnLogs()
|
||||
for _, log := range warnLogs {
|
||||
require.False(t, strings.Contains(log, "need failover"), "Should not have failover log on 200")
|
||||
}
|
||||
})
|
||||
|
||||
// 非 failoverOnStatus 的错误码不应触发 failover
|
||||
t.Run("500 does not trigger failover when only 429 configured", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
|
||||
setUpstreamResponse(host)
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "500"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
warnLogs := host.GetWarnLogs()
|
||||
for _, log := range warnLogs {
|
||||
require.False(t, strings.Contains(log, "need failover"), "Should not have failover log on 500 when only 429 configured")
|
||||
}
|
||||
})
|
||||
|
||||
// 多种 failoverOnStatus 匹配测试:500 应触发 failover
|
||||
t.Run("500 triggers failover when 5xx configured", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownMultiStatusConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
|
||||
setUpstreamResponse(host)
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "500"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
warnLogs := host.GetWarnLogs()
|
||||
hasFailoverLog := false
|
||||
for _, log := range warnLogs {
|
||||
if strings.Contains(log, "need failover") {
|
||||
hasFailoverLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasFailoverLog, "Should have failover log on 500 when 5.* configured")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ============ Cooldown Recovery Tests ============
|
||||
|
||||
func RunCooldownRecoveryTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// token 被摘除后,冷却到期后 tick 应恢复
|
||||
t.Run("token recovered after cooldown expires", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 触发一次 429 使 token 被摘除(failureThreshold=1)
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
setUpstreamResponse(host)
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 验证 token 被标记为不可用
|
||||
infoLogs := host.GetInfoLogs()
|
||||
hasUnavailableLog := false
|
||||
for _, log := range infoLogs {
|
||||
if strings.Contains(log, "is unavailable now") {
|
||||
hasUnavailableLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasUnavailableLog, "Token should be marked as unavailable after 429")
|
||||
|
||||
// 等待冷却到期(cooldownDuration=100ms)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// 触发 tick 执行冷却恢复
|
||||
host.Tick()
|
||||
|
||||
// 验证 token 被恢复
|
||||
infoLogs = host.GetInfoLogs()
|
||||
hasRecoveryLog := false
|
||||
for _, log := range infoLogs {
|
||||
if strings.Contains(log, "cooldown recovery") && strings.Contains(log, "restoring to available list") {
|
||||
hasRecoveryLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasRecoveryLog, "Token should be recovered after cooldown expires and tick fires")
|
||||
})
|
||||
|
||||
// 冷却未到期时 tick 不应恢复 token
|
||||
t.Run("token not recovered before cooldown expires", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownLongDurationConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 触发 429 使 token 被摘除
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
setUpstreamResponse(host)
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 立即 tick,冷却未到期(cooldownDuration=600000ms)
|
||||
host.Tick()
|
||||
|
||||
// 验证 token 未被恢复
|
||||
infoLogs := host.GetInfoLogs()
|
||||
for _, log := range infoLogs {
|
||||
require.False(t, strings.Contains(log, "cooldown recovery"),
|
||||
"Token should NOT be recovered before cooldown expires")
|
||||
}
|
||||
})
|
||||
|
||||
// failureThreshold > 1 时,单次失败不应摘除 token
|
||||
t.Run("single failure does not remove token when threshold is 2", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownThreeTokensConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 只触发一次 429(failureThreshold=2)
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
setUpstreamResponse(host)
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 验证 token 未被摘除(需要连续 2 次失败)
|
||||
infoLogs := host.GetInfoLogs()
|
||||
for _, log := range infoLogs {
|
||||
require.False(t, strings.Contains(log, "is unavailable now"),
|
||||
"Token should NOT be removed after single failure when threshold is 2")
|
||||
}
|
||||
|
||||
// 验证有 debug 日志记录失败次数
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasThresholdLog := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "has not reached the failure threshold") {
|
||||
hasThresholdLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasThresholdLog, "Should log that failure threshold not reached")
|
||||
})
|
||||
|
||||
// 单个 token 被摘除后,应使用不可用 token 兜底
|
||||
t.Run("single token fallback to unavailable when all removed", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownSingleTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 触发 429 使唯一的 token 被摘除
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
setUpstreamResponse(host)
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CompleteHttp()
|
||||
|
||||
// 发起新请求,应使用不可用 token 兜底
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist even when all tokens unavailable")
|
||||
require.Contains(t, authValue, "sk-only-token", "Should fallback to the unavailable token")
|
||||
|
||||
// 验证有 warn 日志
|
||||
warnLogs := host.GetWarnLogs()
|
||||
hasAllUnavailableLog := false
|
||||
for _, log := range warnLogs {
|
||||
if strings.Contains(log, "all tokens are unavailable") {
|
||||
hasAllUnavailableLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasAllUnavailableLog, "Should warn that all tokens are unavailable")
|
||||
})
|
||||
|
||||
// 单个 token 被摘除后,冷却到期后恢复,新请求应正常使用
|
||||
t.Run("single token recovered after cooldown and used in new request", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownSingleTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 触发 429 使 token 被摘除
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
setUpstreamResponse(host)
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CompleteHttp()
|
||||
|
||||
// 等待冷却到期
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
host.Tick()
|
||||
|
||||
// 验证恢复日志
|
||||
infoLogs := host.GetInfoLogs()
|
||||
hasRecoveryLog := false
|
||||
for _, log := range infoLogs {
|
||||
if strings.Contains(log, "cooldown recovery") {
|
||||
hasRecoveryLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasRecoveryLog, "Token should be recovered after cooldown")
|
||||
|
||||
// 发起新请求,应正常使用恢复的 token(不再有 all tokens unavailable 警告)
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "sk-only-token", "Should use the recovered token")
|
||||
})
|
||||
|
||||
// 两个 token,一个被摘除后另一个继续使用
|
||||
t.Run("second token used after first is removed", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(cooldownOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 触发 429 使一个 token 被摘除
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
setUpstreamResponse(host)
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CompleteHttp()
|
||||
|
||||
// 验证 token 被摘除
|
||||
infoLogs := host.GetInfoLogs()
|
||||
hasUnavailableLog := false
|
||||
for _, log := range infoLogs {
|
||||
if strings.Contains(log, "is unavailable now") {
|
||||
hasUnavailableLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasUnavailableLog, "One token should be marked unavailable")
|
||||
|
||||
// 发起新请求,应使用剩余的可用 token
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Should use one of the configured tokens")
|
||||
})
|
||||
|
||||
// 成功请求应重置失败计数
|
||||
t.Run("successful request resets failure count", func(t *testing.T) {
|
||||
resetCountConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-reset-token"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"failureThreshold": 2,
|
||||
"cooldownDuration": 100,
|
||||
"failoverOnStatus": []string{"429"},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
host, status := test.NewTestHost(resetCountConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 第一次请求 429(failureThreshold=2,不会摘除)
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
setUpstreamResponse(host)
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CompleteHttp()
|
||||
|
||||
// 第二次请求成功(200),应重置失败计数
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
setUpstreamResponse(host)
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CompleteHttp()
|
||||
|
||||
// 第三次请求 429(因为计数已重置,不应摘除)
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
host.CallOnHttpRequestBody([]byte(`{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"hi"}]}`))
|
||||
setUpstreamResponse(host)
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 验证 token 未被摘除
|
||||
infoLogs := host.GetInfoLogs()
|
||||
for _, log := range infoLogs {
|
||||
require.False(t, strings.Contains(log, "is unavailable now"),
|
||||
"Token should NOT be removed because failure count was reset by successful request")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user