feat(ai-proxy): support OpenAI-compatible image and audio model Mapping (#2341)

This commit is contained in:
Xijun Dai
2025-05-30 12:16:52 +08:00
committed by GitHub
parent 69b755a10d
commit a73c33f1da
4 changed files with 39 additions and 8 deletions

View File

@@ -352,6 +352,12 @@ func getApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
if strings.HasSuffix(path, "/v1/images/variations") {
return provider.ApiNameImageVariation
}
if strings.HasSuffix(path, "/v1/images/edits") {
return provider.ApiNameImageEdit
}
if strings.HasSuffix(path, "/v1/batches") {
return provider.ApiNameBatches
}

View File

@@ -25,8 +25,7 @@ const (
geminiEmbeddingPath = "batchEmbedContents"
)
type geminiProviderInitializer struct {
}
type geminiProviderInitializer struct{}
func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {

View File

@@ -21,11 +21,14 @@ const (
defaultOpenaiEmbeddingsPath = "/v1/embeddings"
defaultOpenaiAudioSpeech = "/v1/audio/speech"
defaultOpenaiImageGeneration = "/v1/images/generations"
defaultOpenaiImageEdit = "/v1/images/edits"
defaultOpenaiImageVariation = "/v1/images/variations"
defaultOpenaiModels = "/v1/models"
defaultOpenaiFiles = "/v1/files"
defaultOpenaiBatchs = "/v1/batches"
)
type openaiProviderInitializer struct {
}
type openaiProviderInitializer struct{}
func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil
@@ -37,8 +40,11 @@ func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
string(ApiNameImageEdit): defaultOpenaiImageEdit,
string(ApiNameImageVariation): defaultOpenaiImageVariation,
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
string(ApiNameModels): defaultOpenaiModels,
string(ApiNameFiles): defaultOpenaiFiles,
}
}
@@ -121,7 +127,7 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
}
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if apiName != ApiNameChatCompletion {
if !m.config.needToProcessRequestBody(apiName) {
// We don't need to process the request body for other APIs.
return types.ActionContinue, nil
}

View File

@@ -16,8 +16,10 @@ import (
"github.com/tidwall/sjson"
)
type ApiName string
type Pointcut string
type (
ApiName string
Pointcut string
)
const (
@@ -28,6 +30,8 @@ const (
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameImageEdit ApiName = "openai/v1/imageedit"
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
ApiNameFiles ApiName = "openai/v1/files"
ApiNameBatches ApiName = "openai/v1/batches"
@@ -439,6 +443,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
case string(ApiNameChatCompletion),
string(ApiNameEmbeddings),
string(ApiNameImageGeneration),
string(ApiNameImageVariation),
string(ApiNameImageEdit),
string(ApiNameAudioSpeech),
string(ApiNameCohereV1Rerank):
c.capabilities[capability] = pathJson.String()
@@ -703,7 +709,8 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string)
}
func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte,
) (types.Action, error) {
// use original protocol
if c.IsOriginal() {
return types.ActionContinue, nil
@@ -771,3 +778,16 @@ func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext
headers.Del("Content-Length")
}
}
func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool {
switch apiName {
case ApiNameChatCompletion,
ApiNameEmbeddings,
ApiNameImageGeneration,
ApiNameImageEdit,
ApiNameImageVariation,
ApiNameAudioSpeech:
return true
}
return false
}