mirror of
https://github.com/alibaba/higress.git
synced 2026-03-24 04:57:31 +08:00
feat(ai-proxy): support OpenAI-compatible image and audio model Mapping (#2341)
This commit is contained in:
@@ -352,6 +352,12 @@ func getApiName(path string) provider.ApiName {
|
||||
if strings.HasSuffix(path, "/v1/images/generations") {
|
||||
return provider.ApiNameImageGeneration
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/images/variations") {
|
||||
return provider.ApiNameImageVariation
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/images/edits") {
|
||||
return provider.ApiNameImageEdit
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/batches") {
|
||||
return provider.ApiNameBatches
|
||||
}
|
||||
|
||||
@@ -25,8 +25,7 @@ const (
|
||||
geminiEmbeddingPath = "batchEmbedContents"
|
||||
)
|
||||
|
||||
type geminiProviderInitializer struct {
|
||||
}
|
||||
type geminiProviderInitializer struct{}
|
||||
|
||||
func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
|
||||
@@ -21,11 +21,14 @@ const (
|
||||
defaultOpenaiEmbeddingsPath = "/v1/embeddings"
|
||||
defaultOpenaiAudioSpeech = "/v1/audio/speech"
|
||||
defaultOpenaiImageGeneration = "/v1/images/generations"
|
||||
defaultOpenaiImageEdit = "/v1/images/edits"
|
||||
defaultOpenaiImageVariation = "/v1/images/variations"
|
||||
defaultOpenaiModels = "/v1/models"
|
||||
defaultOpenaiFiles = "/v1/files"
|
||||
defaultOpenaiBatchs = "/v1/batches"
|
||||
)
|
||||
|
||||
type openaiProviderInitializer struct {
|
||||
}
|
||||
type openaiProviderInitializer struct{}
|
||||
|
||||
func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
return nil
|
||||
@@ -37,8 +40,11 @@ func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
|
||||
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
|
||||
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
|
||||
string(ApiNameImageEdit): defaultOpenaiImageEdit,
|
||||
string(ApiNameImageVariation): defaultOpenaiImageVariation,
|
||||
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
|
||||
string(ApiNameModels): defaultOpenaiModels,
|
||||
string(ApiNameFiles): defaultOpenaiFiles,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +127,7 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
}
|
||||
|
||||
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.needToProcessRequestBody(apiName) {
|
||||
// We don't need to process the request body for other APIs.
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -16,8 +16,10 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type ApiName string
|
||||
type Pointcut string
|
||||
type (
|
||||
ApiName string
|
||||
Pointcut string
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -28,6 +30,8 @@ const (
|
||||
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
||||
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
||||
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
||||
ApiNameImageEdit ApiName = "openai/v1/imageedit"
|
||||
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
|
||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||
ApiNameFiles ApiName = "openai/v1/files"
|
||||
ApiNameBatches ApiName = "openai/v1/batches"
|
||||
@@ -439,6 +443,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
case string(ApiNameChatCompletion),
|
||||
string(ApiNameEmbeddings),
|
||||
string(ApiNameImageGeneration),
|
||||
string(ApiNameImageVariation),
|
||||
string(ApiNameImageEdit),
|
||||
string(ApiNameAudioSpeech),
|
||||
string(ApiNameCohereV1Rerank):
|
||||
c.capabilities[capability] = pathJson.String()
|
||||
@@ -703,7 +709,8 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestBody(
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte,
|
||||
) (types.Action, error) {
|
||||
// use original protocol
|
||||
if c.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
@@ -771,3 +778,16 @@ func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion,
|
||||
ApiNameEmbeddings,
|
||||
ApiNameImageGeneration,
|
||||
ApiNameImageEdit,
|
||||
ApiNameImageVariation,
|
||||
ApiNameAudioSpeech:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user