feat(ai-proxy): support OpenAI-compatible image and audio model Mapping (#2341)

This commit is contained in:
Xijun Dai
2025-05-30 12:16:52 +08:00
committed by GitHub
parent 69b755a10d
commit a73c33f1da
4 changed files with 39 additions and 8 deletions

View File

@@ -352,6 +352,12 @@ func getApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/images/generations") { if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration return provider.ApiNameImageGeneration
} }
if strings.HasSuffix(path, "/v1/images/variations") {
return provider.ApiNameImageVariation
}
if strings.HasSuffix(path, "/v1/images/edits") {
return provider.ApiNameImageEdit
}
if strings.HasSuffix(path, "/v1/batches") { if strings.HasSuffix(path, "/v1/batches") {
return provider.ApiNameBatches return provider.ApiNameBatches
} }

View File

@@ -25,8 +25,7 @@ const (
geminiEmbeddingPath = "batchEmbedContents" geminiEmbeddingPath = "batchEmbedContents"
) )
type geminiProviderInitializer struct { type geminiProviderInitializer struct{}
}
func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error { func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {

View File

@@ -21,11 +21,14 @@ const (
defaultOpenaiEmbeddingsPath = "/v1/embeddings" defaultOpenaiEmbeddingsPath = "/v1/embeddings"
defaultOpenaiAudioSpeech = "/v1/audio/speech" defaultOpenaiAudioSpeech = "/v1/audio/speech"
defaultOpenaiImageGeneration = "/v1/images/generations" defaultOpenaiImageGeneration = "/v1/images/generations"
defaultOpenaiImageEdit = "/v1/images/edits"
defaultOpenaiImageVariation = "/v1/images/variations"
defaultOpenaiModels = "/v1/models" defaultOpenaiModels = "/v1/models"
defaultOpenaiFiles = "/v1/files"
defaultOpenaiBatchs = "/v1/batches"
) )
type openaiProviderInitializer struct { type openaiProviderInitializer struct{}
}
func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error { func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil return nil
@@ -37,8 +40,11 @@ func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath, string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath, string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
string(ApiNameImageGeneration): defaultOpenaiImageGeneration, string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
string(ApiNameImageEdit): defaultOpenaiImageEdit,
string(ApiNameImageVariation): defaultOpenaiImageVariation,
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech, string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
string(ApiNameModels): defaultOpenaiModels, string(ApiNameModels): defaultOpenaiModels,
string(ApiNameFiles): defaultOpenaiFiles,
} }
} }
@@ -121,7 +127,7 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
} }
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if apiName != ApiNameChatCompletion { if !m.config.needToProcessRequestBody(apiName) {
// We don't need to process the request body for other APIs. // We don't need to process the request body for other APIs.
return types.ActionContinue, nil return types.ActionContinue, nil
} }

View File

@@ -16,8 +16,10 @@ import (
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
type ApiName string type (
type Pointcut string ApiName string
Pointcut string
)
const ( const (
@@ -28,6 +30,8 @@ const (
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings" ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameImageEdit ApiName = "openai/v1/imageedit"
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
ApiNameFiles ApiName = "openai/v1/files" ApiNameFiles ApiName = "openai/v1/files"
ApiNameBatches ApiName = "openai/v1/batches" ApiNameBatches ApiName = "openai/v1/batches"
@@ -439,6 +443,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
case string(ApiNameChatCompletion), case string(ApiNameChatCompletion),
string(ApiNameEmbeddings), string(ApiNameEmbeddings),
string(ApiNameImageGeneration), string(ApiNameImageGeneration),
string(ApiNameImageVariation),
string(ApiNameImageEdit),
string(ApiNameAudioSpeech), string(ApiNameAudioSpeech),
string(ApiNameCohereV1Rerank): string(ApiNameCohereV1Rerank):
c.capabilities[capability] = pathJson.String() c.capabilities[capability] = pathJson.String()
@@ -703,7 +709,8 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string)
} }
func (c *ProviderConfig) handleRequestBody( func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte,
) (types.Action, error) {
// use original protocol // use original protocol
if c.IsOriginal() { if c.IsOriginal() {
return types.ActionContinue, nil return types.ActionContinue, nil
@@ -771,3 +778,16 @@ func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext
headers.Del("Content-Length") headers.Del("Content-Length")
} }
} }
func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool {
switch apiName {
case ApiNameChatCompletion,
ApiNameEmbeddings,
ApiNameImageGeneration,
ApiNameImageEdit,
ApiNameImageVariation,
ApiNameAudioSpeech:
return true
}
return false
}