mirror of
https://github.com/alibaba/higress.git
synced 2026-03-19 01:37:28 +08:00
feat(ai-proxy): add consumer affinity for stateful APIs (#3499)
This commit is contained in:
@@ -157,3 +157,8 @@ func TestClaude(t *testing.T) {
|
||||
test.RunClaudeOnHttpRequestHeadersTests(t)
|
||||
test.RunClaudeOnHttpRequestBodyTests(t)
|
||||
}
|
||||
|
||||
func TestConsumerAffinity(t *testing.T) {
|
||||
test.RunConsumerAffinityParseConfigTests(t)
|
||||
test.RunConsumerAffinityOnHttpRequestHeadersTests(t)
|
||||
}
|
||||
|
||||
@@ -605,7 +605,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext) {
|
||||
if c.isFailoverEnabled() {
|
||||
apiToken = c.GetGlobalRandomToken()
|
||||
} else {
|
||||
apiToken = c.GetRandomToken()
|
||||
apiToken = c.GetOrSetTokenWithContext(ctx)
|
||||
}
|
||||
log.Debugf("Use apiToken %s to send request", apiToken)
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"path"
|
||||
@@ -706,12 +707,45 @@ func (c *ProviderConfig) Validate() error {
|
||||
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
|
||||
ctxApiKey := ctx.GetContext(ctxKeyApiKey)
|
||||
if ctxApiKey == nil {
|
||||
ctxApiKey = c.GetRandomToken()
|
||||
token := c.selectApiToken(ctx)
|
||||
ctxApiKey = token
|
||||
ctx.SetContext(ctxKeyApiKey, ctxApiKey)
|
||||
}
|
||||
return ctxApiKey.(string)
|
||||
}
|
||||
|
||||
// selectApiToken selects an API token based on the request context
|
||||
// For stateful APIs, it uses consumer affinity if available
|
||||
func (c *ProviderConfig) selectApiToken(ctx wrapper.HttpContext) string {
|
||||
// Get API name from context if available
|
||||
ctxApiName := ctx.GetContext(CtxKeyApiName)
|
||||
var apiName string
|
||||
if ctxApiName != nil {
|
||||
// ctxApiName is of type ApiName, need to convert to string
|
||||
apiName = string(ctxApiName.(ApiName))
|
||||
}
|
||||
|
||||
// For stateful APIs, try to use consumer affinity
|
||||
if isStatefulAPI(apiName) {
|
||||
consumer := c.getConsumerFromContext(ctx)
|
||||
if consumer != "" {
|
||||
return c.GetTokenWithConsumerAffinity(ctx, consumer)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to random selection
|
||||
return c.GetRandomToken()
|
||||
}
|
||||
|
||||
// getConsumerFromContext retrieves the consumer identifier from the request context
|
||||
func (c *ProviderConfig) getConsumerFromContext(ctx wrapper.HttpContext) string {
|
||||
consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer")
|
||||
if err == nil && consumer != "" {
|
||||
return consumer
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetRandomToken() string {
|
||||
apiTokens := c.apiTokens
|
||||
count := len(apiTokens)
|
||||
@@ -725,6 +759,50 @@ func (c *ProviderConfig) GetRandomToken() string {
|
||||
}
|
||||
}
|
||||
|
||||
// isStatefulAPI checks if the given API name is a stateful API that requires consumer affinity
|
||||
func isStatefulAPI(apiName string) bool {
|
||||
// These APIs maintain session state and should be routed to the same provider consistently
|
||||
statefulAPIs := map[string]bool{
|
||||
string(ApiNameResponses): true, // Response API - uses previous_response_id
|
||||
string(ApiNameFiles): true, // Files API - maintains file state
|
||||
string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload
|
||||
string(ApiNameRetrieveFileContent): true, // File content - depends on file upload
|
||||
string(ApiNameBatches): true, // Batch API - maintains batch state
|
||||
string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation
|
||||
string(ApiNameCancelBatch): true, // Batch operations - depends on batch state
|
||||
string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state
|
||||
string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status
|
||||
string(ApiNameFineTuningJobEvents): true, // Fine-tuning events
|
||||
string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints
|
||||
string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job
|
||||
string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job
|
||||
}
|
||||
return statefulAPIs[apiName]
|
||||
}
|
||||
|
||||
// GetTokenWithConsumerAffinity selects an API token based on consumer affinity
|
||||
// If x-mse-consumer header is present and API is stateful, it will consistently select the same token
|
||||
func (c *ProviderConfig) GetTokenWithConsumerAffinity(ctx wrapper.HttpContext, consumer string) string {
|
||||
apiTokens := c.apiTokens
|
||||
count := len(apiTokens)
|
||||
switch count {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return apiTokens[0]
|
||||
default:
|
||||
// Use FNV-1a hash for consistent token selection
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(consumer))
|
||||
hashValue := h.Sum32()
|
||||
index := int(hashValue) % count
|
||||
if index < 0 {
|
||||
index += count
|
||||
}
|
||||
return apiTokens[index]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) IsOriginal() bool {
|
||||
return c.protocol == protocolOriginal
|
||||
}
|
||||
|
||||
275
plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go
Normal file
275
plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsStatefulAPI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiName string
|
||||
expected bool
|
||||
}{
|
||||
// Stateful APIs - should return true
|
||||
{
|
||||
name: "responses_api",
|
||||
apiName: string(ApiNameResponses),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "files_api",
|
||||
apiName: string(ApiNameFiles),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_file_api",
|
||||
apiName: string(ApiNameRetrieveFile),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_file_content_api",
|
||||
apiName: string(ApiNameRetrieveFileContent),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "batches_api",
|
||||
apiName: string(ApiNameBatches),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_batch_api",
|
||||
apiName: string(ApiNameRetrieveBatch),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cancel_batch_api",
|
||||
apiName: string(ApiNameCancelBatch),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "fine_tuning_jobs_api",
|
||||
apiName: string(ApiNameFineTuningJobs),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve_fine_tuning_job_api",
|
||||
apiName: string(ApiNameRetrieveFineTuningJob),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "fine_tuning_job_events_api",
|
||||
apiName: string(ApiNameFineTuningJobEvents),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "fine_tuning_job_checkpoints_api",
|
||||
apiName: string(ApiNameFineTuningJobCheckpoints),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cancel_fine_tuning_job_api",
|
||||
apiName: string(ApiNameCancelFineTuningJob),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "resume_fine_tuning_job_api",
|
||||
apiName: string(ApiNameResumeFineTuningJob),
|
||||
expected: true,
|
||||
},
|
||||
// Non-stateful APIs - should return false
|
||||
{
|
||||
name: "chat_completion_api",
|
||||
apiName: string(ApiNameChatCompletion),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "completion_api",
|
||||
apiName: string(ApiNameCompletion),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "embeddings_api",
|
||||
apiName: string(ApiNameEmbeddings),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "models_api",
|
||||
apiName: string(ApiNameModels),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "image_generation_api",
|
||||
apiName: string(ApiNameImageGeneration),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "audio_speech_api",
|
||||
apiName: string(ApiNameAudioSpeech),
|
||||
expected: false,
|
||||
},
|
||||
// Empty/unknown API - should return false
|
||||
{
|
||||
name: "empty_api_name",
|
||||
apiName: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unknown_api_name",
|
||||
apiName: "unknown/api",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isStatefulAPI(tt.apiName)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenWithConsumerAffinity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiTokens []string
|
||||
consumer string
|
||||
wantEmpty bool
|
||||
wantToken string // If not empty, expected specific token (for single token case)
|
||||
}{
|
||||
{
|
||||
name: "no_tokens_returns_empty",
|
||||
apiTokens: []string{},
|
||||
consumer: "consumer1",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "nil_tokens_returns_empty",
|
||||
apiTokens: nil,
|
||||
consumer: "consumer1",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "single_token_always_returns_same_token",
|
||||
apiTokens: []string{"token1"},
|
||||
consumer: "consumer1",
|
||||
wantToken: "token1",
|
||||
},
|
||||
{
|
||||
name: "single_token_with_different_consumer",
|
||||
apiTokens: []string{"token1"},
|
||||
consumer: "consumer2",
|
||||
wantToken: "token1",
|
||||
},
|
||||
{
|
||||
name: "multiple_tokens_consistent_for_same_consumer",
|
||||
apiTokens: []string{"token1", "token2", "token3"},
|
||||
consumer: "consumer1",
|
||||
wantEmpty: false, // Will get one of the tokens, consistently
|
||||
},
|
||||
{
|
||||
name: "multiple_tokens_different_consumers_may_get_different_tokens",
|
||||
apiTokens: []string{"token1", "token2"},
|
||||
consumer: "consumerA",
|
||||
wantEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
apiTokens: tt.apiTokens,
|
||||
}
|
||||
|
||||
result := config.GetTokenWithConsumerAffinity(nil, tt.consumer)
|
||||
|
||||
if tt.wantEmpty {
|
||||
assert.Empty(t, result)
|
||||
} else if tt.wantToken != "" {
|
||||
assert.Equal(t, tt.wantToken, result)
|
||||
} else {
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, tt.apiTokens, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenWithConsumerAffinity_Consistency(t *testing.T) {
|
||||
// Test that the same consumer always gets the same token (consistency)
|
||||
config := &ProviderConfig{
|
||||
apiTokens: []string{"token1", "token2", "token3", "token4", "token5"},
|
||||
}
|
||||
|
||||
t.Run("same_consumer_gets_same_token_repeatedly", func(t *testing.T) {
|
||||
consumer := "test-consumer"
|
||||
var firstResult string
|
||||
|
||||
// Call multiple times and verify consistency
|
||||
for i := 0; i < 10; i++ {
|
||||
result := config.GetTokenWithConsumerAffinity(nil, consumer)
|
||||
if i == 0 {
|
||||
firstResult = result
|
||||
}
|
||||
assert.Equal(t, firstResult, result, "Consumer should consistently get the same token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different_consumers_distribute_across_tokens", func(t *testing.T) {
|
||||
// Use multiple consumers and verify they distribute across tokens
|
||||
consumers := []string{"consumer1", "consumer2", "consumer3", "consumer4", "consumer5", "consumer6", "consumer7", "consumer8", "consumer9", "consumer10"}
|
||||
tokenCounts := make(map[string]int)
|
||||
|
||||
for _, consumer := range consumers {
|
||||
token := config.GetTokenWithConsumerAffinity(nil, consumer)
|
||||
tokenCounts[token]++
|
||||
}
|
||||
|
||||
// Verify all tokens returned are valid
|
||||
for token := range tokenCounts {
|
||||
assert.Contains(t, config.apiTokens, token)
|
||||
}
|
||||
|
||||
// With 10 consumers and 5 tokens, we expect some distribution
|
||||
// (not necessarily perfect distribution, but should use multiple tokens)
|
||||
assert.GreaterOrEqual(t, len(tokenCounts), 2, "Should use at least 2 different tokens")
|
||||
})
|
||||
|
||||
t.Run("empty_consumer_returns_empty_string", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
apiTokens: []string{"token1", "token2"},
|
||||
}
|
||||
result := config.GetTokenWithConsumerAffinity(nil, "")
|
||||
// Empty consumer still returns a token (hash of empty string)
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, []string{"token1", "token2"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetTokenWithConsumerAffinity_HashDistribution(t *testing.T) {
|
||||
// Test that the hash function distributes consumers reasonably across tokens
|
||||
config := &ProviderConfig{
|
||||
apiTokens: []string{"token1", "token2", "token3"},
|
||||
}
|
||||
|
||||
// Test specific consumers to verify hash behavior
|
||||
testCases := []struct {
|
||||
consumer string
|
||||
expectValid bool
|
||||
}{
|
||||
{"user-alice", true},
|
||||
{"user-bob", true},
|
||||
{"user-charlie", true},
|
||||
{"service-api-v1", true},
|
||||
{"service-api-v2", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("consumer_"+tc.consumer, func(t *testing.T) {
|
||||
result := config.GetTokenWithConsumerAffinity(nil, tc.consumer)
|
||||
assert.True(t, tc.expectValid)
|
||||
assert.Contains(t, config.apiTokens, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
292
plugins/wasm-go/extensions/ai-proxy/test/consumer_affinity.go
Normal file
292
plugins/wasm-go/extensions/ai-proxy/test/consumer_affinity.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:多 API Token 配置(用于测试 consumer affinity)
|
||||
var multiTokenOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-token-1", "sk-token-2", "sk-token-3"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:单 API Token 配置
|
||||
var singleTokenOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-single-token"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunConsumerAffinityParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("multi token config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunConsumerAffinityOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 stateful API(responses)使用 consumer affinity
|
||||
t.Run("stateful api responses with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 使用 x-mse-consumer header
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-alice"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证 Authorization header 使用了其中一个 token
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试 stateful API(files)使用 consumer affinity
|
||||
t.Run("stateful api files with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/files"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-files"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试 stateful API(batches)使用 consumer affinity
|
||||
t.Run("stateful api batches with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/batches"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-batches"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试 stateful API(fine_tuning)使用 consumer affinity
|
||||
t.Run("stateful api fine_tuning with consumer header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/fine_tuning/jobs"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-finetuning"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试非 stateful API 正常工作
|
||||
t.Run("non stateful api chat completions works normally", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-chat"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试无 x-mse-consumer header 时正常工作
|
||||
t.Run("stateful api without consumer header works normally", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens")
|
||||
})
|
||||
|
||||
// 测试单个 token 时始终使用该 token
|
||||
t.Run("single token always used", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(singleTokenOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", "consumer-test"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, _ := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.Contains(t, authValue, "sk-single-token", "Single token should always be used")
|
||||
})
|
||||
|
||||
// 测试同一 consumer 多次请求获得相同 token(consumer affinity 一致性)
|
||||
t.Run("same consumer gets consistent token across requests", func(t *testing.T) {
|
||||
consumer := "consumer-consistency-test"
|
||||
var firstToken string
|
||||
|
||||
// 运行 5 次请求,验证同一个 consumer 始终获得相同的 token
|
||||
for i := 0; i < 5; i++ {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", consumer},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.True(t, strings.Contains(authValue, "sk-token-"), "Should use one of the configured tokens")
|
||||
|
||||
if i == 0 {
|
||||
firstToken = authValue
|
||||
} else {
|
||||
require.Equal(t, firstToken, authValue, "Same consumer should get same token consistently (consumer affinity)")
|
||||
}
|
||||
|
||||
host.Reset()
|
||||
}
|
||||
})
|
||||
|
||||
// 测试不同 consumer 可能获得不同 token
|
||||
t.Run("different consumers get tokens based on hash", func(t *testing.T) {
|
||||
tokens := make(map[string]string)
|
||||
|
||||
consumers := []string{"consumer-alpha", "consumer-beta", "consumer-gamma", "consumer-delta", "consumer-epsilon"}
|
||||
for _, consumer := range consumers {
|
||||
host, status := test.NewTestHost(multiTokenOpenAIConfig)
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/responses"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"x-mse-consumer", consumer},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
authValue, _ := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
tokens[consumer] = authValue
|
||||
|
||||
host.Reset()
|
||||
}
|
||||
|
||||
// 验证至少使用了多个不同的 token(hash 分布)
|
||||
uniqueTokens := make(map[string]bool)
|
||||
for _, token := range tokens {
|
||||
uniqueTokens[token] = true
|
||||
}
|
||||
require.GreaterOrEqual(t, len(uniqueTokens), 2, "Different consumers should use at least 2 different tokens")
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user