diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 6f05120b9..1d24c1e53 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -69,6 +69,7 @@ var ( {provider.PathOpenAIResponses, provider.ApiNameResponses}, {provider.PathOpenAIVideos, provider.ApiNameVideos}, // Anthropic style + {provider.PathAnthropicMessagesCountTokens, provider.ApiNameAnthropicCountTokens}, {provider.PathAnthropicMessages, provider.ApiNameAnthropicMessages}, {provider.PathAnthropicComplete, provider.ApiNameAnthropicComplete}, // Cohere style diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index bb1bbd75c..522a0ea40 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -54,6 +54,7 @@ func Test_getApiName(t *testing.T) { {"openai delete fine tuning checkpoint permission", "/v1/fine_tuning/checkpoints/checkpointid/permissions/permissionid", provider.ApiNameDeleteFineTuningCheckpointPermission}, {"openai responses", "/v1/responses", provider.ApiNameResponses}, // Anthropic + {"anthropic count_tokens", "/v1/messages/count_tokens", provider.ApiNameAnthropicCountTokens}, {"anthropic messages", "/v1/messages", provider.ApiNameAnthropicMessages}, {"anthropic complete", "/v1/complete", provider.ApiNameAnthropicComplete}, // Gemini diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index e530901f2..4d74895b2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -77,6 +77,7 @@ const ( ApiNameGeminiGenerateContent ApiName = "gemini/v1beta/generatecontent" ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent" ApiNameAnthropicMessages ApiName = "anthropic/v1/messages" + ApiNameAnthropicCountTokens ApiName = "anthropic/v1/messages/count_tokens" ApiNameAnthropicComplete ApiName = "anthropic/v1/complete" ApiNameVertexRaw ApiName = "vertex/raw" @@ -115,8 +116,9 @@ const ( PathOpenAIRetrieveVideoContent = "/v1/videos/{video_id}/content" // Anthropic - PathAnthropicMessages = "/v1/messages" - PathAnthropicComplete = "/v1/complete" + PathAnthropicMessages = "/v1/messages" + PathAnthropicMessagesCountTokens = "/v1/messages/count_tokens" + PathAnthropicComplete = "/v1/complete" // Cohere PathCohereV1Rerank = "/v1/rerank" @@ -1478,7 +1480,8 @@ func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool { ApiNameResponses, ApiNameGeminiGenerateContent, ApiNameGeminiStreamGenerateContent, - ApiNameAnthropicMessages: + ApiNameAnthropicMessages, + ApiNameAnthropicCountTokens: return true } return false diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vllm.go b/plugins/wasm-go/extensions/ai-proxy/provider/vllm.go index 994e03930..7a10a6be4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vllm.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vllm.go @@ -14,10 +14,20 @@ const ( defaultVllmDomain = "vllm-service.cluster.local" ) -// isVllmDirectPath checks if the path is a known standard vLLM interface path. +// isVllmDirectPath checks if the path is a known standard vLLM interface path, +// i.e. the configured vllmCustomUrl already points at a concrete endpoint rather +// than a base path. Such paths are forwarded as-is; base paths get the per-API +// suffix appended. Must cover every endpoint in DefaultCapabilities that a user +// might configure directly, otherwise the path is mistakenly treated as a base +// and double-appended (e.g. /v1/responses -> /v1/responses/responses). func isVllmDirectPath(path string) bool { return strings.HasSuffix(path, "/completions") || - strings.HasSuffix(path, "/rerank") + strings.HasSuffix(path, "/rerank") || + strings.HasSuffix(path, "/responses") || + strings.HasSuffix(path, "/messages") || + strings.HasSuffix(path, "/count_tokens") || + strings.HasSuffix(path, "/transcriptions") || + strings.HasSuffix(path, "/translations") } type vllmProviderInitializer struct{} @@ -36,6 +46,13 @@ func (m *vllmProviderInitializer) DefaultCapabilities() map[string]string { string(ApiNameModels): PathOpenAIModels, string(ApiNameEmbeddings): PathOpenAIEmbeddings, string(ApiNameCohereV1Rerank): PathCohereV1Rerank, + // vLLM also natively serves the Anthropic Messages API and newer OpenAI + // endpoints; expose them as passthrough (no protocol translation). + string(ApiNameAnthropicMessages): PathAnthropicMessages, + string(ApiNameAnthropicCountTokens): PathAnthropicMessagesCountTokens, + string(ApiNameResponses): PathOpenAIResponses, + string(ApiNameAudioTranscription): PathOpenAIAudioTranscriptions, + string(ApiNameAudioTranslation): PathOpenAIAudioTranslations, } } @@ -154,6 +171,22 @@ func (m *vllmProvider) GetApiName(path string) ApiName { if strings.Contains(path, PathCohereV1Rerank) { return ApiNameCohereV1Rerank } + // count_tokens must be matched before /v1/messages: the former contains the latter. + if strings.Contains(path, PathAnthropicMessagesCountTokens) { + return ApiNameAnthropicCountTokens + } + if strings.Contains(path, PathAnthropicMessages) { + return ApiNameAnthropicMessages + } + if strings.Contains(path, PathOpenAIResponses) { + return ApiNameResponses + } + if strings.Contains(path, PathOpenAIAudioTranscriptions) { + return ApiNameAudioTranscription + } + if strings.Contains(path, PathOpenAIAudioTranslations) { + return ApiNameAudioTranslation + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vllm_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/vllm_test.go new file mode 100644 index 000000000..aff4805e6 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vllm_test.go @@ -0,0 +1,193 @@ +package provider + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVllmProviderInitializer_DefaultCapabilities(t *testing.T) { + initializer := &vllmProviderInitializer{} + + capabilities := initializer.DefaultCapabilities() + expected := map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameCompletion): PathOpenAICompletions, + string(ApiNameModels): PathOpenAIModels, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + string(ApiNameCohereV1Rerank): PathCohereV1Rerank, + string(ApiNameAnthropicMessages): PathAnthropicMessages, + string(ApiNameAnthropicCountTokens): PathAnthropicMessagesCountTokens, + string(ApiNameResponses): PathOpenAIResponses, + string(ApiNameAudioTranscription): PathOpenAIAudioTranscriptions, + string(ApiNameAudioTranslation): PathOpenAIAudioTranslations, + } + + assert.Equal(t, expected, capabilities) +} + +func TestVllmProvider_GetApiName(t *testing.T) { + provider := &vllmProvider{} + + cases := []struct { + path string + expected ApiName + }{ + // existing (regression guard) + {PathOpenAIChatCompletions, ApiNameChatCompletion}, + {PathOpenAICompletions, ApiNameCompletion}, + {PathOpenAIModels, ApiNameModels}, + {PathOpenAIEmbeddings, ApiNameEmbeddings}, + {PathCohereV1Rerank, ApiNameCohereV1Rerank}, + // new passthrough endpoints + // count_tokens must be checked before /v1/messages (substring) — guards the ordering + {PathAnthropicMessagesCountTokens, ApiNameAnthropicCountTokens}, + {PathAnthropicMessages, ApiNameAnthropicMessages}, + {PathOpenAIResponses, ApiNameResponses}, + {PathOpenAIAudioTranscriptions, ApiNameAudioTranscription}, + {PathOpenAIAudioTranslations, ApiNameAudioTranslation}, + // unknown path + {"/v1/unknown", ApiName("")}, + } + + for _, c := range cases { + t.Run(c.path, func(t *testing.T) { + assert.Equal(t, c.expected, provider.GetApiName(c.path)) + }) + } +} + +func TestVllm_isVllmDirectPath(t *testing.T) { + cases := []struct { + path string + want bool + }{ + // existing direct endpoints + {"/v1/chat/completions", true}, + {"/v1/completions", true}, + {"/v1/rerank", true}, + // newly added passthrough endpoints + {"/v1/responses", true}, + {"/v1/messages", true}, + {"/v1/messages/count_tokens", true}, + {"/v1/audio/transcriptions", true}, + {"/v1/audio/translations", true}, + // base paths must NOT be treated as direct endpoints + {"/v1", false}, + {"/", false}, + {"/custom", false}, + } + + for _, c := range cases { + t.Run(c.path, func(t *testing.T) { + assert.Equal(t, c.want, isVllmDirectPath(c.path)) + }) + } +} + +// TestVllmProviderInitializer_CreateProvider_customUrl verifies vllmCustomUrl +// handling: a base path gets the per-API suffix appended, while a direct endpoint +// URL is forwarded as-is (no double-append such as /v1/responses/responses). +func TestVllmProviderInitializer_CreateProvider_customUrl(t *testing.T) { + initializer := &vllmProviderInitializer{} + + cases := []struct { + name string + customUrl string + wantDirect bool + wantPath string // expected customPath when direct + wantDomain string + capability ApiName // sample capability to check for base paths + wantCap string + }{ + { + name: "base path v1", + customUrl: "http://host:8000/v1", + wantDirect: false, + wantDomain: "host:8000", + capability: ApiNameResponses, + wantCap: "/v1/responses", + }, + { + name: "custom base path", + customUrl: "http://host:8000/custom", + wantDirect: false, + wantDomain: "host:8000", + capability: ApiNameAnthropicMessages, + wantCap: "/custom/messages", + }, + { + name: "direct responses endpoint", + customUrl: "http://host:8000/v1/responses", + wantDirect: true, + wantPath: "/v1/responses", + wantDomain: "host:8000", + }, + { + name: "direct anthropic messages endpoint", + customUrl: "http://host:8000/v1/messages", + wantDirect: true, + wantPath: "/v1/messages", + wantDomain: "host:8000", + }, + { + name: "direct count_tokens endpoint", + customUrl: "http://host:8000/v1/messages/count_tokens", + wantDirect: true, + wantPath: "/v1/messages/count_tokens", + wantDomain: "host:8000", + }, + { + name: "direct audio transcription endpoint", + customUrl: "http://host:8000/v1/audio/transcriptions", + wantDirect: true, + wantPath: "/v1/audio/transcriptions", + wantDomain: "host:8000", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + p, err := initializer.CreateProvider(ProviderConfig{vllmCustomUrl: c.customUrl}) + assert.NoError(t, err) + vp, ok := p.(*vllmProvider) + assert.True(t, ok) + assert.Equal(t, c.wantDirect, vp.isDirectCustomPath) + assert.Equal(t, c.wantDomain, vp.customDomain) + if c.wantDirect { + assert.Equal(t, c.wantPath, vp.customPath) + } + if c.capability != "" { + assert.Equal(t, c.wantCap, vp.config.capabilities[string(c.capability)]) + } + }) + } +} + +// TestVllm_passthroughBodyAndSupport guards the body-handling and capability +// behaviour of the passthrough endpoints. +func TestVllm_passthroughBodyAndSupport(t *testing.T) { + cfg := &ProviderConfig{} + // Audio endpoints carry multipart/form-data bodies and must be passed through + // untouched (no JSON processing). + assert.False(t, cfg.needToProcessRequestBody(ApiNameAudioTranscription)) + assert.False(t, cfg.needToProcessRequestBody(ApiNameAudioTranslation)) + // Anthropic messages / count_tokens / responses carry JSON bodies (model + // mapping etc.), so they are processed. + assert.True(t, cfg.needToProcessRequestBody(ApiNameAnthropicMessages)) + assert.True(t, cfg.needToProcessRequestBody(ApiNameAnthropicCountTokens)) + assert.True(t, cfg.needToProcessRequestBody(ApiNameResponses)) + + // A vLLM provider declares the count_tokens capability and supports it. + vllmCfg := &ProviderConfig{} + vllmCfg.setDefaultCapabilities((&vllmProviderInitializer{}).DefaultCapabilities()) + assert.True(t, vllmCfg.isSupportedAPI(ApiNameAnthropicCountTokens)) + + // A provider that does not declare it (the path is now globally recognized) + // rejects the request via isSupportedAPI instead of mishandling it. + otherCfg := &ProviderConfig{} + otherCfg.setDefaultCapabilities(map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + }) + assert.False(t, otherCfg.isSupportedAPI(ApiNameAnthropicCountTokens)) +}