mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
feat(ai-proxy): add anthropic && gemini apiName (#2551)
Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
@@ -342,28 +342,28 @@ func checkStream(ctx wrapper.HttpContext) {
|
||||
|
||||
func getApiName(path string) provider.ApiName {
|
||||
// openai style
|
||||
if strings.HasSuffix(path, "/v1/chat/completions") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIChatCompletions) {
|
||||
return provider.ApiNameChatCompletion
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/completions") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAICompletions) {
|
||||
return provider.ApiNameCompletion
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/embeddings") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIEmbeddings) {
|
||||
return provider.ApiNameEmbeddings
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/audio/speech") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIAudioSpeech) {
|
||||
return provider.ApiNameAudioSpeech
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/images/generations") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIImageGeneration) {
|
||||
return provider.ApiNameImageGeneration
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/images/variations") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIImageVariation) {
|
||||
return provider.ApiNameImageVariation
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/images/edits") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIImageEdit) {
|
||||
return provider.ApiNameImageEdit
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/batches") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIBatches) {
|
||||
return provider.ApiNameBatches
|
||||
}
|
||||
if util.RegRetrieveBatchPath.MatchString(path) {
|
||||
@@ -372,7 +372,7 @@ func getApiName(path string) provider.ApiName {
|
||||
if util.RegCancelBatchPath.MatchString(path) {
|
||||
return provider.ApiNameCancelBatch
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/files") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIFiles) {
|
||||
return provider.ApiNameFiles
|
||||
}
|
||||
if util.RegRetrieveFilePath.MatchString(path) {
|
||||
@@ -381,10 +381,10 @@ func getApiName(path string) provider.ApiName {
|
||||
if util.RegRetrieveFileContentPath.MatchString(path) {
|
||||
return provider.ApiNameRetrieveFileContent
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/models") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIModels) {
|
||||
return provider.ApiNameModels
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/fine_tuning/jobs") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIFineTuningJobs) {
|
||||
return provider.ApiNameFineTuningJobs
|
||||
}
|
||||
if util.RegRetrieveFineTuningJobPath.MatchString(path) {
|
||||
@@ -411,11 +411,28 @@ func getApiName(path string) provider.ApiName {
|
||||
if util.RegDeleteFineTuningCheckpointPermissionPath.MatchString(path) {
|
||||
return provider.ApiNameDeleteFineTuningCheckpointPermission
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/responses") {
|
||||
if strings.HasSuffix(path, provider.PathOpenAIResponses) {
|
||||
return provider.ApiNameResponses
|
||||
}
|
||||
|
||||
// Anthropic
|
||||
if strings.HasSuffix(path, provider.PathAnthropicMessages) {
|
||||
return provider.ApiNameAnthropicMessages
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathAnthropicComplete) {
|
||||
return provider.ApiNameAnthropicComplete
|
||||
}
|
||||
|
||||
// Gemini
|
||||
if util.RegGeminiGenerateContent.MatchString(path) {
|
||||
return provider.ApiNameGeminiGenerateContent
|
||||
}
|
||||
if util.RegGeminiStreamGenerateContent.MatchString(path) {
|
||||
return provider.ApiNameGeminiStreamGenerateContent
|
||||
}
|
||||
|
||||
// cohere style
|
||||
if strings.HasSuffix(path, "/v1/rerank") {
|
||||
if strings.HasSuffix(path, provider.PathCohereV1Rerank) {
|
||||
return provider.ApiNameCohereV1Rerank
|
||||
}
|
||||
return ""
|
||||
|
||||
@@ -9,18 +9,16 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
// claudeProvider is the provider for Claude service.
|
||||
const (
|
||||
claudeDomain = "api.anthropic.com"
|
||||
claudeChatCompletionPath = "/v1/messages"
|
||||
claudeCompletionPath = "/v1/complete"
|
||||
claudeDefaultVersion = "2023-06-01"
|
||||
claudeDefaultMaxTokens = 4096
|
||||
claudeDomain = "api.anthropic.com"
|
||||
claudeDefaultVersion = "2023-06-01"
|
||||
claudeDefaultMaxTokens = 4096
|
||||
)
|
||||
|
||||
type claudeProviderInitializer struct{}
|
||||
@@ -123,8 +121,8 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): claudeChatCompletionPath,
|
||||
string(ApiNameCompletion): claudeCompletionPath,
|
||||
string(ApiNameChatCompletion): PathAnthropicMessages,
|
||||
string(ApiNameCompletion): PathAnthropicComplete,
|
||||
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
string(ApiNameModels): PathOpenAIModels,
|
||||
@@ -461,10 +459,10 @@ func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, o
|
||||
}
|
||||
|
||||
func (c *claudeProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, claudeChatCompletionPath) {
|
||||
if strings.Contains(path, PathAnthropicMessages) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if strings.Contains(path, claudeCompletionPath) {
|
||||
if strings.Contains(path, PathAnthropicComplete) {
|
||||
return ApiNameCompletion
|
||||
}
|
||||
if strings.Contains(path, PathOpenAIModels) {
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
// geminiProvider is the provider for google gemini/gemini flash service.
|
||||
@@ -39,10 +39,12 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): "",
|
||||
string(ApiNameEmbeddings): "",
|
||||
string(ApiNameModels): "",
|
||||
string(ApiNameImageGeneration): "",
|
||||
string(ApiNameChatCompletion): "",
|
||||
string(ApiNameEmbeddings): "",
|
||||
string(ApiNameModels): "",
|
||||
string(ApiNameImageGeneration): "",
|
||||
string(ApiNameGeminiGenerateContent): "",
|
||||
string(ApiNameGeminiStreamGenerateContent): "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +93,7 @@ func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
||||
case ApiNameImageGeneration:
|
||||
return g.onImageGenerationRequestBody(ctx, body, headers)
|
||||
}
|
||||
log.Debugf("TransformRequestBodyHeaders apiName:%s", apiName)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@@ -259,6 +262,10 @@ func (g *geminiProvider) getRequestPath(apiName ApiName, model string, stream bo
|
||||
}
|
||||
case ApiNameImageGeneration:
|
||||
action = geminiImageGenerationPath
|
||||
case ApiNameGeminiGenerateContent:
|
||||
action = geminiChatCompletionPath
|
||||
case ApiNameGeminiStreamGenerateContent:
|
||||
action = geminiChatCompletionStreamPath
|
||||
}
|
||||
return fmt.Sprintf("/%s/models/%s:%s", g.config.apiVersion, model, action)
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -54,6 +54,17 @@ const (
|
||||
ApiNameFineTuningCheckpointPermissions ApiName = "openai/v1/fine-tuningjobcheckpointpermissions"
|
||||
ApiNameDeleteFineTuningCheckpointPermission ApiName = "openai/v1/deletefine-tuningjobcheckpointpermission"
|
||||
|
||||
// TODO: 以下是一些非标准的API名称,需要进一步确认是否支持
|
||||
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
|
||||
ApiNameQwenAsyncAIGC ApiName = "qwen/v1/services/aigc"
|
||||
ApiNameQwenAsyncTask ApiName = "qwen/v1/tasks"
|
||||
ApiNameQwenV1Rerank ApiName = "qwen/v1/rerank"
|
||||
ApiNameGeminiGenerateContent ApiName = "gemini/v1beta/generatecontent"
|
||||
ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent"
|
||||
ApiNameAnthropicMessages ApiName = "anthropic/v1/messages"
|
||||
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
|
||||
|
||||
// OpenAI
|
||||
PathOpenAICompletions = "/v1/completions"
|
||||
PathOpenAIChatCompletions = "/v1/chat/completions"
|
||||
PathOpenAIEmbeddings = "/v1/embeddings"
|
||||
@@ -79,11 +90,12 @@ const (
|
||||
PathOpenAIFineTuningCheckpointPermissions = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions"
|
||||
PathOpenAIFineDeleteTuningCheckpointPermission = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions/{permission_id}"
|
||||
|
||||
// TODO: 以下是一些非标准的API名称,需要进一步确认是否支持
|
||||
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
|
||||
ApiNameQwenV1Rerank ApiName = "qwen/v1/rerank"
|
||||
ApiNameQwenAsyncAIGC ApiName = "api/v1/services/aigc"
|
||||
ApiNameQwenAsyncTask ApiName = "api/v1/tasks/"
|
||||
// Anthropic
|
||||
PathAnthropicMessages = "/v1/messages"
|
||||
PathAnthropicComplete = "/v1/complete"
|
||||
|
||||
// Cohere
|
||||
PathCohereV1Rerank = "/v1/rerank"
|
||||
|
||||
providerTypeMoonshot = "moonshot"
|
||||
providerTypeAzure = "azure"
|
||||
@@ -901,7 +913,10 @@ func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool {
|
||||
ApiNameImageVariation,
|
||||
ApiNameAudioSpeech,
|
||||
ApiNameFineTuningJobs,
|
||||
ApiNameResponses:
|
||||
ApiNameResponses,
|
||||
ApiNameGeminiGenerateContent,
|
||||
ApiNameGeminiStreamGenerateContent,
|
||||
ApiNameAnthropicMessages:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -29,6 +29,8 @@ var (
|
||||
RegPauseFineTuningJobPath = regexp.MustCompile(`^.*/v1/fine_tuning/jobs/(?P<fine_tuning_job_id>[^/]+)/pause$`)
|
||||
RegFineTuningCheckpointPermissionPath = regexp.MustCompile(`^.*/v1/fine_tuning/checkpoints/(?P<fine_tuned_model_checkpoint>[^/]+)/permissions$`)
|
||||
RegDeleteFineTuningCheckpointPermissionPath = regexp.MustCompile(`^.*/v1/fine_tuning/checkpoints/(?P<fine_tuned_model_checkpoint>[^/]+)/permissions/(?P<permission_id>[^/]+)$`)
|
||||
RegGeminiGenerateContent = regexp.MustCompile(`^.*/(?P<api_version>[^/]+)/models/(?P<model>[^:]+):generateContent`)
|
||||
RegGeminiStreamGenerateContent = regexp.MustCompile(`^.*/(?P<api_version>[^/]+)/models/(?P<model>[^:]+):streamGenerateContent`)
|
||||
)
|
||||
|
||||
type ErrorHandlerFunc func(statusCodeDetails string, err error) error
|
||||
|
||||
Reference in New Issue
Block a user