mirror of
https://github.com/alibaba/higress.git
synced 2026-05-27 22:27:29 +08:00
feat(ai-proxy): support OpenAI-compatible image and audio model Mapping (#2341)
This commit is contained in:
@@ -352,6 +352,12 @@ func getApiName(path string) provider.ApiName {
|
|||||||
if strings.HasSuffix(path, "/v1/images/generations") {
|
if strings.HasSuffix(path, "/v1/images/generations") {
|
||||||
return provider.ApiNameImageGeneration
|
return provider.ApiNameImageGeneration
|
||||||
}
|
}
|
||||||
|
if strings.HasSuffix(path, "/v1/images/variations") {
|
||||||
|
return provider.ApiNameImageVariation
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(path, "/v1/images/edits") {
|
||||||
|
return provider.ApiNameImageEdit
|
||||||
|
}
|
||||||
if strings.HasSuffix(path, "/v1/batches") {
|
if strings.HasSuffix(path, "/v1/batches") {
|
||||||
return provider.ApiNameBatches
|
return provider.ApiNameBatches
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,8 +25,7 @@ const (
|
|||||||
geminiEmbeddingPath = "batchEmbedContents"
|
geminiEmbeddingPath = "batchEmbedContents"
|
||||||
)
|
)
|
||||||
|
|
||||||
type geminiProviderInitializer struct {
|
type geminiProviderInitializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||||
|
|||||||
@@ -21,11 +21,14 @@ const (
|
|||||||
defaultOpenaiEmbeddingsPath = "/v1/embeddings"
|
defaultOpenaiEmbeddingsPath = "/v1/embeddings"
|
||||||
defaultOpenaiAudioSpeech = "/v1/audio/speech"
|
defaultOpenaiAudioSpeech = "/v1/audio/speech"
|
||||||
defaultOpenaiImageGeneration = "/v1/images/generations"
|
defaultOpenaiImageGeneration = "/v1/images/generations"
|
||||||
|
defaultOpenaiImageEdit = "/v1/images/edits"
|
||||||
|
defaultOpenaiImageVariation = "/v1/images/variations"
|
||||||
defaultOpenaiModels = "/v1/models"
|
defaultOpenaiModels = "/v1/models"
|
||||||
|
defaultOpenaiFiles = "/v1/files"
|
||||||
|
defaultOpenaiBatchs = "/v1/batches"
|
||||||
)
|
)
|
||||||
|
|
||||||
type openaiProviderInitializer struct {
|
type openaiProviderInitializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||||
return nil
|
return nil
|
||||||
@@ -37,8 +40,11 @@ func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
|
|||||||
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
|
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
|
||||||
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
|
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
|
||||||
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
|
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
|
||||||
|
string(ApiNameImageEdit): defaultOpenaiImageEdit,
|
||||||
|
string(ApiNameImageVariation): defaultOpenaiImageVariation,
|
||||||
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
|
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
|
||||||
string(ApiNameModels): defaultOpenaiModels,
|
string(ApiNameModels): defaultOpenaiModels,
|
||||||
|
string(ApiNameFiles): defaultOpenaiFiles,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +127,7 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||||
if apiName != ApiNameChatCompletion {
|
if !m.config.needToProcessRequestBody(apiName) {
|
||||||
// We don't need to process the request body for other APIs.
|
// We don't need to process the request body for other APIs.
|
||||||
return types.ActionContinue, nil
|
return types.ActionContinue, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,8 +16,10 @@ import (
|
|||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ApiName string
|
type (
|
||||||
type Pointcut string
|
ApiName string
|
||||||
|
Pointcut string
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
||||||
@@ -28,6 +30,8 @@ const (
|
|||||||
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
||||||
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
||||||
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
||||||
|
ApiNameImageEdit ApiName = "openai/v1/imageedit"
|
||||||
|
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
|
||||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||||
ApiNameFiles ApiName = "openai/v1/files"
|
ApiNameFiles ApiName = "openai/v1/files"
|
||||||
ApiNameBatches ApiName = "openai/v1/batches"
|
ApiNameBatches ApiName = "openai/v1/batches"
|
||||||
@@ -439,6 +443,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
|||||||
case string(ApiNameChatCompletion),
|
case string(ApiNameChatCompletion),
|
||||||
string(ApiNameEmbeddings),
|
string(ApiNameEmbeddings),
|
||||||
string(ApiNameImageGeneration),
|
string(ApiNameImageGeneration),
|
||||||
|
string(ApiNameImageVariation),
|
||||||
|
string(ApiNameImageEdit),
|
||||||
string(ApiNameAudioSpeech),
|
string(ApiNameAudioSpeech),
|
||||||
string(ApiNameCohereV1Rerank):
|
string(ApiNameCohereV1Rerank):
|
||||||
c.capabilities[capability] = pathJson.String()
|
c.capabilities[capability] = pathJson.String()
|
||||||
@@ -703,7 +709,8 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) handleRequestBody(
|
func (c *ProviderConfig) handleRequestBody(
|
||||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte,
|
||||||
|
) (types.Action, error) {
|
||||||
// use original protocol
|
// use original protocol
|
||||||
if c.IsOriginal() {
|
if c.IsOriginal() {
|
||||||
return types.ActionContinue, nil
|
return types.ActionContinue, nil
|
||||||
@@ -771,3 +778,16 @@ func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext
|
|||||||
headers.Del("Content-Length")
|
headers.Del("Content-Length")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool {
|
||||||
|
switch apiName {
|
||||||
|
case ApiNameChatCompletion,
|
||||||
|
ApiNameEmbeddings,
|
||||||
|
ApiNameImageGeneration,
|
||||||
|
ApiNameImageEdit,
|
||||||
|
ApiNameImageVariation,
|
||||||
|
ApiNameAudioSpeech:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user