mirror of
https://github.com/alibaba/higress.git
synced 2026-04-21 20:17:29 +08:00
feat(ai-proxy): add anthropic && gemini apiName (#2551)
Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
@@ -342,28 +342,28 @@ func checkStream(ctx wrapper.HttpContext) {
|
|||||||
|
|
||||||
func getApiName(path string) provider.ApiName {
|
func getApiName(path string) provider.ApiName {
|
||||||
// openai style
|
// openai style
|
||||||
if strings.HasSuffix(path, "/v1/chat/completions") {
|
if strings.HasSuffix(path, provider.PathOpenAIChatCompletions) {
|
||||||
return provider.ApiNameChatCompletion
|
return provider.ApiNameChatCompletion
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/completions") {
|
if strings.HasSuffix(path, provider.PathOpenAICompletions) {
|
||||||
return provider.ApiNameCompletion
|
return provider.ApiNameCompletion
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/embeddings") {
|
if strings.HasSuffix(path, provider.PathOpenAIEmbeddings) {
|
||||||
return provider.ApiNameEmbeddings
|
return provider.ApiNameEmbeddings
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/audio/speech") {
|
if strings.HasSuffix(path, provider.PathOpenAIAudioSpeech) {
|
||||||
return provider.ApiNameAudioSpeech
|
return provider.ApiNameAudioSpeech
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/images/generations") {
|
if strings.HasSuffix(path, provider.PathOpenAIImageGeneration) {
|
||||||
return provider.ApiNameImageGeneration
|
return provider.ApiNameImageGeneration
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/images/variations") {
|
if strings.HasSuffix(path, provider.PathOpenAIImageVariation) {
|
||||||
return provider.ApiNameImageVariation
|
return provider.ApiNameImageVariation
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/images/edits") {
|
if strings.HasSuffix(path, provider.PathOpenAIImageEdit) {
|
||||||
return provider.ApiNameImageEdit
|
return provider.ApiNameImageEdit
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/batches") {
|
if strings.HasSuffix(path, provider.PathOpenAIBatches) {
|
||||||
return provider.ApiNameBatches
|
return provider.ApiNameBatches
|
||||||
}
|
}
|
||||||
if util.RegRetrieveBatchPath.MatchString(path) {
|
if util.RegRetrieveBatchPath.MatchString(path) {
|
||||||
@@ -372,7 +372,7 @@ func getApiName(path string) provider.ApiName {
|
|||||||
if util.RegCancelBatchPath.MatchString(path) {
|
if util.RegCancelBatchPath.MatchString(path) {
|
||||||
return provider.ApiNameCancelBatch
|
return provider.ApiNameCancelBatch
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/files") {
|
if strings.HasSuffix(path, provider.PathOpenAIFiles) {
|
||||||
return provider.ApiNameFiles
|
return provider.ApiNameFiles
|
||||||
}
|
}
|
||||||
if util.RegRetrieveFilePath.MatchString(path) {
|
if util.RegRetrieveFilePath.MatchString(path) {
|
||||||
@@ -381,10 +381,10 @@ func getApiName(path string) provider.ApiName {
|
|||||||
if util.RegRetrieveFileContentPath.MatchString(path) {
|
if util.RegRetrieveFileContentPath.MatchString(path) {
|
||||||
return provider.ApiNameRetrieveFileContent
|
return provider.ApiNameRetrieveFileContent
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/models") {
|
if strings.HasSuffix(path, provider.PathOpenAIModels) {
|
||||||
return provider.ApiNameModels
|
return provider.ApiNameModels
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/fine_tuning/jobs") {
|
if strings.HasSuffix(path, provider.PathOpenAIFineTuningJobs) {
|
||||||
return provider.ApiNameFineTuningJobs
|
return provider.ApiNameFineTuningJobs
|
||||||
}
|
}
|
||||||
if util.RegRetrieveFineTuningJobPath.MatchString(path) {
|
if util.RegRetrieveFineTuningJobPath.MatchString(path) {
|
||||||
@@ -411,11 +411,28 @@ func getApiName(path string) provider.ApiName {
|
|||||||
if util.RegDeleteFineTuningCheckpointPermissionPath.MatchString(path) {
|
if util.RegDeleteFineTuningCheckpointPermissionPath.MatchString(path) {
|
||||||
return provider.ApiNameDeleteFineTuningCheckpointPermission
|
return provider.ApiNameDeleteFineTuningCheckpointPermission
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(path, "/v1/responses") {
|
if strings.HasSuffix(path, provider.PathOpenAIResponses) {
|
||||||
return provider.ApiNameResponses
|
return provider.ApiNameResponses
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Anthropic
|
||||||
|
if strings.HasSuffix(path, provider.PathAnthropicMessages) {
|
||||||
|
return provider.ApiNameAnthropicMessages
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(path, provider.PathAnthropicComplete) {
|
||||||
|
return provider.ApiNameAnthropicComplete
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gemini
|
||||||
|
if util.RegGeminiGenerateContent.MatchString(path) {
|
||||||
|
return provider.ApiNameGeminiGenerateContent
|
||||||
|
}
|
||||||
|
if util.RegGeminiStreamGenerateContent.MatchString(path) {
|
||||||
|
return provider.ApiNameGeminiStreamGenerateContent
|
||||||
|
}
|
||||||
|
|
||||||
// cohere style
|
// cohere style
|
||||||
if strings.HasSuffix(path, "/v1/rerank") {
|
if strings.HasSuffix(path, provider.PathCohereV1Rerank) {
|
||||||
return provider.ApiNameCohereV1Rerank
|
return provider.ApiNameCohereV1Rerank
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -9,18 +9,16 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// claudeProvider is the provider for Claude service.
|
// claudeProvider is the provider for Claude service.
|
||||||
const (
|
const (
|
||||||
claudeDomain = "api.anthropic.com"
|
claudeDomain = "api.anthropic.com"
|
||||||
claudeChatCompletionPath = "/v1/messages"
|
claudeDefaultVersion = "2023-06-01"
|
||||||
claudeCompletionPath = "/v1/complete"
|
claudeDefaultMaxTokens = 4096
|
||||||
claudeDefaultVersion = "2023-06-01"
|
|
||||||
claudeDefaultMaxTokens = 4096
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type claudeProviderInitializer struct{}
|
type claudeProviderInitializer struct{}
|
||||||
@@ -123,8 +121,8 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
|||||||
|
|
||||||
func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
|
func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameChatCompletion): claudeChatCompletionPath,
|
string(ApiNameChatCompletion): PathAnthropicMessages,
|
||||||
string(ApiNameCompletion): claudeCompletionPath,
|
string(ApiNameCompletion): PathAnthropicComplete,
|
||||||
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
|
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
|
||||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||||
string(ApiNameModels): PathOpenAIModels,
|
string(ApiNameModels): PathOpenAIModels,
|
||||||
@@ -461,10 +459,10 @@ func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, o
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *claudeProvider) GetApiName(path string) ApiName {
|
func (c *claudeProvider) GetApiName(path string) ApiName {
|
||||||
if strings.Contains(path, claudeChatCompletionPath) {
|
if strings.Contains(path, PathAnthropicMessages) {
|
||||||
return ApiNameChatCompletion
|
return ApiNameChatCompletion
|
||||||
}
|
}
|
||||||
if strings.Contains(path, claudeCompletionPath) {
|
if strings.Contains(path, PathAnthropicComplete) {
|
||||||
return ApiNameCompletion
|
return ApiNameCompletion
|
||||||
}
|
}
|
||||||
if strings.Contains(path, PathOpenAIModels) {
|
if strings.Contains(path, PathOpenAIModels) {
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
|
||||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
)
|
)
|
||||||
|
|
||||||
// geminiProvider is the provider for google gemini/gemini flash service.
|
// geminiProvider is the provider for google gemini/gemini flash service.
|
||||||
@@ -39,10 +39,12 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
|||||||
|
|
||||||
func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
|
func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameChatCompletion): "",
|
string(ApiNameChatCompletion): "",
|
||||||
string(ApiNameEmbeddings): "",
|
string(ApiNameEmbeddings): "",
|
||||||
string(ApiNameModels): "",
|
string(ApiNameModels): "",
|
||||||
string(ApiNameImageGeneration): "",
|
string(ApiNameImageGeneration): "",
|
||||||
|
string(ApiNameGeminiGenerateContent): "",
|
||||||
|
string(ApiNameGeminiStreamGenerateContent): "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,6 +93,7 @@ func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
|||||||
case ApiNameImageGeneration:
|
case ApiNameImageGeneration:
|
||||||
return g.onImageGenerationRequestBody(ctx, body, headers)
|
return g.onImageGenerationRequestBody(ctx, body, headers)
|
||||||
}
|
}
|
||||||
|
log.Debugf("TransformRequestBodyHeaders apiName:%s", apiName)
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,6 +262,10 @@ func (g *geminiProvider) getRequestPath(apiName ApiName, model string, stream bo
|
|||||||
}
|
}
|
||||||
case ApiNameImageGeneration:
|
case ApiNameImageGeneration:
|
||||||
action = geminiImageGenerationPath
|
action = geminiImageGenerationPath
|
||||||
|
case ApiNameGeminiGenerateContent:
|
||||||
|
action = geminiChatCompletionPath
|
||||||
|
case ApiNameGeminiStreamGenerateContent:
|
||||||
|
action = geminiChatCompletionStreamPath
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("/%s/models/%s:%s", g.config.apiVersion, model, action)
|
return fmt.Sprintf("/%s/models/%s:%s", g.config.apiVersion, model, action)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
|
||||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -54,6 +54,17 @@ const (
|
|||||||
ApiNameFineTuningCheckpointPermissions ApiName = "openai/v1/fine-tuningjobcheckpointpermissions"
|
ApiNameFineTuningCheckpointPermissions ApiName = "openai/v1/fine-tuningjobcheckpointpermissions"
|
||||||
ApiNameDeleteFineTuningCheckpointPermission ApiName = "openai/v1/deletefine-tuningjobcheckpointpermission"
|
ApiNameDeleteFineTuningCheckpointPermission ApiName = "openai/v1/deletefine-tuningjobcheckpointpermission"
|
||||||
|
|
||||||
|
// TODO: 以下是一些非标准的API名称,需要进一步确认是否支持
|
||||||
|
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
|
||||||
|
ApiNameQwenAsyncAIGC ApiName = "qwen/v1/services/aigc"
|
||||||
|
ApiNameQwenAsyncTask ApiName = "qwen/v1/tasks"
|
||||||
|
ApiNameQwenV1Rerank ApiName = "qwen/v1/rerank"
|
||||||
|
ApiNameGeminiGenerateContent ApiName = "gemini/v1beta/generatecontent"
|
||||||
|
ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent"
|
||||||
|
ApiNameAnthropicMessages ApiName = "anthropic/v1/messages"
|
||||||
|
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
|
||||||
|
|
||||||
|
// OpenAI
|
||||||
PathOpenAICompletions = "/v1/completions"
|
PathOpenAICompletions = "/v1/completions"
|
||||||
PathOpenAIChatCompletions = "/v1/chat/completions"
|
PathOpenAIChatCompletions = "/v1/chat/completions"
|
||||||
PathOpenAIEmbeddings = "/v1/embeddings"
|
PathOpenAIEmbeddings = "/v1/embeddings"
|
||||||
@@ -79,11 +90,12 @@ const (
|
|||||||
PathOpenAIFineTuningCheckpointPermissions = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions"
|
PathOpenAIFineTuningCheckpointPermissions = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions"
|
||||||
PathOpenAIFineDeleteTuningCheckpointPermission = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions/{permission_id}"
|
PathOpenAIFineDeleteTuningCheckpointPermission = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions/{permission_id}"
|
||||||
|
|
||||||
// TODO: 以下是一些非标准的API名称,需要进一步确认是否支持
|
// Anthropic
|
||||||
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
|
PathAnthropicMessages = "/v1/messages"
|
||||||
ApiNameQwenV1Rerank ApiName = "qwen/v1/rerank"
|
PathAnthropicComplete = "/v1/complete"
|
||||||
ApiNameQwenAsyncAIGC ApiName = "api/v1/services/aigc"
|
|
||||||
ApiNameQwenAsyncTask ApiName = "api/v1/tasks/"
|
// Cohere
|
||||||
|
PathCohereV1Rerank = "/v1/rerank"
|
||||||
|
|
||||||
providerTypeMoonshot = "moonshot"
|
providerTypeMoonshot = "moonshot"
|
||||||
providerTypeAzure = "azure"
|
providerTypeAzure = "azure"
|
||||||
@@ -901,7 +913,10 @@ func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool {
|
|||||||
ApiNameImageVariation,
|
ApiNameImageVariation,
|
||||||
ApiNameAudioSpeech,
|
ApiNameAudioSpeech,
|
||||||
ApiNameFineTuningJobs,
|
ApiNameFineTuningJobs,
|
||||||
ApiNameResponses:
|
ApiNameResponses,
|
||||||
|
ApiNameGeminiGenerateContent,
|
||||||
|
ApiNameGeminiStreamGenerateContent,
|
||||||
|
ApiNameAnthropicMessages:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ var (
|
|||||||
RegPauseFineTuningJobPath = regexp.MustCompile(`^.*/v1/fine_tuning/jobs/(?P<fine_tuning_job_id>[^/]+)/pause$`)
|
RegPauseFineTuningJobPath = regexp.MustCompile(`^.*/v1/fine_tuning/jobs/(?P<fine_tuning_job_id>[^/]+)/pause$`)
|
||||||
RegFineTuningCheckpointPermissionPath = regexp.MustCompile(`^.*/v1/fine_tuning/checkpoints/(?P<fine_tuned_model_checkpoint>[^/]+)/permissions$`)
|
RegFineTuningCheckpointPermissionPath = regexp.MustCompile(`^.*/v1/fine_tuning/checkpoints/(?P<fine_tuned_model_checkpoint>[^/]+)/permissions$`)
|
||||||
RegDeleteFineTuningCheckpointPermissionPath = regexp.MustCompile(`^.*/v1/fine_tuning/checkpoints/(?P<fine_tuned_model_checkpoint>[^/]+)/permissions/(?P<permission_id>[^/]+)$`)
|
RegDeleteFineTuningCheckpointPermissionPath = regexp.MustCompile(`^.*/v1/fine_tuning/checkpoints/(?P<fine_tuned_model_checkpoint>[^/]+)/permissions/(?P<permission_id>[^/]+)$`)
|
||||||
|
RegGeminiGenerateContent = regexp.MustCompile(`^.*/(?P<api_version>[^/]+)/models/(?P<model>[^:]+):generateContent`)
|
||||||
|
RegGeminiStreamGenerateContent = regexp.MustCompile(`^.*/(?P<api_version>[^/]+)/models/(?P<model>[^:]+):streamGenerateContent`)
|
||||||
)
|
)
|
||||||
|
|
||||||
type ErrorHandlerFunc func(statusCodeDetails string, err error) error
|
type ErrorHandlerFunc func(statusCodeDetails string, err error) error
|
||||||
|
|||||||
Reference in New Issue
Block a user