From c3eb8d0447ffccec55dc9770f314263487042aaf Mon Sep 17 00:00:00 2001 From: Xijun Dai Date: Tue, 15 Jul 2025 19:15:07 +0800 Subject: [PATCH] feat(ai-proxy): add anthropic && gemini apiName (#2551) Signed-off-by: Xijun Dai --- plugins/wasm-go/extensions/ai-proxy/main.go | 43 +++++++++++++------ .../extensions/ai-proxy/provider/claude.go | 18 ++++---- .../extensions/ai-proxy/provider/gemini.go | 19 +++++--- .../extensions/ai-proxy/provider/provider.go | 31 +++++++++---- .../wasm-go/extensions/ai-proxy/util/http.go | 2 + 5 files changed, 76 insertions(+), 37 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 7ab71956a..fe3a539e7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -342,28 +342,28 @@ func checkStream(ctx wrapper.HttpContext) { func getApiName(path string) provider.ApiName { // openai style - if strings.HasSuffix(path, "/v1/chat/completions") { + if strings.HasSuffix(path, provider.PathOpenAIChatCompletions) { return provider.ApiNameChatCompletion } - if strings.HasSuffix(path, "/v1/completions") { + if strings.HasSuffix(path, provider.PathOpenAICompletions) { return provider.ApiNameCompletion } - if strings.HasSuffix(path, "/v1/embeddings") { + if strings.HasSuffix(path, provider.PathOpenAIEmbeddings) { return provider.ApiNameEmbeddings } - if strings.HasSuffix(path, "/v1/audio/speech") { + if strings.HasSuffix(path, provider.PathOpenAIAudioSpeech) { return provider.ApiNameAudioSpeech } - if strings.HasSuffix(path, "/v1/images/generations") { + if strings.HasSuffix(path, provider.PathOpenAIImageGeneration) { return provider.ApiNameImageGeneration } - if strings.HasSuffix(path, "/v1/images/variations") { + if strings.HasSuffix(path, provider.PathOpenAIImageVariation) { return provider.ApiNameImageVariation } - if strings.HasSuffix(path, "/v1/images/edits") { + if strings.HasSuffix(path, provider.PathOpenAIImageEdit) { return provider.ApiNameImageEdit } - if strings.HasSuffix(path, "/v1/batches") { + if strings.HasSuffix(path, provider.PathOpenAIBatches) { return provider.ApiNameBatches } if util.RegRetrieveBatchPath.MatchString(path) { @@ -372,7 +372,7 @@ func getApiName(path string) provider.ApiName { if util.RegCancelBatchPath.MatchString(path) { return provider.ApiNameCancelBatch } - if strings.HasSuffix(path, "/v1/files") { + if strings.HasSuffix(path, provider.PathOpenAIFiles) { return provider.ApiNameFiles } if util.RegRetrieveFilePath.MatchString(path) { @@ -381,10 +381,10 @@ func getApiName(path string) provider.ApiName { if util.RegRetrieveFileContentPath.MatchString(path) { return provider.ApiNameRetrieveFileContent } - if strings.HasSuffix(path, "/v1/models") { + if strings.HasSuffix(path, provider.PathOpenAIModels) { return provider.ApiNameModels } - if strings.HasSuffix(path, "/v1/fine_tuning/jobs") { + if strings.HasSuffix(path, provider.PathOpenAIFineTuningJobs) { return provider.ApiNameFineTuningJobs } if util.RegRetrieveFineTuningJobPath.MatchString(path) { @@ -411,11 +411,28 @@ func getApiName(path string) provider.ApiName { if util.RegDeleteFineTuningCheckpointPermissionPath.MatchString(path) { return provider.ApiNameDeleteFineTuningCheckpointPermission } - if strings.HasSuffix(path, "/v1/responses") { + if strings.HasSuffix(path, provider.PathOpenAIResponses) { return provider.ApiNameResponses } + + // Anthropic + if strings.HasSuffix(path, provider.PathAnthropicMessages) { + return provider.ApiNameAnthropicMessages + } + if strings.HasSuffix(path, provider.PathAnthropicComplete) { + return provider.ApiNameAnthropicComplete + } + + // Gemini + if util.RegGeminiGenerateContent.MatchString(path) { + return provider.ApiNameGeminiGenerateContent + } + if util.RegGeminiStreamGenerateContent.MatchString(path) { + return provider.ApiNameGeminiStreamGenerateContent + } + // cohere style - if strings.HasSuffix(path, "/v1/rerank") { + if strings.HasSuffix(path, provider.PathCohereV1Rerank) { return provider.ApiNameCohereV1Rerank } return "" diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index a324fe85e..0842e8529 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -9,18 +9,16 @@ import ( "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) // claudeProvider is the provider for Claude service. const ( - claudeDomain = "api.anthropic.com" - claudeChatCompletionPath = "/v1/messages" - claudeCompletionPath = "/v1/complete" - claudeDefaultVersion = "2023-06-01" - claudeDefaultMaxTokens = 4096 + claudeDomain = "api.anthropic.com" + claudeDefaultVersion = "2023-06-01" + claudeDefaultMaxTokens = 4096 ) type claudeProviderInitializer struct{} @@ -123,8 +121,8 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - string(ApiNameChatCompletion): claudeChatCompletionPath, - string(ApiNameCompletion): claudeCompletionPath, + string(ApiNameChatCompletion): PathAnthropicMessages, + string(ApiNameCompletion): PathAnthropicComplete, // docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api string(ApiNameEmbeddings): PathOpenAIEmbeddings, string(ApiNameModels): PathOpenAIModels, @@ -461,10 +459,10 @@ func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, o } func (c *claudeProvider) GetApiName(path string) ApiName { - if strings.Contains(path, claudeChatCompletionPath) { + if strings.Contains(path, PathAnthropicMessages) { return ApiNameChatCompletion } - if strings.Contains(path, claudeCompletionPath) { + if strings.Contains(path, PathAnthropicComplete) { return ApiNameCompletion } if strings.Contains(path, PathOpenAIModels) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 5bdc2f425..b3a9499f7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -9,10 +9,10 @@ import ( "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" - "github.com/higress-group/wasm-go/pkg/log" - "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/google/uuid" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" ) // geminiProvider is the provider for google gemini/gemini flash service. @@ -39,10 +39,12 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - string(ApiNameChatCompletion): "", - string(ApiNameEmbeddings): "", - string(ApiNameModels): "", - string(ApiNameImageGeneration): "", + string(ApiNameChatCompletion): "", + string(ApiNameEmbeddings): "", + string(ApiNameModels): "", + string(ApiNameImageGeneration): "", + string(ApiNameGeminiGenerateContent): "", + string(ApiNameGeminiStreamGenerateContent): "", } } @@ -91,6 +93,7 @@ func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap case ApiNameImageGeneration: return g.onImageGenerationRequestBody(ctx, body, headers) } + log.Debugf("TransformRequestBodyHeaders apiName:%s", apiName) return body, nil } @@ -259,6 +262,10 @@ func (g *geminiProvider) getRequestPath(apiName ApiName, model string, stream bo } case ApiNameImageGeneration: action = geminiImageGenerationPath + case ApiNameGeminiGenerateContent: + action = geminiChatCompletionPath + case ApiNameGeminiStreamGenerateContent: + action = geminiChatCompletionStreamPath } return fmt.Sprintf("/%s/models/%s:%s", g.config.apiVersion, model, action) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 48a485cea..90a419d46 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -10,10 +10,10 @@ import ( "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" - "github.com/higress-group/wasm-go/pkg/log" - "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -54,6 +54,17 @@ const ( ApiNameFineTuningCheckpointPermissions ApiName = "openai/v1/fine-tuningjobcheckpointpermissions" ApiNameDeleteFineTuningCheckpointPermission ApiName = "openai/v1/deletefine-tuningjobcheckpointpermission" + // TODO: 以下是一些非标准的API名称,需要进一步确认是否支持 + ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank" + ApiNameQwenAsyncAIGC ApiName = "qwen/v1/services/aigc" + ApiNameQwenAsyncTask ApiName = "qwen/v1/tasks" + ApiNameQwenV1Rerank ApiName = "qwen/v1/rerank" + ApiNameGeminiGenerateContent ApiName = "gemini/v1beta/generatecontent" + ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent" + ApiNameAnthropicMessages ApiName = "anthropic/v1/messages" + ApiNameAnthropicComplete ApiName = "anthropic/v1/complete" + + // OpenAI PathOpenAICompletions = "/v1/completions" PathOpenAIChatCompletions = "/v1/chat/completions" PathOpenAIEmbeddings = "/v1/embeddings" @@ -79,11 +90,12 @@ const ( PathOpenAIFineTuningCheckpointPermissions = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions" PathOpenAIFineDeleteTuningCheckpointPermission = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions/{permission_id}" - // TODO: 以下是一些非标准的API名称,需要进一步确认是否支持 - ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank" - ApiNameQwenV1Rerank ApiName = "qwen/v1/rerank" - ApiNameQwenAsyncAIGC ApiName = "api/v1/services/aigc" - ApiNameQwenAsyncTask ApiName = "api/v1/tasks/" + // Anthropic + PathAnthropicMessages = "/v1/messages" + PathAnthropicComplete = "/v1/complete" + + // Cohere + PathCohereV1Rerank = "/v1/rerank" providerTypeMoonshot = "moonshot" providerTypeAzure = "azure" @@ -901,7 +913,10 @@ func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool { ApiNameImageVariation, ApiNameAudioSpeech, ApiNameFineTuningJobs, - ApiNameResponses: + ApiNameResponses, + ApiNameGeminiGenerateContent, + ApiNameGeminiStreamGenerateContent, + ApiNameAnthropicMessages: return true } return false diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 041c2d8f3..37fda2ed5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -29,6 +29,8 @@ var ( RegPauseFineTuningJobPath = regexp.MustCompile(`^.*/v1/fine_tuning/jobs/(?P[^/]+)/pause$`) RegFineTuningCheckpointPermissionPath = regexp.MustCompile(`^.*/v1/fine_tuning/checkpoints/(?P[^/]+)/permissions$`) RegDeleteFineTuningCheckpointPermissionPath = regexp.MustCompile(`^.*/v1/fine_tuning/checkpoints/(?P[^/]+)/permissions/(?P[^/]+)$`) + RegGeminiGenerateContent = regexp.MustCompile(`^.*/(?P[^/]+)/models/(?P[^:]+):generateContent`) + RegGeminiStreamGenerateContent = regexp.MustCompile(`^.*/(?P[^/]+)/models/(?P[^:]+):streamGenerateContent`) ) type ErrorHandlerFunc func(statusCodeDetails string, err error) error