From c9fa8d15dbdbbfcbdcfafd0b6256552887ebef00 Mon Sep 17 00:00:00 2001 From: Kent Dong Date: Mon, 18 Aug 2025 19:05:23 +0800 Subject: [PATCH] chore: Restructure the path-to-api-name mapping logic in ai-proxy (#2773) --- plugins/wasm-go/extensions/ai-proxy/main.go | 144 +++++++----------- .../wasm-go/extensions/ai-proxy/main_test.go | 59 +++++++ 2 files changed, 113 insertions(+), 90 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/main_test.go diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 43961e747..512eda775 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -6,6 +6,7 @@ package main import ( "fmt" "net/url" + "regexp" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config" @@ -31,6 +32,11 @@ const ( ctxOriginalAuth = "original_auth" ) +type pair[K, V any] struct { + key K + value V +} + var ( headersCtxKeyMapping = map[string]string{ util.HeaderAuthority: ctxOriginalHost, @@ -42,6 +48,44 @@ var ( util.HeaderPath: util.HeaderOriginalPath, util.HeaderAuthorization: util.HeaderOriginalAuth, } + pathSuffixToApiName = []pair[string, provider.ApiName]{ + // OpenAI style + {provider.PathOpenAIChatCompletions, provider.ApiNameChatCompletion}, + {provider.PathOpenAICompletions, provider.ApiNameCompletion}, + {provider.PathOpenAIEmbeddings, provider.ApiNameEmbeddings}, + {provider.PathOpenAIAudioSpeech, provider.ApiNameAudioSpeech}, + {provider.PathOpenAIImageGeneration, provider.ApiNameImageGeneration}, + {provider.PathOpenAIImageVariation, provider.ApiNameImageVariation}, + {provider.PathOpenAIImageEdit, provider.ApiNameImageEdit}, + {provider.PathOpenAIBatches, provider.ApiNameBatches}, + {provider.PathOpenAIFiles, provider.ApiNameFiles}, + {provider.PathOpenAIModels, provider.ApiNameModels}, + {provider.PathOpenAIFineTuningJobs, provider.ApiNameFineTuningJobs}, + {provider.PathOpenAIResponses, provider.ApiNameResponses}, + // Anthropic style + {provider.PathAnthropicMessages, provider.ApiNameAnthropicMessages}, + {provider.PathAnthropicComplete, provider.ApiNameAnthropicComplete}, + // Cohere style + {provider.PathCohereV1Rerank, provider.ApiNameCohereV1Rerank}, + } + pathPatternToApiName = []pair[*regexp.Regexp, provider.ApiName]{ + // OpenAI style + {util.RegRetrieveBatchPath, provider.ApiNameRetrieveBatch}, + {util.RegCancelBatchPath, provider.ApiNameCancelBatch}, + {util.RegRetrieveFilePath, provider.ApiNameRetrieveFile}, + {util.RegRetrieveFileContentPath, provider.ApiNameRetrieveFileContent}, + {util.RegRetrieveFineTuningJobPath, provider.ApiNameRetrieveFineTuningJob}, + {util.RegRetrieveFineTuningJobEventsPath, provider.ApiNameFineTuningJobEvents}, + {util.RegRetrieveFineTuningJobCheckpointsPath, provider.ApiNameFineTuningJobCheckpoints}, + {util.RegCancelFineTuningJobPath, provider.ApiNameCancelFineTuningJob}, + {util.RegResumeFineTuningJobPath, provider.ApiNameResumeFineTuningJob}, + {util.RegPauseFineTuningJobPath, provider.ApiNamePauseFineTuningJob}, + {util.RegFineTuningCheckpointPermissionPath, provider.ApiNameFineTuningCheckpointPermissions}, + {util.RegDeleteFineTuningCheckpointPermissionPath, provider.ApiNameDeleteFineTuningCheckpointPermission}, + // Gemini style + {util.RegGeminiGenerateContent, provider.ApiNameGeminiGenerateContent}, + {util.RegGeminiStreamGenerateContent, provider.ApiNameGeminiStreamGenerateContent}, + } ) func main() {} @@ -397,99 +441,19 @@ func checkStream(ctx wrapper.HttpContext) { } func getApiName(path string) provider.ApiName { - // openai style - if strings.HasSuffix(path, provider.PathOpenAIChatCompletions) { - return provider.ApiNameChatCompletion - } - if strings.HasSuffix(path, provider.PathOpenAICompletions) { - return provider.ApiNameCompletion - } - if strings.HasSuffix(path, provider.PathOpenAIEmbeddings) { - return provider.ApiNameEmbeddings - } - if strings.HasSuffix(path, provider.PathOpenAIAudioSpeech) { - return provider.ApiNameAudioSpeech - } - if strings.HasSuffix(path, provider.PathOpenAIImageGeneration) { - return provider.ApiNameImageGeneration - } - if strings.HasSuffix(path, provider.PathOpenAIImageVariation) { - return provider.ApiNameImageVariation - } - if strings.HasSuffix(path, provider.PathOpenAIImageEdit) { - return provider.ApiNameImageEdit - } - if strings.HasSuffix(path, provider.PathOpenAIBatches) { - return provider.ApiNameBatches - } - if util.RegRetrieveBatchPath.MatchString(path) { - return provider.ApiNameRetrieveBatch - } - if util.RegCancelBatchPath.MatchString(path) { - return provider.ApiNameCancelBatch - } - if strings.HasSuffix(path, provider.PathOpenAIFiles) { - return provider.ApiNameFiles - } - if util.RegRetrieveFilePath.MatchString(path) { - return provider.ApiNameRetrieveFile - } - if util.RegRetrieveFileContentPath.MatchString(path) { - return provider.ApiNameRetrieveFileContent - } - if strings.HasSuffix(path, provider.PathOpenAIModels) { - return provider.ApiNameModels - } - if strings.HasSuffix(path, provider.PathOpenAIFineTuningJobs) { - return provider.ApiNameFineTuningJobs - } - if util.RegRetrieveFineTuningJobPath.MatchString(path) { - return provider.ApiNameRetrieveFineTuningJob - } - if util.RegRetrieveFineTuningJobEventsPath.MatchString(path) { - return provider.ApiNameFineTuningJobEvents - } - if util.RegRetrieveFineTuningJobCheckpointsPath.MatchString(path) { - return provider.ApiNameFineTuningJobCheckpoints - } - if util.RegCancelFineTuningJobPath.MatchString(path) { - return provider.ApiNameCancelFineTuningJob - } - if util.RegResumeFineTuningJobPath.MatchString(path) { - return provider.ApiNameResumeFineTuningJob - } - if util.RegPauseFineTuningJobPath.MatchString(path) { - return provider.ApiNamePauseFineTuningJob - } - if util.RegFineTuningCheckpointPermissionPath.MatchString(path) { - return provider.ApiNameFineTuningCheckpointPermissions - } - if util.RegDeleteFineTuningCheckpointPermissionPath.MatchString(path) { - return provider.ApiNameDeleteFineTuningCheckpointPermission - } - if strings.HasSuffix(path, provider.PathOpenAIResponses) { - return provider.ApiNameResponses + // Check path suffix matches first + for _, p := range pathSuffixToApiName { + if strings.HasSuffix(path, p.key) { + return p.value + } } - // Anthropic - if strings.HasSuffix(path, provider.PathAnthropicMessages) { - return provider.ApiNameAnthropicMessages - } - if strings.HasSuffix(path, provider.PathAnthropicComplete) { - return provider.ApiNameAnthropicComplete + // Check path pattern matches + for _, p := range pathPatternToApiName { + if p.key.MatchString(path) { + return p.value + } } - // Gemini - if util.RegGeminiGenerateContent.MatchString(path) { - return provider.ApiNameGeminiGenerateContent - } - if util.RegGeminiStreamGenerateContent.MatchString(path) { - return provider.ApiNameGeminiStreamGenerateContent - } - - // cohere style - if strings.HasSuffix(path, provider.PathCohereV1Rerank) { - return provider.ApiNameCohereV1Rerank - } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go new file mode 100644 index 000000000..26bf07846 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -0,0 +1,59 @@ +package main + +import ( + "testing" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" +) + +func Test_getApiName(t *testing.T) { + tests := []struct { + name string + path string + want provider.ApiName + }{ + // OpenAI style + {"openai chat completions", "/v1/chat/completions", provider.ApiNameChatCompletion}, + {"openai completions", "/v1/completions", provider.ApiNameCompletion}, + {"openai embeddings", "/v1/embeddings", provider.ApiNameEmbeddings}, + {"openai audio speech", "/v1/audio/speech", provider.ApiNameAudioSpeech}, + {"openai image generation", "/v1/images/generations", provider.ApiNameImageGeneration}, + {"openai image variation", "/v1/images/variations", provider.ApiNameImageVariation}, + {"openai image edit", "/v1/images/edits", provider.ApiNameImageEdit}, + {"openai batches", "/v1/batches", provider.ApiNameBatches}, + {"openai retrieve batch", "/v1/batches/batchid", provider.ApiNameRetrieveBatch}, + {"openai cancel batch", "/v1/batches/batchid/cancel", provider.ApiNameCancelBatch}, + {"openai files", "/v1/files", provider.ApiNameFiles}, + {"openai retrieve file", "/v1/files/fileid", provider.ApiNameRetrieveFile}, + {"openai retrieve file content", "/v1/files/fileid/content", provider.ApiNameRetrieveFileContent}, + {"openai models", "/v1/models", provider.ApiNameModels}, + {"openai fine tuning jobs", "/v1/fine_tuning/jobs", provider.ApiNameFineTuningJobs}, + {"openai retrieve fine tuning job", "/v1/fine_tuning/jobs/jobid", provider.ApiNameRetrieveFineTuningJob}, + {"openai fine tuning job events", "/v1/fine_tuning/jobs/jobid/events", provider.ApiNameFineTuningJobEvents}, + {"openai fine tuning job checkpoints", "/v1/fine_tuning/jobs/jobid/checkpoints", provider.ApiNameFineTuningJobCheckpoints}, + {"openai cancel fine tuning job", "/v1/fine_tuning/jobs/jobid/cancel", provider.ApiNameCancelFineTuningJob}, + {"openai resume fine tuning job", "/v1/fine_tuning/jobs/jobid/resume", provider.ApiNameResumeFineTuningJob}, + {"openai pause fine tuning job", "/v1/fine_tuning/jobs/jobid/pause", provider.ApiNamePauseFineTuningJob}, + {"openai fine tuning checkpoint permissions", "/v1/fine_tuning/checkpoints/checkpointid/permissions", provider.ApiNameFineTuningCheckpointPermissions}, + {"openai delete fine tuning checkpoint permission", "/v1/fine_tuning/checkpoints/checkpointid/permissions/permissionid", provider.ApiNameDeleteFineTuningCheckpointPermission}, + {"openai responses", "/v1/responses", provider.ApiNameResponses}, + // Anthropic + {"anthropic messages", "/v1/messages", provider.ApiNameAnthropicMessages}, + {"anthropic complete", "/v1/complete", provider.ApiNameAnthropicComplete}, + // Gemini + {"gemini generate content", "/v1beta/models/gemini-1.0-pro:generateContent", provider.ApiNameGeminiGenerateContent}, + {"gemini stream generate content", "/v1beta/models/gemini-1.0-pro:streamGenerateContent", provider.ApiNameGeminiStreamGenerateContent}, + // Cohere + {"cohere rerank", "/v1/rerank", provider.ApiNameCohereV1Rerank}, + // Unknown + {"unknown", "/v1/unknown", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getApiName(tt.path) + if got != tt.want { + t.Errorf("getApiName(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +}