From 1c4fe1c9f993de2c134b729e165f0c4be7b391b8 Mon Sep 17 00:00:00 2001 From: Jingze <52855280+Jing-ze@users.noreply.github.com> Date: Tue, 12 May 2026 10:20:08 +0800 Subject: [PATCH] test(ai-proxy): expand wasm integration tests, coverage, and provider matrix (#3790) Signed-off-by: jingze --- .../extensions/ai-proxy/config/config.go | 5 + .../extensions/ai-proxy/config/config_test.go | 164 +++++++++++++ .../extensions/ai-proxy/export_test.go | 27 ++ .../wasm-go/extensions/ai-proxy/main_test.go | 182 ++++++++++++++ .../extensions/ai-proxy/parse_config_test.go | 60 +++++ .../provider/claude_to_openai_test.go | 19 ++ .../ai-proxy/provider/provider_test.go | 39 +++ .../ai-proxy/provider/retry_test.go | 126 ++++++++++ .../provider/streaming_extract_test.go | 51 ++++ .../ai-proxy/streaming_matrix_test.go | 135 ++++++++++ .../extensions/ai-proxy/test/deepseek.go | 92 +++++++ .../extensions/ai-proxy/test/doubao.go | 92 +++++++ .../extensions/ai-proxy/test/github.go | 93 +++++++ .../wasm-go/extensions/ai-proxy/test/grok.go | 92 +++++++ .../wasm-go/extensions/ai-proxy/test/groq.go | 92 +++++++ .../extensions/ai-proxy/test/main_edges.go | 144 +++++++++++ .../extensions/ai-proxy/test/mistral.go | 92 +++++++ .../extensions/ai-proxy/test/moonshot.go | 92 +++++++ .../extensions/ai-proxy/test/openai.go | 4 + .../ai-proxy/test/provider_wasm_smoke.go | 230 ++++++++++++++++++ .../wasm-go/extensions/ai-proxy/test/spark.go | 79 ++++++ .../extensions/ai-proxy/test/together_ai.go | 92 +++++++ .../wasm-go/extensions/ai-proxy/test/util.go | 8 + .../ai-proxy/util/header_slice_test.go | 39 +++ .../extensions/ai-proxy/util/string_test.go | 27 ++ 25 files changed, 2076 insertions(+) create mode 100644 plugins/wasm-go/extensions/ai-proxy/config/config_test.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/export_test.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/parse_config_test.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/retry_test.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/streaming_extract_test.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/streaming_matrix_test.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/deepseek.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/doubao.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/github.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/grok.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/groq.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/main_edges.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/mistral.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/moonshot.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/provider_wasm_smoke.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/spark.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/together_ai.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/util/header_slice_test.go diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go index cd76fd80..fa81e668 100644 --- a/plugins/wasm-go/extensions/ai-proxy/config/config.go +++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go @@ -98,3 +98,8 @@ func (c *PluginConfig) GetProvider() provider.Provider { func (c *PluginConfig) GetProviderConfig() *provider.ProviderConfig { return c.activeProviderConfig } + +// SetActiveProviderForTest replaces the runtime Provider after Complete(); intended for unit tests in package main only. +func (c *PluginConfig) SetActiveProviderForTest(p provider.Provider) { + c.activeProvider = p +} diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config_test.go b/plugins/wasm-go/extensions/ai-proxy/config/config_test.go new file mode 100644 index 00000000..352526ff --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/config/config_test.go @@ -0,0 +1,164 @@ +package config + +import ( + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +func TestPluginConfig_FromJsonAndValidate(t *testing.T) { + tests := []struct { + name string + json string + wantErr string + wantNilPC bool + wantID string + wantType string + }{ + { + name: "legacy_single_provider_object", + json: `{"provider":{"type":"generic","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]}}`, + wantNilPC: false, + wantType: "generic", + }, + { + name: "providers_without_active_id_validate_ok", + json: `{"providers":[ + {"id":"a","type":"generic","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]}, + {"id":"b","type":"generic","genericHost":"http://127.0.0.1:8081","apiTokens":["u"]} + ]}`, + wantNilPC: true, + }, + { + name: "providers_with_active_id", + json: `{"providers":[ + {"id":"p1","type":"generic","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]}, + {"id":"p2","type":"generic","genericHost":"http://127.0.0.1:8081","apiTokens":["u"]} + ],"activeProviderId":"p2"}`, + wantNilPC: false, + wantID: "p2", + wantType: "generic", + }, + { + name: "active_id_not_found", + json: `{"providers":[ + {"id":"p1","type":"generic","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]} + ],"activeProviderId":"missing"}`, + wantNilPC: true, + }, + { + name: "invalid_protocol", + json: `{"providers":[{"id":"x","type":"generic","protocol":"badproto","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]}],"activeProviderId":"x"}`, + wantErr: "invalid protocol", + }, + { + name: "missing_type", + json: `{"providers":[{"id":"x","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]}],"activeProviderId":"x"}`, + wantErr: "missing type", + }, + { + name: "unknown_provider_type", + json: `{"providers":[{"id":"x","type":"not-a-real-provider","apiTokens":["t"]}],"activeProviderId":"x"}`, + wantErr: "unknown provider type", + }, + { + name: "initializer_validate_azure_missing_url", + json: `{"providers":[{"id":"x","type":"azure","apiTokens":["t"]}],"activeProviderId":"x"}`, + wantErr: "missing azureServiceUrl", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var c PluginConfig + c.FromJson(gjson.Parse(tt.json)) + err := c.Validate() + if tt.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Validate() err = %v, want substring %q", err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("Validate() = %v", err) + } + pc := c.GetProviderConfig() + if tt.wantNilPC { + if pc != nil { + t.Fatalf("GetProviderConfig() = %p, want nil", pc) + } + } else { + if pc == nil { + t.Fatal("GetProviderConfig() = nil, want non-nil") + } + if tt.wantID != "" && pc.GetId() != tt.wantID { + t.Errorf("GetId() = %q, want %q", pc.GetId(), tt.wantID) + } + if tt.wantType != "" && pc.GetType() != tt.wantType { + t.Errorf("GetType() = %q, want %q", pc.GetType(), tt.wantType) + } + } + }) + } +} + +func TestPluginConfig_OverrideMergeSimulatesParseOverride(t *testing.T) { + globalJSON := `{"providers":[ + {"id":"p1","type":"generic","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]}, + {"id":"p2","type":"generic","genericHost":"http://127.0.0.1:8081","apiTokens":["u"]} + ],"activeProviderId":"p1"}` + + t.Run("switch_active_provider_id", func(t *testing.T) { + var global PluginConfig + global.FromJson(gjson.Parse(globalJSON)) + if err := global.Validate(); err != nil { + t.Fatal(err) + } + if global.GetProviderConfig().GetId() != "p1" { + t.Fatalf("global active id = %q", global.GetProviderConfig().GetId()) + } + + rule := global + rule.FromJson(gjson.Parse(`{"activeProviderId":"p2"}`)) + if err := rule.Validate(); err != nil { + t.Fatal(err) + } + if got := rule.GetProviderConfig().GetId(); got != "p2" { + t.Errorf("after override active id = %q, want p2", got) + } + }) + + t.Run("empty_override_json_clears_active", func(t *testing.T) { + var global PluginConfig + global.FromJson(gjson.Parse(globalJSON)) + if err := global.Validate(); err != nil { + t.Fatal(err) + } + + rule := global + rule.FromJson(gjson.Parse(`{}`)) + if err := rule.Validate(); err != nil { + t.Fatal(err) + } + if rule.GetProviderConfig() != nil { + t.Errorf("after empty override, GetProviderConfig() = %v, want nil", rule.GetProviderConfig()) + } + }) + + t.Run("clear_active_with_empty_string_id", func(t *testing.T) { + var global PluginConfig + global.FromJson(gjson.Parse(globalJSON)) + if err := global.Validate(); err != nil { + t.Fatal(err) + } + + rule := global + rule.FromJson(gjson.Parse(`{"activeProviderId":""}`)) + if err := rule.Validate(); err != nil { + t.Fatal(err) + } + if rule.GetProviderConfig() != nil { + t.Errorf("GetProviderConfig() = %v, want nil", rule.GetProviderConfig()) + } + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/export_test.go b/plugins/wasm-go/extensions/ai-proxy/export_test.go new file mode 100644 index 00000000..e2d7800f --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/export_test.go @@ -0,0 +1,27 @@ +package main + +import ( + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +// NeedsClaudeResponseConversionForTest exposes needsClaudeResponseConversion for unit tests. +func NeedsClaudeResponseConversionForTest(ctx wrapper.HttpContext) bool { + return needsClaudeResponseConversion(ctx) +} + +// ParseGlobalConfigForTest exposes parseGlobalConfig for unit tests. +func ParseGlobalConfigForTest(json gjson.Result, pluginConfig *config.PluginConfig) error { + return parseGlobalConfig(json, pluginConfig) +} + +// ParseOverrideRuleConfigForTest exposes parseOverrideRuleConfig for unit tests. +func ParseOverrideRuleConfigForTest(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig) error { + return parseOverrideRuleConfig(json, global, pluginConfig) +} + +// OnStreamingResponseBodyForTest exposes onStreamingResponseBody for unit tests. +func OnStreamingResponseBodyForTest(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool) []byte { + return onStreamingResponseBody(ctx, pluginConfig, chunk, isLastChunk) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index f06e4a33..178e1e1b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -1,10 +1,12 @@ package main import ( + "strings" "testing" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test" + "github.com/tidwall/gjson" ) func Test_getApiName(t *testing.T) { @@ -24,6 +26,9 @@ func Test_getApiName(t *testing.T) { {"openai realtime", "/v1/realtime", provider.ApiNameRealtime}, {"openai realtime with prefix", "/proxy/v1/realtime", provider.ApiNameRealtime}, {"openai realtime with trailing slash", "/v1/realtime/", ""}, + {"openai chat completions with path_prefix", "/gateway/proxy/v1/chat/completions", provider.ApiNameChatCompletion}, + {"openai chat completions_extra_path_not_suffix_match", "/v1/chat/completions/extra", ""}, + {"openai realtime_with_query_not_matched_as_suffix", "/v1/realtime?stream=1", ""}, {"openai image generation", "/v1/images/generations", provider.ApiNameImageGeneration}, {"openai image variation", "/v1/images/variations", provider.ApiNameImageVariation}, {"openai image edit", "/v1/images/edits", provider.ApiNameImageEdit}, @@ -109,6 +114,30 @@ func Test_isSupportedRequestContentType(t *testing.T) { contentType: "text/plain", want: false, }, + { + name: "json_with_charset", + apiName: provider.ApiNameChatCompletion, + contentType: "application/json; charset=utf-8", + want: true, + }, + { + name: "multipart_uppercase_image_edit", + apiName: provider.ApiNameImageEdit, + contentType: "MULTIPART/FORM-DATA; boundary=abc", + want: true, + }, + { + name: "multipart_image_generation_not_allowed", + apiName: provider.ApiNameImageGeneration, + contentType: "multipart/form-data; boundary=----boundary", + want: false, + }, + { + name: "multipart_embeddings_not_allowed", + apiName: provider.ApiNameEmbeddings, + contentType: "multipart/form-data; boundary=----boundary", + want: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -120,6 +149,94 @@ func Test_isSupportedRequestContentType(t *testing.T) { } } +func Test_normalizeOpenAiRequestBody(t *testing.T) { + t.Run("stream_adds_include_usage", func(t *testing.T) { + in := []byte(`{"model":"x","stream":true}`) + got := normalizeOpenAiRequestBody(in) + if !gjson.GetBytes(got, "stream_options.include_usage").Bool() { + t.Fatalf("want stream_options.include_usage true, got %s", string(got)) + } + }) + t.Run("stream_false_no_stream_options", func(t *testing.T) { + in := []byte(`{"model":"x","stream":false}`) + got := normalizeOpenAiRequestBody(in) + if gjson.GetBytes(got, "stream_options").Exists() { + t.Fatalf("did not expect stream_options, got %s", string(got)) + } + }) + t.Run("respect_explicit_include_usage_false", func(t *testing.T) { + in := []byte(`{"model":"x","stream":true,"stream_options":{"include_usage":false}}`) + got := normalizeOpenAiRequestBody(in) + if gjson.GetBytes(got, "stream_options.include_usage").Bool() { + t.Fatalf("want include_usage false, got %s", string(got)) + } + }) + t.Run("stream_missing_no_stream_options", func(t *testing.T) { + in := []byte(`{"model":"x"}`) + got := normalizeOpenAiRequestBody(in) + if gjson.GetBytes(got, "stream_options").Exists() { + t.Fatalf("unexpected stream_options: %s", string(got)) + } + }) + t.Run("stream_non_bool_treated_as_false", func(t *testing.T) { + in := []byte(`{"model":"x","stream":"yes"}`) + got := normalizeOpenAiRequestBody(in) + if gjson.GetBytes(got, "stream_options").Exists() { + t.Fatalf("unexpected stream_options for non-bool stream: %s", string(got)) + } + }) +} + +func Test_convertResponseBodyToClaude_glue(t *testing.T) { + ctx := test.NewMockHttpContext() + openaiBody := []byte(`{"id":"id1","object":"chat.completion","created":1,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"hello"}}]}`) + + out, err := convertResponseBodyToClaude(ctx, openaiBody) + if err != nil || string(out) != string(openaiBody) { + t.Fatalf("without flag: err=%v out=%s", err, string(out)) + } + // Full OpenAI→Claude conversion runs log.Debugf inside the provider and requires a Wasm host + // when this package's init() has registered the plugin (see provider/claude_to_openai_test.go). +} + +func Test_convertStreamingResponseToClaude_glue(t *testing.T) { + chunk := []byte("data: {\"x\":1}\n\n") + ctx := test.NewMockHttpContext() + out, err := convertStreamingResponseToClaude(ctx, chunk) + if err != nil || string(out) != string(chunk) { + t.Fatalf("without conversion flag: err=%v out=%q", err, string(out)) + } +} + +func Test_needsClaudeResponseConversion(t *testing.T) { + ctx := test.NewMockHttpContext() + if NeedsClaudeResponseConversionForTest(ctx) { + t.Fatal("expected false without context flag") + } + ctx.SetContext("needClaudeResponseConversion", true) + if !NeedsClaudeResponseConversionForTest(ctx) { + t.Fatal("expected true when flag set") + } +} + +func Test_promoteThinkingInStreamingChunk(t *testing.T) { + ctx := test.NewMockHttpContext() + reasoningJSON := `{"choices":[{"index":0,"delta":{"reasoning_content":"only-thinking"}}]}` + sse := "data: " + reasoningJSON + "\n" + out := promoteThinkingInStreamingChunk(ctx, []byte(sse), true) + if len(out) == 0 { + t.Fatal("expected non-empty output") + } + // Last chunk should prepend flush SSE when no content delta was seen + if !strings.HasPrefix(string(out), "data: ") { + t.Fatalf("expected flush data line prepended, got prefix %q", string(out)) + } + // Original line should still be present (possibly stripped reasoning) + if !strings.Contains(string(out), "data:") { + t.Fatalf("expected SSE data lines: %s", string(out)) + } +} + func TestAi360(t *testing.T) { test.RunAi360ParseConfigTests(t) test.RunAi360OnHttpRequestHeadersTests(t) @@ -183,6 +300,10 @@ func TestUtil(t *testing.T) { test.RunMapRequestPathByCapabilityTests(t) } +func TestMainEdgeCases(t *testing.T) { + test.RunMainEdgeCaseTests(t) +} + func TestApiPathRegression(t *testing.T) { test.RunApiPathRegressionTests(t) } @@ -239,3 +360,64 @@ func TestOpenRouter(t *testing.T) { func TestZhipuAI(t *testing.T) { test.RunZhipuAIClaudeAutoConversionTests(t) } + +func TestDeepSeek(t *testing.T) { + test.RunDeepSeekParseConfigTests(t) + test.RunDeepSeekOnHttpRequestHeadersTests(t) +} + +func TestDoubao(t *testing.T) { + test.RunDoubaoParseConfigTests(t) + test.RunDoubaoOnHttpRequestHeadersTests(t) +} + +func TestGroq(t *testing.T) { + test.RunGroqParseConfigTests(t) + test.RunGroqOnHttpRequestHeadersTests(t) +} + +func TestMistral(t *testing.T) { + test.RunMistralParseConfigTests(t) + test.RunMistralOnHttpRequestHeadersTests(t) +} + +func TestMoonshot(t *testing.T) { + test.RunMoonshotParseConfigTests(t) + test.RunMoonshotOnHttpRequestHeadersTests(t) +} + +func TestSpark(t *testing.T) { + test.RunSparkParseConfigTests(t) + test.RunSparkOnHttpRequestHeadersTests(t) +} + +func TestTogetherAI(t *testing.T) { + test.RunTogetherAIParseConfigTests(t) + test.RunTogetherAIOnHttpRequestHeadersTests(t) +} + +func TestGithub(t *testing.T) { + test.RunGithubParseConfigTests(t) + test.RunGithubOnHttpRequestHeadersTests(t) +} + +func TestGrok(t *testing.T) { + test.RunGrokParseConfigTests(t) + test.RunGrokOnHttpRequestHeadersTests(t) +} + +func TestProviderWasmSmoke(t *testing.T) { + test.RunBaichuanWasmSmokeTests(t) + test.RunYiWasmSmokeTests(t) + test.RunOllamaWasmSmokeTests(t) + test.RunBaiduWasmSmokeTests(t) + test.RunHunyuanWasmSmokeTests(t) + test.RunStepfunWasmSmokeTests(t) + test.RunCloudflareWasmSmokeTests(t) + test.RunDeeplWasmSmokeTests(t) + test.RunCohereWasmSmokeTests(t) + test.RunCozeWasmSmokeTests(t) + test.RunDifyWasmSmokeTests(t) + test.RunTritonWasmSmokeTests(t) + test.RunVllmWasmSmokeTests(t) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/parse_config_test.go b/plugins/wasm-go/extensions/ai-proxy/parse_config_test.go new file mode 100644 index 00000000..2951a187 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/parse_config_test.go @@ -0,0 +1,60 @@ +package main + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + wasmtest "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestParseGlobalAndOverrideConfig(t *testing.T) { + wasmtest.RunGoTest(t, func(t *testing.T) { + bootstrap, st := wasmtest.NewTestHost(json.RawMessage(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:1","apiTokens":["bootstrap"]}}`)) + require.Equal(t, types.OnPluginStartStatusOK, st) + defer bootstrap.Reset() + + t.Run("parse_global_empty_ok", func(t *testing.T) { + var c config.PluginConfig + err := ParseGlobalConfigForTest(gjson.Parse(`{}`), &c) + require.NoError(t, err) + require.Nil(t, c.GetProviderConfig()) + }) + + t.Run("parse_global_invalid_provider", func(t *testing.T) { + var c config.PluginConfig + err := ParseGlobalConfigForTest(gjson.Parse(`{"provider":{"type":"not-a-real-provider","apiTokens":["x"]}}`), &c) + require.Error(t, err) + }) + + t.Run("parse_override_switches_active_provider", func(t *testing.T) { + globalJSON := `{"providers":[ + {"id":"p1","type":"generic","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]}, + {"id":"p2","type":"generic","genericHost":"http://127.0.0.1:8081","apiTokens":["u"]} + ],"activeProviderId":"p1"}` + + var global config.PluginConfig + require.NoError(t, ParseGlobalConfigForTest(gjson.Parse(globalJSON), &global)) + require.Equal(t, "p1", global.GetProviderConfig().GetId()) + + var rule config.PluginConfig + err := ParseOverrideRuleConfigForTest(gjson.Parse(`{"activeProviderId":"p2"}`), global, &rule) + require.NoError(t, err) + require.Equal(t, "p2", rule.GetProviderConfig().GetId()) + }) + + t.Run("parse_override_invalid_fails", func(t *testing.T) { + var global config.PluginConfig + require.NoError(t, ParseGlobalConfigForTest(gjson.Parse(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:1","apiTokens":["a"]}}`), &global)) + + var rule config.PluginConfig + err := ParseOverrideRuleConfigForTest(gjson.Parse(`{"provider":{"type":"azure","apiTokens":["t"]}}`), global, &rule) + require.Error(t, err) + require.Contains(t, strings.ToLower(err.Error()), "azure") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai_test.go index b0d1817b..4cbc93d0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai_test.go @@ -1141,6 +1141,25 @@ func TestNormalizeFinishReason(t *testing.T) { } } +func TestClaudeToOpenAIConverter_streaming_tool_call_smoke(t *testing.T) { + converter := &ClaudeToOpenAIConverter{} + + start := `data: {"id":"tc1","choices":[{"index":0,"delta":{"role":"assistant","content":""}}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n" + _, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(start)) + require.NoError(t, err) + + toolChunk := `data: {"id":"tc1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"my_fn","arguments":""}}]}}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n" + out, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(toolChunk)) + require.NoError(t, err) + require.Contains(t, string(out), "content_block_start") + require.Contains(t, string(out), "tool_use") + + argChunk := `data: {"id":"tc1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"x\":1}"}}]}}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n" + out2, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(argChunk)) + require.NoError(t, err) + require.Contains(t, string(out2), "input_json_delta") +} + func TestClaudeToOpenAIConverter_ConvertOpenAIStreamResponseToClaude_Compatibility(t *testing.T) { t.Run("finish_reason empty string should not stop stream", func(t *testing.T) { converter := &ClaudeToOpenAIConverter{} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go index 9dccd76b..3f2f7b12 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go @@ -544,6 +544,19 @@ func TestFailover_FromJson_FailoverOnStatus(t *testing.T) { }) } +func TestFailover_Validate(t *testing.T) { + t.Run("missing_healthCheckModel", func(t *testing.T) { + f := &failover{} + f.FromJson(gjson.Parse(`{"enabled":true}`)) + assert.Error(t, f.Validate()) + }) + t.Run("ok_with_healthCheckModel", func(t *testing.T) { + f := &failover{} + f.FromJson(gjson.Parse(`{"enabled":true,"healthCheckModel":"gpt-4o-mini"}`)) + assert.NoError(t, f.Validate()) + }) +} + func TestHealthCheckEndpoint_Struct(t *testing.T) { t.Run("health_check_endpoint_fields", func(t *testing.T) { endpoint := HealthCheckEndpoint{ @@ -679,6 +692,32 @@ func TestProviderConfig_SetDefaultCapabilities(t *testing.T) { }) } +func TestCreateProvider(t *testing.T) { + t.Run("generic_success", func(t *testing.T) { + var pc ProviderConfig + pc.FromJson(gjson.Parse(`{"type":"generic","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]}`)) + p, err := CreateProvider(pc) + assert.NoError(t, err) + assert.Equal(t, providerTypeGeneric, p.GetProviderType()) + }) + + t.Run("openai_minimal_success", func(t *testing.T) { + var pc ProviderConfig + pc.FromJson(gjson.Parse(`{"type":"openai","apiTokens":["sk-test"]}`)) + p, err := CreateProvider(pc) + assert.NoError(t, err) + assert.Equal(t, providerTypeOpenAI, p.GetProviderType()) + }) + + t.Run("unknown_type", func(t *testing.T) { + var pc ProviderConfig + pc.FromJson(gjson.Parse(`{"type":"no-such-provider-xyz","apiTokens":["t"]}`)) + _, err := CreateProvider(pc) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown provider type") + }) +} + func TestStripClaudeInternalMessageFields(t *testing.T) { body := []byte(`{ "model":"claude", diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/retry_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/retry_test.go new file mode 100644 index 00000000..43a4ac48 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/retry_test.go @@ -0,0 +1,126 @@ +package provider + +import ( + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/iface" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// mapCtx is a minimal wrapper.HttpContext for offline tests (no import cycle with test package). +type mapCtx struct { + kv map[string]interface{} +} + +func newMapCtx() *mapCtx { + return &mapCtx{kv: make(map[string]interface{})} +} + +func (m *mapCtx) SetContext(key string, value interface{}) { m.kv[key] = value } +func (m *mapCtx) GetContext(key string) interface{} { return m.kv[key] } +func (m *mapCtx) GetBoolContext(key string, def bool) bool { return def } +func (m *mapCtx) GetStringContext(key, def string) string { return def } +func (m *mapCtx) GetByteSliceContext(key string, def []byte) []byte { return def } +func (m *mapCtx) Scheme() string { return "" } +func (m *mapCtx) Host() string { return "" } +func (m *mapCtx) Path() string { return "" } +func (m *mapCtx) Method() string { return "" } +func (m *mapCtx) GetUserAttribute(key string) interface{} { return nil } +func (m *mapCtx) SetUserAttribute(key string, value interface{}) {} +func (m *mapCtx) SetUserAttributeMap(kvmap map[string]interface{}) {} +func (m *mapCtx) GetUserAttributeMap() map[string]interface{} { return nil } +func (m *mapCtx) WriteUserAttributeToLog() error { return nil } +func (m *mapCtx) WriteUserAttributeToLogWithKey(key string) error { return nil } +func (m *mapCtx) WriteUserAttributeToTrace() error { return nil } +func (m *mapCtx) DontReadRequestBody() {} +func (m *mapCtx) DontReadResponseBody() {} +func (m *mapCtx) BufferRequestBody() {} +func (m *mapCtx) BufferResponseBody() {} +func (m *mapCtx) NeedPauseStreamingResponse() {} +func (m *mapCtx) PushBuffer(buffer []byte) {} +func (m *mapCtx) PopBuffer() []byte { return nil } +func (m *mapCtx) BufferQueueSize() int { return 0 } +func (m *mapCtx) DisableReroute() {} +func (m *mapCtx) SetRequestBodyBufferLimit(byteSize uint32) {} +func (m *mapCtx) SetResponseBodyBufferLimit(byteSize uint32) {} +func (m *mapCtx) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error { + return nil +} +func (m *mapCtx) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 } +func (m *mapCtx) HasRequestBody() bool { return false } +func (m *mapCtx) HasResponseBody() bool { return false } +func (m *mapCtx) IsWebsocket() bool { return false } +func (m *mapCtx) IsBinaryRequestBody() bool { return false } +func (m *mapCtx) IsBinaryResponseBody() bool { return false } + +var _ wrapper.HttpContext = (*mapCtx)(nil) + +type stubProviderType struct{} + +func (stubProviderType) GetProviderType() string { return providerTypeOpenAI } + +func TestRemoveApiTokenFromRetryList(t *testing.T) { + t.Run("removes_token", func(t *testing.T) { + got := removeApiTokenFromRetryList([]string{"a", "b", "c"}, "b") + assert.Equal(t, []string{"a", "c"}, got) + }) + t.Run("removes_all_when_single", func(t *testing.T) { + got := removeApiTokenFromRetryList([]string{"x"}, "x") + assert.Empty(t, got) + }) + t.Run("no_match_unchanged", func(t *testing.T) { + got := removeApiTokenFromRetryList([]string{"a", "b"}, "z") + assert.Equal(t, []string{"a", "b"}, got) + }) + t.Run("empty_input", func(t *testing.T) { + got := removeApiTokenFromRetryList(nil, "a") + assert.Empty(t, got) + }) +} + +func TestGetRandomToken(t *testing.T) { + assert.Equal(t, "", GetRandomToken(nil)) + assert.Equal(t, "", GetRandomToken([]string{})) + assert.Equal(t, "only", GetRandomToken([]string{"only"})) + tokens := []string{"a", "b", "c"} + for i := 0; i < 20; i++ { + got := GetRandomToken(tokens) + assert.Contains(t, tokens, got) + } +} + +func TestRetryOnFailure_FromJson_defaults(t *testing.T) { + var c ProviderConfig + c.FromJson(gjson.Parse(`{"type":"openai","apiTokens":["t"],"retryOnFailure":{"enabled":true}}`)) + require.True(t, c.IsRetryOnFailureEnabled()) + assert.Equal(t, int64(1), c.retryOnFailure.maxRetries) + assert.Equal(t, int64(60*1000), c.retryOnFailure.retryTimeout) + assert.Equal(t, []string{"4.*", "5.*"}, c.retryOnFailure.retryOnStatus) +} + +func TestOnRequestFailed_offlineBranches(t *testing.T) { + t.Run("no_failover_no_retry_always_continue", func(t *testing.T) { + var c ProviderConfig + c.FromJson(gjson.Parse(`{"type":"openai","apiTokens":["t"]}`)) + ctx := newMapCtx() + act := c.OnRequestFailed(stubProviderType{}, ctx, "t", []string{"t"}, "503") + assert.Equal(t, types.ActionContinue, act) + }) + + t.Run("retry_enabled_single_token_returns_continue_before_post", func(t *testing.T) { + var c ProviderConfig + c.FromJson(gjson.Parse(`{ + "type":"openai", + "apiTokens":["only"], + "retryOnFailure":{"enabled":true,"retryOnStatus":["429","503"]} + }`)) + ctx := newMapCtx() + ctx.SetContext(CtxKeyApiName, ApiNameChatCompletion) + act := c.OnRequestFailed(stubProviderType{}, ctx, "only", []string{"only"}, "503") + assert.Equal(t, types.ActionContinue, act) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/streaming_extract_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/streaming_extract_test.go new file mode 100644 index 00000000..911b7ee9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/streaming_extract_test.go @@ -0,0 +1,51 @@ +package provider + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractStreamingEvents(t *testing.T) { + t.Run("empty_chunk", func(t *testing.T) { + ctx := newMapCtx() + events := ExtractStreamingEvents(ctx, nil) + assert.Empty(t, events) + }) + + t.Run("crlf_normalized", func(t *testing.T) { + ctx := newMapCtx() + chunk := "event:msg\r\ndata:{\"k\":1}\r\n\r\n" + events := ExtractStreamingEvents(ctx, []byte(chunk)) + require.NotEmpty(t, events) + }) + + t.Run("qwen_style_block", func(t *testing.T) { + ctx := newMapCtx() + chunk := "event:result\n:HTTP_STATUS/200\ndata:{\"output\":1}\n\n" + events := ExtractStreamingEvents(ctx, []byte(chunk)) + require.NotEmpty(t, events) + foundData := false + for _, e := range events { + if strings.Contains(e.RawEvent, "data:") { + foundData = true + } + } + assert.True(t, foundData, "expected a data line in parsed events: %#v", events) + }) + + t.Run("split_chunk_buffers_incomplete", func(t *testing.T) { + ctx := newMapCtx() + part1 := []byte("event:a\n") + _ = ExtractStreamingEvents(ctx, part1) + buf, has := ctx.GetContext(ctxKeyStreamingBody).([]byte) + require.True(t, has, "expected streaming body buffer after incomplete chunk") + require.NotEmpty(t, buf) + + part2 := []byte("data:{}\n\n") + events := ExtractStreamingEvents(ctx, part2) + require.NotEmpty(t, events) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/streaming_matrix_test.go b/plugins/wasm-go/extensions/ai-proxy/streaming_matrix_test.go new file mode 100644 index 00000000..1295c3d6 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/streaming_matrix_test.go @@ -0,0 +1,135 @@ +package main + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + wasmtest "github.com/higress-group/wasm-go/pkg/test" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +type streamBodyStub struct { + out []byte + err error +} + +func (s *streamBodyStub) GetProviderType() string { return "stream-body-stub" } + +func (s *streamBodyStub) OnStreamingResponseBody(ctx wrapper.HttpContext, name provider.ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { + _ = ctx + _ = name + _ = chunk + _ = isLastChunk + return s.out, s.err +} + +type streamEventStub struct { + eventsOut []provider.StreamEvent + err error +} + +func (s *streamEventStub) GetProviderType() string { return "stream-event-stub" } + +func (s *streamEventStub) OnStreamingEvent(ctx wrapper.HttpContext, name provider.ApiName, event provider.StreamEvent) ([]provider.StreamEvent, error) { + _ = ctx + _ = name + _ = event + return s.eventsOut, s.err +} + +func pluginConfigWithStubProvider(t *testing.T, p provider.Provider) config.PluginConfig { + t.Helper() + var c config.PluginConfig + c.FromJson(gjson.Parse(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:9","apiTokens":["tok"]}}`)) + require.NoError(t, c.Validate()) + require.NoError(t, c.Complete()) + c.SetActiveProviderForTest(p) + return c +} + +func TestOnStreamingResponseBody_matrix(t *testing.T) { + wasmtest.RunGoTest(t, func(t *testing.T) { + bootstrap, st := wasmtest.NewTestHost(json.RawMessage(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:1","apiTokens":["bootstrap"]}}`)) + require.Equal(t, types.OnPluginStartStatusOK, st) + defer bootstrap.Reset() + + t.Run("nil_provider_returns_chunk", func(t *testing.T) { + var c config.PluginConfig + c.FromJson(gjson.Parse(`{"providers":[{"id":"x","type":"generic","genericHost":"http://127.0.0.1:9","apiTokens":["t"]}]}`)) + require.NoError(t, c.Validate()) + require.NoError(t, c.Complete()) + ctx := test.NewMockHttpContext() + in := []byte("keep") + out := OnStreamingResponseBodyForTest(ctx, c, in, false) + require.Equal(t, in, out) + }) + + t.Run("streaming_body_handler_err_returns_original_chunk", func(t *testing.T) { + stub := &streamBodyStub{out: []byte("x"), err: errors.New("handler failed")} + pc := pluginConfigWithStubProvider(t, stub) + ctx := test.NewMockHttpContext() + ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion) + in := []byte("original") + out := OnStreamingResponseBodyForTest(ctx, pc, in, false) + require.Equal(t, in, out) + }) + + t.Run("streaming_body_handler_nil_modified_returns_original_chunk", func(t *testing.T) { + stub := &streamBodyStub{out: nil, err: nil} + pc := pluginConfigWithStubProvider(t, stub) + ctx := test.NewMockHttpContext() + ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion) + in := []byte("original") + out := OnStreamingResponseBodyForTest(ctx, pc, in, false) + require.Equal(t, in, out) + }) + + t.Run("streaming_body_handler_ok_returns_modified", func(t *testing.T) { + stub := &streamBodyStub{out: []byte("modified"), err: nil} + pc := pluginConfigWithStubProvider(t, stub) + ctx := test.NewMockHttpContext() + ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion) + in := []byte("in") + out := OnStreamingResponseBodyForTest(ctx, pc, in, false) + require.Equal(t, "modified", string(out)) + }) + + t.Run("streaming_event_handler_zero_events_returns_empty", func(t *testing.T) { + stub := &streamEventStub{} + pc := pluginConfigWithStubProvider(t, stub) + ctx := test.NewMockHttpContext() + ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion) + out := OnStreamingResponseBodyForTest(ctx, pc, []byte("incomplete"), false) + require.Equal(t, []byte(""), out) + }) + + t.Run("streaming_event_handler_on_event_err_returns_chunk", func(t *testing.T) { + stub := &streamEventStub{err: errors.New("event failed")} + pc := pluginConfigWithStubProvider(t, stub) + ctx := test.NewMockHttpContext() + ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion) + chunk := []byte("data: {\"x\":1}\n\n") + out := OnStreamingResponseBodyForTest(ctx, pc, chunk, false) + require.Equal(t, chunk, out) + }) + + t.Run("no_handler_no_flags_returns_chunk", func(t *testing.T) { + var c config.PluginConfig + c.FromJson(gjson.Parse(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:9","apiTokens":["t"]}}`)) + require.NoError(t, c.Validate()) + require.NoError(t, c.Complete()) + ctx := test.NewMockHttpContext() + ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion) + in := []byte("passthrough") + out := OnStreamingResponseBodyForTest(ctx, c, in, false) + require.Equal(t, in, out) + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/test/deepseek.go new file mode 100644 index 00000000..ad8631d2 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/deepseek.go @@ -0,0 +1,92 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var basicDeepSeekConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "deepseek", + "apiTokens": []string{"sk-deepseek-test"}, + "modelMapping": map[string]string{ + "*": "deepseek-chat", + }, + }) +}() + +var invalidDeepSeekConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "deepseek", + "apiTokens": []string{}, + "modelMapping": map[string]string{"*": "deepseek-chat"}, + }) +}() + +// RunDeepSeekParseConfigTests exercises DeepSeek plugin config loading. +func RunDeepSeekParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic deepseek config", func(t *testing.T) { + host, status := test.NewTestHost(basicDeepSeekConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + t.Run("invalid deepseek config missing apiToken", func(t *testing.T) { + host, status := test.NewTestHost(invalidDeepSeekConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + }) +} + +// RunDeepSeekOnHttpRequestHeadersTests exercises request header transforms for DeepSeek. +func RunDeepSeekOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("deepseek chat completions headers", func(t *testing.T) { + host, status := test.NewTestHost(basicDeepSeekConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + hostValue, ok := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, ok) + require.Equal(t, "api.deepseek.com", hostValue) + + authValue, ok := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, ok) + require.Contains(t, authValue, "Bearer sk-deepseek-test") + + pathValue, ok := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, ok) + require.Equal(t, "/v1/chat/completions", pathValue) + + debugLogs := host.GetDebugLogs() + found := false + for _, log := range debugLogs { + if strings.Contains(log, "deepseek") || strings.Contains(log, "ai-proxy") { + found = true + break + } + } + require.True(t, found, "expected ai-proxy or deepseek debug logs") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/doubao.go b/plugins/wasm-go/extensions/ai-proxy/test/doubao.go new file mode 100644 index 00000000..267f8c07 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/doubao.go @@ -0,0 +1,92 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var basicDoubaoConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "doubao", + "apiTokens": []string{"doubao-token-test"}, + "modelMapping": map[string]string{ + "*": "ep-20240101000000-example", + }, + }) +}() + +var invalidDoubaoConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "doubao", + "apiTokens": []string{}, + "modelMapping": map[string]string{"*": "ep-example"}, + }) +}() + +// RunDoubaoParseConfigTests exercises Doubao (Volcengine Ark) plugin config loading. +func RunDoubaoParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic doubao config", func(t *testing.T) { + host, status := test.NewTestHost(basicDoubaoConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + t.Run("invalid doubao config missing apiToken", func(t *testing.T) { + host, status := test.NewTestHost(invalidDoubaoConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + }) +} + +// RunDoubaoOnHttpRequestHeadersTests exercises Doubao request header transforms. +func RunDoubaoOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("doubao chat completions headers", func(t *testing.T) { + host, status := test.NewTestHost(basicDoubaoConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + hostValue, ok := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, ok) + require.Equal(t, "ark.cn-beijing.volces.com", hostValue) + + authValue, ok := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, ok) + require.Contains(t, authValue, "Bearer doubao-token-test") + + pathValue, ok := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, ok) + require.Equal(t, "/api/v3/chat/completions", pathValue) + + debugLogs := host.GetDebugLogs() + found := false + for _, log := range debugLogs { + if strings.Contains(log, "doubao") || strings.Contains(log, "ai-proxy") { + found = true + break + } + } + require.True(t, found, "expected ai-proxy or doubao debug logs") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/github.go b/plugins/wasm-go/extensions/ai-proxy/test/github.go new file mode 100644 index 00000000..a0610776 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/github.go @@ -0,0 +1,93 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var basicGithubConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "github", + "apiTokens": []string{"github_models_pat_test"}, + "modelMapping": map[string]string{ + "*": "gpt-4o", + }, + }) +}() + +var invalidGithubConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "github", + "apiTokens": []string{}, + "modelMapping": map[string]string{"*": "gpt-4o"}, + }) +}() + +// RunGithubParseConfigTests exercises GitHub Models plugin config loading. +func RunGithubParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic github config", func(t *testing.T) { + host, status := test.NewTestHost(basicGithubConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + t.Run("invalid github config missing apiToken", func(t *testing.T) { + host, status := test.NewTestHost(invalidGithubConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + }) +} + +// RunGithubOnHttpRequestHeadersTests exercises GitHub Models request header transforms. +func RunGithubOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("github chat completions headers", func(t *testing.T) { + host, status := test.NewTestHost(basicGithubConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + hostValue, ok := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, ok) + require.Equal(t, "models.inference.ai.azure.com", hostValue) + + authValue, ok := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, ok) + // GitHub provider sets raw token without "Bearer " prefix + require.Equal(t, "github_models_pat_test", authValue) + + pathValue, ok := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, ok) + require.Equal(t, "/chat/completions", pathValue) + + debugLogs := host.GetDebugLogs() + found := false + for _, log := range debugLogs { + if strings.Contains(log, "github") || strings.Contains(log, "ai-proxy") { + found = true + break + } + } + require.True(t, found, "expected ai-proxy or github debug logs") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/grok.go b/plugins/wasm-go/extensions/ai-proxy/test/grok.go new file mode 100644 index 00000000..9991750c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/grok.go @@ -0,0 +1,92 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var basicGrokConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "grok", + "apiTokens": []string{"xai-grok-test-key"}, + "modelMapping": map[string]string{ + "*": "grok-2-latest", + }, + }) +}() + +var invalidGrokConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "grok", + "apiTokens": []string{}, + "modelMapping": map[string]string{"*": "grok-2-latest"}, + }) +}() + +// RunGrokParseConfigTests exercises Grok plugin config loading. +func RunGrokParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic grok config", func(t *testing.T) { + host, status := test.NewTestHost(basicGrokConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + t.Run("invalid grok config missing apiToken", func(t *testing.T) { + host, status := test.NewTestHost(invalidGrokConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + }) +} + +// RunGrokOnHttpRequestHeadersTests exercises Grok request header transforms. +func RunGrokOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("grok chat completions headers", func(t *testing.T) { + host, status := test.NewTestHost(basicGrokConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + hostValue, ok := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, ok) + require.Equal(t, "api.x.ai", hostValue) + + authValue, ok := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, ok) + require.Contains(t, authValue, "Bearer xai-grok-test-key") + + pathValue, ok := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, ok) + require.Equal(t, "/v1/chat/completions", pathValue) + + debugLogs := host.GetDebugLogs() + found := false + for _, log := range debugLogs { + if strings.Contains(log, "grok") || strings.Contains(log, "ai-proxy") { + found = true + break + } + } + require.True(t, found, "expected ai-proxy or grok debug logs") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/groq.go b/plugins/wasm-go/extensions/ai-proxy/test/groq.go new file mode 100644 index 00000000..ab0e296c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/groq.go @@ -0,0 +1,92 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var basicGroqConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "groq", + "apiTokens": []string{"gsk_groq_test"}, + "modelMapping": map[string]string{ + "*": "llama-3.1-8b-instant", + }, + }) +}() + +var invalidGroqConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "groq", + "apiTokens": []string{}, + "modelMapping": map[string]string{"*": "llama-3.1-8b-instant"}, + }) +}() + +// RunGroqParseConfigTests exercises Groq plugin config loading. +func RunGroqParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic groq config", func(t *testing.T) { + host, status := test.NewTestHost(basicGroqConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + t.Run("invalid groq config missing apiToken", func(t *testing.T) { + host, status := test.NewTestHost(invalidGroqConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + }) +} + +// RunGroqOnHttpRequestHeadersTests exercises Groq request header transforms. +func RunGroqOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("groq chat completions headers", func(t *testing.T) { + host, status := test.NewTestHost(basicGroqConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + hostValue, ok := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, ok) + require.Equal(t, "api.groq.com", hostValue) + + authValue, ok := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, ok) + require.Contains(t, authValue, "Bearer gsk_groq_test") + + pathValue, ok := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, ok) + require.Equal(t, "/openai/v1/chat/completions", pathValue) + + debugLogs := host.GetDebugLogs() + found := false + for _, log := range debugLogs { + if strings.Contains(log, "groq") || strings.Contains(log, "ai-proxy") { + found = true + break + } + } + require.True(t, found, "expected ai-proxy or groq debug logs") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/main_edges.go b/plugins/wasm-go/extensions/ai-proxy/test/main_edges.go new file mode 100644 index 00000000..319d1ba5 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/main_edges.go @@ -0,0 +1,144 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + wasmhost "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var edgeNoActiveProviderConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "providers": []map[string]interface{}{ + {"id": "a", "type": "generic", "genericHost": "http://127.0.0.1:8080", "apiTokens": []string{"t"}}, + }, + }) + return data +}() + +var edgeGenericConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "generic", + "genericHost": "http://127.0.0.1:9999", + "apiTokens": []string{"tok"}, + "modelMapping": map[string]string{"*": "mapped-model"}, + }) +}() + +// RunMainEdgeCaseTests covers main.go branches: no active provider, unknown path, bad Content-Type, generic Claude path rewrite. +func RunMainEdgeCaseTests(t *testing.T) { + wasmhost.RunGoTest(t, func(t *testing.T) { + t.Run("no_active_provider_skips_body", func(t *testing.T) { + host, status := wasmhost.NewTestHost(edgeNoActiveProviderConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.ActionContinue, action) + }) + + t.Run("unknown_path_logs_unsupported", func(t *testing.T) { + host, status := wasmhost.NewTestHost(edgeGenericConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + _ = host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/not-a-real-openai-path"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + all := append(append([]string{}, host.GetDebugLogs()...), host.GetWarnLogs()...) + found := false + for _, line := range all { + if strings.Contains(line, "unsupported path") { + found = true + break + } + } + require.True(t, found, "logs: %v", all) + }) + + t.Run("multipart_on_chat_logs_unsupported_content_type", func(t *testing.T) { + host, status := wasmhost.NewTestHost(edgeGenericConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + _ = host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "multipart/form-data; boundary=----abc"}, + }) + debugLogs := host.GetDebugLogs() + found := false + for _, line := range debugLogs { + if strings.Contains(line, "unsupported content type") { + found = true + break + } + } + require.True(t, found, "debug logs: %v", debugLogs) + }) + }) + + wasmhost.RunTest(t, func(t *testing.T) { + t.Run("generic_claude_path_rewrites_to_chat_completions", func(t *testing.T) { + host, status := wasmhost.NewTestHost(edgeGenericConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/messages"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + headers := host.GetRequestHeaders() + pathVal, ok := wasmhost.GetHeaderValue(headers, ":path") + require.True(t, ok) + require.Equal(t, "/v1/chat/completions", pathVal) + }) + + t.Run("response_json_content_type_buffers_for_non_sse", func(t *testing.T) { + host, status := wasmhost.NewTestHost(edgeGenericConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + host.CallOnHttpRequestBody([]byte(`{"model":"m","messages":[{"role":"user","content":"hi"}]}`)) + + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + action := host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.ActionContinue, action) + + debugLogs := host.GetDebugLogs() + found := false + for _, line := range debugLogs { + if strings.Contains(line, "onHttpResponseHeaders") { + found = true + break + } + } + require.True(t, found, "expected onHttpResponseHeaders log, got: %v", debugLogs) + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/mistral.go b/plugins/wasm-go/extensions/ai-proxy/test/mistral.go new file mode 100644 index 00000000..72bf60d9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/mistral.go @@ -0,0 +1,92 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var basicMistralConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "mistral", + "apiTokens": []string{"mistral-test-key"}, + "modelMapping": map[string]string{ + "*": "mistral-small-latest", + }, + }) +}() + +var invalidMistralConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "mistral", + "apiTokens": []string{}, + "modelMapping": map[string]string{"*": "mistral-small-latest"}, + }) +}() + +// RunMistralParseConfigTests exercises Mistral plugin config loading. +func RunMistralParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic mistral config", func(t *testing.T) { + host, status := test.NewTestHost(basicMistralConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + t.Run("invalid mistral config missing apiToken", func(t *testing.T) { + host, status := test.NewTestHost(invalidMistralConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + }) +} + +// RunMistralOnHttpRequestHeadersTests exercises Mistral request header transforms. +func RunMistralOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("mistral chat completions headers", func(t *testing.T) { + host, status := test.NewTestHost(basicMistralConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + hostValue, ok := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, ok) + require.Equal(t, "api.mistral.ai", hostValue) + + authValue, ok := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, ok) + require.Contains(t, authValue, "Bearer mistral-test-key") + + pathValue, ok := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, ok) + require.Equal(t, "/v1/chat/completions", pathValue) + + debugLogs := host.GetDebugLogs() + found := false + for _, log := range debugLogs { + if strings.Contains(log, "mistral") || strings.Contains(log, "ai-proxy") { + found = true + break + } + } + require.True(t, found, "expected ai-proxy or mistral debug logs") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/test/moonshot.go new file mode 100644 index 00000000..2286538a --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/moonshot.go @@ -0,0 +1,92 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var basicMoonshotConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "moonshot", + "apiTokens": []string{"sk-moonshot-test"}, + "modelMapping": map[string]string{ + "*": "moonshot-v1-8k", + }, + }) +}() + +var invalidMoonshotConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "moonshot", + "apiTokens": []string{}, + "modelMapping": map[string]string{"*": "moonshot-v1-8k"}, + }) +}() + +// RunMoonshotParseConfigTests exercises Moonshot plugin config loading. +func RunMoonshotParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic moonshot config", func(t *testing.T) { + host, status := test.NewTestHost(basicMoonshotConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + t.Run("invalid moonshot config missing apiToken", func(t *testing.T) { + host, status := test.NewTestHost(invalidMoonshotConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + }) +} + +// RunMoonshotOnHttpRequestHeadersTests exercises Moonshot request header transforms. +func RunMoonshotOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("moonshot chat completions headers", func(t *testing.T) { + host, status := test.NewTestHost(basicMoonshotConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + hostValue, ok := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, ok) + require.Equal(t, "api.moonshot.cn", hostValue) + + authValue, ok := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, ok) + require.Contains(t, authValue, "Bearer sk-moonshot-test") + + pathValue, ok := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, ok) + require.Equal(t, "/v1/chat/completions", pathValue) + + debugLogs := host.GetDebugLogs() + found := false + for _, log := range debugLogs { + if strings.Contains(log, "moonshot") || strings.Contains(log, "ai-proxy") { + found = true + break + } + } + require.True(t, found, "expected ai-proxy or moonshot debug logs") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/openai.go b/plugins/wasm-go/extensions/ai-proxy/test/openai.go index 9c5d0562..d67c5264 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/openai.go @@ -984,6 +984,10 @@ func RunOpenAIOnStreamingResponseBodyTests(t *testing.T) { action4 := host.CallOnHttpStreamingResponseBody([]byte(chunk4), true) require.Equal(t, types.ActionContinue, action4) + // Empty chunk should not panic + actionEmpty := host.CallOnHttpStreamingResponseBody([]byte{}, false) + require.Equal(t, types.ActionContinue, actionEmpty) + // 验证流式响应处理 // 注意:流式响应可能不会在GetResponseBody中累积,需要检查日志或其他方式验证 debugLogs := host.GetDebugLogs() diff --git a/plugins/wasm-go/extensions/ai-proxy/test/provider_wasm_smoke.go b/plugins/wasm-go/extensions/ai-proxy/test/provider_wasm_smoke.go new file mode 100644 index 00000000..0cb4faa6 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/provider_wasm_smoke.go @@ -0,0 +1,230 @@ +// Package test contains Wasm smoke tests for additional AI providers (legacy top-level +// "provider" JSON, plugin start, request headers, and minimal body where useful). This file +// complements per-vendor files such as openai.go and deepseek.go. +package test + +import ( + "encoding/json" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + wasmhost "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +func providerSmokeLegacyJSON(m map[string]interface{}) json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{"provider": m}) + return data +} + +func RunBaichuanWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{ + "type": "baichuan", "apiTokens": []string{"sk-bc"}, "modelMapping": map[string]string{"*": "bc-model"}, + }) + t.Run("parse", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + }) + t.Run("headers", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + act := h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, act) + auth, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), "Authorization") + require.Contains(t, auth, "Bearer") + }) + }) +} + +func RunYiWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{"type": "yi", "apiTokens": []string{"sk-yi"}}) + t.Run("headers", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + _ = h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + host, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), ":authority") + require.Contains(t, host, "lingyi") + }) + }) +} + +func RunOllamaWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{ + "type": "ollama", "ollamaServerHost": "127.0.0.1", "ollamaServerPort": 11434, "apiTokens": []string{"x"}, + }) + t.Run("headers", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + _ = h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + host, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), ":authority") + require.Contains(t, host, "127.0.0.1") + }) + }) +} + +func RunBaiduWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{"type": "baidu", "apiTokens": []string{"sk-bd"}}) + t.Run("headers", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + _ = h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + path, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), ":path") + require.Contains(t, path, "/v2/") + }) + }) +} + +func RunHunyuanWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{"type": "hunyuan", "apiTokens": []string{"tok"}}) + t.Run("parse", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + }) + }) +} + +func RunStepfunWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{"type": "stepfun", "apiTokens": []string{"sk-sf"}}) + t.Run("headers", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + _ = h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + host, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), ":authority") + require.Contains(t, host, "stepfun") + }) + }) +} + +func RunCloudflareWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{ + "type": "cloudflare", "apiTokens": []string{"cf"}, "cloudflareAccountId": "acc1", + }) + t.Run("headers", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + _ = h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + path, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), ":path") + require.Contains(t, path, "acc1") + }) + }) +} + +func RunDeeplWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{ + "type": "deepl", "apiTokens": []string{"k"}, "targetLang": "EN", + }) + t.Run("parse", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + }) + }) +} + +func RunCohereWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{"type": "cohere", "apiTokens": []string{"ck"}}) + t.Run("headers", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + _ = h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + host, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), ":authority") + require.Contains(t, host, "cohere") + }) + }) +} + +func RunCozeWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{"type": "coze", "apiTokens": []string{"cz"}}) + t.Run("headers", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + _ = h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + host, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), ":authority") + require.Contains(t, host, "coze") + }) + }) +} + +func RunDifyWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{ + "type": "dify", "apiTokens": []string{"d"}, "botType": "Chat", + }) + t.Run("headers", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + _ = h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + path, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), ":path") + require.Contains(t, path, "chat-messages") + }) + }) +} + +func RunTritonWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{ + "type": "triton", "apiTokens": []string{"t"}, "tritonDomain": "triton.example.com", + }) + t.Run("headers_and_body", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + _ = h.CallOnHttpRequestHeaders([][2]string{ + {":authority", "ex.com"}, {":path", "/v1/chat/completions"}, {":method", "POST"}, {"Content-Type", "application/json"}, + }) + _ = h.CallOnHttpRequestBody([]byte(`{"model":"m1","messages":[{"role":"user","content":"hi"}]}`)) + path, _ := wasmhost.GetHeaderValue(h.GetRequestHeaders(), ":path") + require.Contains(t, path, "m1") + }) + }) +} + +func RunVllmWasmSmokeTests(t *testing.T) { + wasmhost.RunTest(t, func(t *testing.T) { + cfg := providerSmokeLegacyJSON(map[string]interface{}{"type": "vllm"}) + t.Run("parse", func(t *testing.T) { + h, st := wasmhost.NewTestHost(cfg) + defer h.Reset() + require.Equal(t, types.OnPluginStartStatusOK, st) + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/spark.go b/plugins/wasm-go/extensions/ai-proxy/test/spark.go new file mode 100644 index 00000000..333ad89b --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/spark.go @@ -0,0 +1,79 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var basicSparkConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "spark", + "apiTokens": []string{"spark-test-token"}, + "modelMapping": map[string]string{ + "*": "generalv3.5", + }, + }) +}() + +// RunSparkParseConfigTests exercises Spark plugin config loading (ValidateConfig is a no-op). +func RunSparkParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic spark config", func(t *testing.T) { + host, status := test.NewTestHost(basicSparkConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + }) +} + +// RunSparkOnHttpRequestHeadersTests exercises Spark request header transforms. +func RunSparkOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("spark chat completions headers", func(t *testing.T) { + host, status := test.NewTestHost(basicSparkConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + hostValue, ok := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, ok) + require.Equal(t, "spark-api-open.xf-yun.com", hostValue) + + authValue, ok := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, ok) + require.Contains(t, authValue, "Bearer spark-test-token") + + pathValue, ok := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, ok) + require.Equal(t, "/v1/chat/completions", pathValue) + + debugLogs := host.GetDebugLogs() + found := false + for _, log := range debugLogs { + if strings.Contains(log, "spark") || strings.Contains(log, "ai-proxy") { + found = true + break + } + } + require.True(t, found, "expected ai-proxy or spark debug logs") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/together_ai.go b/plugins/wasm-go/extensions/ai-proxy/test/together_ai.go new file mode 100644 index 00000000..c0e2499c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/together_ai.go @@ -0,0 +1,92 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +var basicTogetherAIConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "together-ai", + "apiTokens": []string{"together-test-key"}, + "modelMapping": map[string]string{ + "*": "meta-llama/Llama-3-8b-chat-hf", + }, + }) +}() + +var invalidTogetherAIConfig = func() json.RawMessage { + return LegacyProviderPluginJSON(map[string]interface{}{ + "type": "together-ai", + "apiTokens": []string{}, + "modelMapping": map[string]string{"*": "meta-llama/Llama-3-8b-chat-hf"}, + }) +}() + +// RunTogetherAIParseConfigTests exercises Together AI plugin config loading. +func RunTogetherAIParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("basic together-ai config", func(t *testing.T) { + host, status := test.NewTestHost(basicTogetherAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + t.Run("invalid together-ai config missing apiToken", func(t *testing.T) { + host, status := test.NewTestHost(invalidTogetherAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) + }) +} + +// RunTogetherAIOnHttpRequestHeadersTests exercises Together AI request header transforms. +func RunTogetherAIOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("together-ai chat completions headers", func(t *testing.T) { + host, status := test.NewTestHost(basicTogetherAIConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + hostValue, ok := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, ok) + require.Equal(t, "api.together.xyz", hostValue) + + authValue, ok := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, ok) + require.Contains(t, authValue, "Bearer together-test-key") + + pathValue, ok := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, ok) + require.Equal(t, "/v1/chat/completions", pathValue) + + debugLogs := host.GetDebugLogs() + found := false + for _, log := range debugLogs { + if strings.Contains(log, "together") || strings.Contains(log, "ai-proxy") { + found = true + break + } + } + require.True(t, found, "expected ai-proxy or together debug logs") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/util.go b/plugins/wasm-go/extensions/ai-proxy/test/util.go index 9bd40990..bfa0e720 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/util.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/util.go @@ -1,11 +1,19 @@ package test import ( + "encoding/json" "testing" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" ) +// LegacyProviderPluginJSON builds the top-level plugin JSON with a single legacy "provider" +// object, matching historical wasm integration tests (see test/openai.go). +func LegacyProviderPluginJSON(provider map[string]interface{}) json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{"provider": provider}) + return json.RawMessage(data) +} + func RunMapRequestPathByCapabilityTests(t *testing.T) { testCases := []struct { name string diff --git a/plugins/wasm-go/extensions/ai-proxy/util/header_slice_test.go b/plugins/wasm-go/extensions/ai-proxy/util/header_slice_test.go new file mode 100644 index 00000000..765cb8b6 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/util/header_slice_test.go @@ -0,0 +1,39 @@ +package util + +import ( + "net/http" + "reflect" + "testing" +) + +func TestCreateHeaders(t *testing.T) { + h := CreateHeaders("Content-Type", "application/json", ":status", "200") + if len(h) != 2 { + t.Fatalf("len=%d", len(h)) + } + if h[0][0] != "Content-Type" || h[0][1] != "application/json" { + t.Fatalf("first pair: %v", h[0]) + } +} + +func TestHeaderToSliceAndSliceToHeader_roundTrip(t *testing.T) { + src := make(http.Header) + src.Set("A", "1") + src.Add("A", "2") + src.Set("B", "3") + + slice := HeaderToSlice(src) + round := SliceToHeader(slice) + + if !reflect.DeepEqual(src["A"], round["A"]) || !reflect.DeepEqual(src["B"], round["B"]) { + t.Fatalf("roundTrip mismatch: %#v vs %#v", src, round) + } +} + +func TestOverwriteRequestPathHeader(t *testing.T) { + h := make(http.Header) + OverwriteRequestPathHeader(h, "/v1/chat/completions") + if h.Get(":path") != "/v1/chat/completions" { + t.Fatalf("path=%q", h.Get(":path")) + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/util/string_test.go b/plugins/wasm-go/extensions/ai-proxy/util/string_test.go index 4042baf0..8ec78704 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/string_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/string_test.go @@ -6,6 +6,33 @@ import ( "github.com/stretchr/testify/assert" ) +func TestMatchStatus(t *testing.T) { + defaultRetryPatterns := []string{"4.*", "5.*"} + tests := []struct { + name string + status string + patterns []string + want bool + }{ + {"200_no_match", "200", defaultRetryPatterns, false}, + {"201_no_match", "201", defaultRetryPatterns, false}, + {"429_matches_4xx", "429", defaultRetryPatterns, true}, + {"400_matches_4xx", "400", defaultRetryPatterns, true}, + {"503_matches_5xx", "503", defaultRetryPatterns, true}, + {"500_matches_5xx", "500", defaultRetryPatterns, true}, + {"exact_503_pattern", "503", []string{"503"}, true}, + {"exact_503_miss", "502", []string{"503"}, false}, + {"empty_patterns", "500", []string{}, false}, + {"empty_status", "", defaultRetryPatterns, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MatchStatus(tt.status, tt.patterns) + assert.Equal(t, tt.want, got) + }) + } +} + func TestDecodeUnicodeEscapes(t *testing.T) { tests := []struct { name string