diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index ffae63f0b..947a57fdd 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -157,3 +157,8 @@ func TestClaude(t *testing.T) { test.RunClaudeOnHttpRequestHeadersTests(t) test.RunClaudeOnHttpRequestBodyTests(t) } + +func TestConsumerAffinity(t *testing.T) { + test.RunConsumerAffinityParseConfigTests(t) + test.RunConsumerAffinityOnHttpRequestHeadersTests(t) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 60aaf36ca..9c1925f61 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -605,7 +605,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext) { if c.isFailoverEnabled() { apiToken = c.GetGlobalRandomToken() } else { - apiToken = c.GetRandomToken() + apiToken = c.GetOrSetTokenWithContext(ctx) } log.Debugf("Use apiToken %s to send request", apiToken) ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index d6eb72a0f..cd4d1cab5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "hash/fnv" "math/rand" "net/http" "path" @@ -706,12 +707,45 @@ func (c *ProviderConfig) Validate() error { func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string { ctxApiKey := ctx.GetContext(ctxKeyApiKey) if ctxApiKey == nil { - ctxApiKey = c.GetRandomToken() + token := c.selectApiToken(ctx) + ctxApiKey = token ctx.SetContext(ctxKeyApiKey, ctxApiKey) } return ctxApiKey.(string) } +// selectApiToken selects an API token based on the request context +// For stateful APIs, it uses consumer affinity if available +func (c *ProviderConfig) selectApiToken(ctx wrapper.HttpContext) string { + // Get API name from context if available + ctxApiName := ctx.GetContext(CtxKeyApiName) + var apiName string + if ctxApiName != nil { + // ctxApiName is of type ApiName, need to convert to string + apiName = string(ctxApiName.(ApiName)) + } + + // For stateful APIs, try to use consumer affinity + if isStatefulAPI(apiName) { + consumer := c.getConsumerFromContext(ctx) + if consumer != "" { + return c.GetTokenWithConsumerAffinity(ctx, consumer) + } + } + + // Fall back to random selection + return c.GetRandomToken() +} + +// getConsumerFromContext retrieves the consumer identifier from the request context +func (c *ProviderConfig) getConsumerFromContext(ctx wrapper.HttpContext) string { + consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer") + if err == nil && consumer != "" { + return consumer + } + return "" +} + func (c *ProviderConfig) GetRandomToken() string { apiTokens := c.apiTokens count := len(apiTokens) @@ -725,6 +759,50 @@ func (c *ProviderConfig) GetRandomToken() string { } } +// isStatefulAPI checks if the given API name is a stateful API that requires consumer affinity +func isStatefulAPI(apiName string) bool { + // These APIs maintain session state and should be routed to the same provider consistently + statefulAPIs := map[string]bool{ + string(ApiNameResponses): true, // Response API - uses previous_response_id + string(ApiNameFiles): true, // Files API - maintains file state + string(ApiNameRetrieveFile): true, // File retrieval - depends on file upload + string(ApiNameRetrieveFileContent): true, // File content - depends on file upload + string(ApiNameBatches): true, // Batch API - maintains batch state + string(ApiNameRetrieveBatch): true, // Batch status - depends on batch creation + string(ApiNameCancelBatch): true, // Batch operations - depends on batch state + string(ApiNameFineTuningJobs): true, // Fine-tuning - maintains job state + string(ApiNameRetrieveFineTuningJob): true, // Fine-tuning job status + string(ApiNameFineTuningJobEvents): true, // Fine-tuning events + string(ApiNameFineTuningJobCheckpoints): true, // Fine-tuning checkpoints + string(ApiNameCancelFineTuningJob): true, // Cancel fine-tuning job + string(ApiNameResumeFineTuningJob): true, // Resume fine-tuning job + } + return statefulAPIs[apiName] +} + +// GetTokenWithConsumerAffinity selects an API token based on consumer affinity +// If x-mse-consumer header is present and API is stateful, it will consistently select the same token +func (c *ProviderConfig) GetTokenWithConsumerAffinity(ctx wrapper.HttpContext, consumer string) string { + apiTokens := c.apiTokens + count := len(apiTokens) + switch count { + case 0: + return "" + case 1: + return apiTokens[0] + default: + // Use FNV-1a hash for consistent token selection + h := fnv.New32a() + h.Write([]byte(consumer)) + hashValue := h.Sum32() + index := int(hashValue) % count + if index < 0 { + index += count + } + return apiTokens[index] + } +} + func (c *ProviderConfig) IsOriginal() bool { return c.protocol == protocolOriginal } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go new file mode 100644 index 000000000..c061c438d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go @@ -0,0 +1,275 @@ +package provider + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsStatefulAPI(t *testing.T) { + tests := []struct { + name string + apiName string + expected bool + }{ + // Stateful APIs - should return true + { + name: "responses_api", + apiName: string(ApiNameResponses), + expected: true, + }, + { + name: "files_api", + apiName: string(ApiNameFiles), + expected: true, + }, + { + name: "retrieve_file_api", + apiName: string(ApiNameRetrieveFile), + expected: true, + }, + { + name: "retrieve_file_content_api", + apiName: string(ApiNameRetrieveFileContent), + expected: true, + }, + { + name: "batches_api", + apiName: string(ApiNameBatches), + expected: true, + }, + { + name: "retrieve_batch_api", + apiName: string(ApiNameRetrieveBatch), + expected: true, + }, + { + name: "cancel_batch_api", + apiName: string(ApiNameCancelBatch), + expected: true, + }, + { + name: "fine_tuning_jobs_api", + apiName: string(ApiNameFineTuningJobs), + expected: true, + }, + { + name: "retrieve_fine_tuning_job_api", + apiName: string(ApiNameRetrieveFineTuningJob), + expected: true, + }, + { + name: "fine_tuning_job_events_api", + apiName: string(ApiNameFineTuningJobEvents), + expected: true, + }, + { + name: "fine_tuning_job_checkpoints_api", + apiName: string(ApiNameFineTuningJobCheckpoints), + expected: true, + }, + { + name: "cancel_fine_tuning_job_api", + apiName: string(ApiNameCancelFineTuningJob), + expected: true, + }, + { + name: "resume_fine_tuning_job_api", + apiName: string(ApiNameResumeFineTuningJob), + expected: true, + }, + // Non-stateful APIs - should return false + { + name: "chat_completion_api", + apiName: string(ApiNameChatCompletion), + expected: false, + }, + { + name: "completion_api", + apiName: string(ApiNameCompletion), + expected: false, + }, + { + name: "embeddings_api", + apiName: string(ApiNameEmbeddings), + expected: false, + }, + { + name: "models_api", + apiName: string(ApiNameModels), + expected: false, + }, + { + name: "image_generation_api", + apiName: string(ApiNameImageGeneration), + expected: false, + }, + { + name: "audio_speech_api", + apiName: string(ApiNameAudioSpeech), + expected: false, + }, + // Empty/unknown API - should return false + { + name: "empty_api_name", + apiName: "", + expected: false, + }, + { + name: "unknown_api_name", + apiName: "unknown/api", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isStatefulAPI(tt.apiName) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGetTokenWithConsumerAffinity(t *testing.T) { + tests := []struct { + name string + apiTokens []string + consumer string + wantEmpty bool + wantToken string // If not empty, expected specific token (for single token case) + }{ + { + name: "no_tokens_returns_empty", + apiTokens: []string{}, + consumer: "consumer1", + wantEmpty: true, + }, + { + name: "nil_tokens_returns_empty", + apiTokens: nil, + consumer: "consumer1", + wantEmpty: true, + }, + { + name: "single_token_always_returns_same_token", + apiTokens: []string{"token1"}, + consumer: "consumer1", + wantToken: "token1", + }, + { + name: "single_token_with_different_consumer", + apiTokens: []string{"token1"}, + consumer: "consumer2", + wantToken: "token1", + }, + { + name: "multiple_tokens_consistent_for_same_consumer", + apiTokens: []string{"token1", "token2", "token3"}, + consumer: "consumer1", + wantEmpty: false, // Will get one of the tokens, consistently + }, + { + name: "multiple_tokens_different_consumers_may_get_different_tokens", + apiTokens: []string{"token1", "token2"}, + consumer: "consumerA", + wantEmpty: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &ProviderConfig{ + apiTokens: tt.apiTokens, + } + + result := config.GetTokenWithConsumerAffinity(nil, tt.consumer) + + if tt.wantEmpty { + assert.Empty(t, result) + } else if tt.wantToken != "" { + assert.Equal(t, tt.wantToken, result) + } else { + assert.NotEmpty(t, result) + assert.Contains(t, tt.apiTokens, result) + } + }) + } +} + +func TestGetTokenWithConsumerAffinity_Consistency(t *testing.T) { + // Test that the same consumer always gets the same token (consistency) + config := &ProviderConfig{ + apiTokens: []string{"token1", "token2", "token3", "token4", "token5"}, + } + + t.Run("same_consumer_gets_same_token_repeatedly", func(t *testing.T) { + consumer := "test-consumer" + var firstResult string + + // Call multiple times and verify consistency + for i := 0; i < 10; i++ { + result := config.GetTokenWithConsumerAffinity(nil, consumer) + if i == 0 { + firstResult = result + } + assert.Equal(t, firstResult, result, "Consumer should consistently get the same token") + } + }) + + t.Run("different_consumers_distribute_across_tokens", func(t *testing.T) { + // Use multiple consumers and verify they distribute across tokens + consumers := []string{"consumer1", "consumer2", "consumer3", "consumer4", "consumer5", "consumer6", "consumer7", "consumer8", "consumer9", "consumer10"} + tokenCounts := make(map[string]int) + + for _, consumer := range consumers { + token := config.GetTokenWithConsumerAffinity(nil, consumer) + tokenCounts[token]++ + } + + // Verify all tokens returned are valid + for token := range tokenCounts { + assert.Contains(t, config.apiTokens, token) + } + + // With 10 consumers and 5 tokens, we expect some distribution + // (not necessarily perfect distribution, but should use multiple tokens) + assert.GreaterOrEqual(t, len(tokenCounts), 2, "Should use at least 2 different tokens") + }) + + t.Run("empty_consumer_returns_empty_string", func(t *testing.T) { + config := &ProviderConfig{ + apiTokens: []string{"token1", "token2"}, + } + result := config.GetTokenWithConsumerAffinity(nil, "") + // Empty consumer still returns a token (hash of empty string) + assert.NotEmpty(t, result) + assert.Contains(t, []string{"token1", "token2"}, result) + }) +} + +func TestGetTokenWithConsumerAffinity_HashDistribution(t *testing.T) { + // Test that the hash function distributes consumers reasonably across tokens + config := &ProviderConfig{ + apiTokens: []string{"token1", "token2", "token3"}, + } + + // Test specific consumers to verify hash behavior + testCases := []struct { + consumer string + expectValid bool + }{ + {"user-alice", true}, + {"user-bob", true}, + {"user-charlie", true}, + {"service-api-v1", true}, + {"service-api-v2", true}, + } + + for _, tc := range testCases { + t.Run("consumer_"+tc.consumer, func(t *testing.T) { + result := config.GetTokenWithConsumerAffinity(nil, tc.consumer) + assert.True(t, tc.expectValid) + assert.Contains(t, config.apiTokens, result) + }) + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/consumer_affinity.go b/plugins/wasm-go/extensions/ai-proxy/test/consumer_affinity.go new file mode 100644 index 000000000..dc11b3c64 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/consumer_affinity.go @@ -0,0 +1,292 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +// 测试配置:多 API Token 配置(用于测试 consumer affinity) +var multiTokenOpenAIConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-token-1", "sk-token-2", "sk-token-3"}, + "modelMapping": map[string]string{ + "*": "gpt-4", + }, + }, + }) + return data +}() + +// 测试配置:单 API Token 配置 +var singleTokenOpenAIConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "openai", + "apiTokens": []string{"sk-single-token"}, + "modelMapping": map[string]string{ + "*": "gpt-4", + }, + }, + }) + return data +}() + +func RunConsumerAffinityParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("multi token config", func(t *testing.T) { + host, status := test.NewTestHost(multiTokenOpenAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + }) +} + +func RunConsumerAffinityOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 stateful API(responses)使用 consumer affinity + t.Run("stateful api responses with consumer header", func(t *testing.T) { + host, status := test.NewTestHost(multiTokenOpenAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 使用 x-mse-consumer header + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/responses"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + {"x-mse-consumer", "consumer-alice"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + // 验证 Authorization header 使用了其中一个 token + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens") + }) + + // 测试 stateful API(files)使用 consumer affinity + t.Run("stateful api files with consumer header", func(t *testing.T) { + host, status := test.NewTestHost(multiTokenOpenAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/files"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + {"x-mse-consumer", "consumer-files"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens") + }) + + // 测试 stateful API(batches)使用 consumer affinity + t.Run("stateful api batches with consumer header", func(t *testing.T) { + host, status := test.NewTestHost(multiTokenOpenAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/batches"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + {"x-mse-consumer", "consumer-batches"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens") + }) + + // 测试 stateful API(fine_tuning)使用 consumer affinity + t.Run("stateful api fine_tuning with consumer header", func(t *testing.T) { + host, status := test.NewTestHost(multiTokenOpenAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/fine_tuning/jobs"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + {"x-mse-consumer", "consumer-finetuning"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens") + }) + + // 测试非 stateful API 正常工作 + t.Run("non stateful api chat completions works normally", func(t *testing.T) { + host, status := test.NewTestHost(multiTokenOpenAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + {"x-mse-consumer", "consumer-chat"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens") + }) + + // 测试无 x-mse-consumer header 时正常工作 + t.Run("stateful api without consumer header works normally", func(t *testing.T) { + host, status := test.NewTestHost(multiTokenOpenAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/responses"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.True(t, strings.Contains(authValue, "sk-token-"), "Authorization should contain one of the tokens") + }) + + // 测试单个 token 时始终使用该 token + t.Run("single token always used", func(t *testing.T) { + host, status := test.NewTestHost(singleTokenOpenAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/responses"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + {"x-mse-consumer", "consumer-test"}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + authValue, _ := test.GetHeaderValue(requestHeaders, "Authorization") + require.Contains(t, authValue, "sk-single-token", "Single token should always be used") + }) + + // 测试同一 consumer 多次请求获得相同 token(consumer affinity 一致性) + t.Run("same consumer gets consistent token across requests", func(t *testing.T) { + consumer := "consumer-consistency-test" + var firstToken string + + // 运行 5 次请求,验证同一个 consumer 始终获得相同的 token + for i := 0; i < 5; i++ { + host, status := test.NewTestHost(multiTokenOpenAIConfig) + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/responses"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + {"x-mse-consumer", consumer}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth, "Authorization header should exist") + require.True(t, strings.Contains(authValue, "sk-token-"), "Should use one of the configured tokens") + + if i == 0 { + firstToken = authValue + } else { + require.Equal(t, firstToken, authValue, "Same consumer should get same token consistently (consumer affinity)") + } + + host.Reset() + } + }) + + // 测试不同 consumer 可能获得不同 token + t.Run("different consumers get tokens based on hash", func(t *testing.T) { + tokens := make(map[string]string) + + consumers := []string{"consumer-alpha", "consumer-beta", "consumer-gamma", "consumer-delta", "consumer-epsilon"} + for _, consumer := range consumers { + host, status := test.NewTestHost(multiTokenOpenAIConfig) + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/responses"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + {"x-mse-consumer", consumer}, + }) + + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + authValue, _ := test.GetHeaderValue(requestHeaders, "Authorization") + tokens[consumer] = authValue + + host.Reset() + } + + // 验证至少使用了多个不同的 token(hash 分布) + uniqueTokens := make(map[string]bool) + for _, token := range tokens { + uniqueTokens[token] = true + } + require.GreaterOrEqual(t, len(uniqueTokens), 2, "Different consumers should use at least 2 different tokens") + }) + }) +}