diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 99c8641ee..0e10bcde9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -352,6 +352,12 @@ func getApiName(path string) provider.ApiName { if strings.HasSuffix(path, "/v1/images/generations") { return provider.ApiNameImageGeneration } + if strings.HasSuffix(path, "/v1/images/variations") { + return provider.ApiNameImageVariation + } + if strings.HasSuffix(path, "/v1/images/edits") { + return provider.ApiNameImageEdit + } if strings.HasSuffix(path, "/v1/batches") { return provider.ApiNameBatches } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index ff00b3c3e..4c60c2320 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -25,8 +25,7 @@ const ( geminiEmbeddingPath = "batchEmbedContents" ) -type geminiProviderInitializer struct { -} +type geminiProviderInitializer struct{} func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error { if config.apiTokens == nil || len(config.apiTokens) == 0 { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 154be2faa..c36b6d23b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -21,11 +21,14 @@ const ( defaultOpenaiEmbeddingsPath = "/v1/embeddings" defaultOpenaiAudioSpeech = "/v1/audio/speech" defaultOpenaiImageGeneration = "/v1/images/generations" + defaultOpenaiImageEdit = "/v1/images/edits" + defaultOpenaiImageVariation = "/v1/images/variations" defaultOpenaiModels = "/v1/models" + defaultOpenaiFiles = "/v1/files" + defaultOpenaiBatchs = "/v1/batches" ) -type openaiProviderInitializer struct { -} +type openaiProviderInitializer struct{} func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error { return nil @@ -37,8 +40,11 @@ func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string { string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath, string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath, string(ApiNameImageGeneration): defaultOpenaiImageGeneration, + string(ApiNameImageEdit): defaultOpenaiImageEdit, + string(ApiNameImageVariation): defaultOpenaiImageVariation, string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech, string(ApiNameModels): defaultOpenaiModels, + string(ApiNameFiles): defaultOpenaiFiles, } } @@ -121,7 +127,7 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam } func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { - if apiName != ApiNameChatCompletion { + if !m.config.needToProcessRequestBody(apiName) { // We don't need to process the request body for other APIs. return types.ActionContinue, nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 9b169469d..ecc1dfc41 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -16,8 +16,10 @@ import ( "github.com/tidwall/sjson" ) -type ApiName string -type Pointcut string +type ( + ApiName string + Pointcut string +) const ( @@ -28,6 +30,8 @@ const ( ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" ApiNameEmbeddings ApiName = "openai/v1/embeddings" ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" + ApiNameImageEdit ApiName = "openai/v1/imageedit" + ApiNameImageVariation ApiName = "openai/v1/imagevariation" ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" ApiNameFiles ApiName = "openai/v1/files" ApiNameBatches ApiName = "openai/v1/batches" @@ -439,6 +443,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { case string(ApiNameChatCompletion), string(ApiNameEmbeddings), string(ApiNameImageGeneration), + string(ApiNameImageVariation), + string(ApiNameImageEdit), string(ApiNameAudioSpeech), string(ApiNameCohereV1Rerank): c.capabilities[capability] = pathJson.String() @@ -703,7 +709,8 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) } func (c *ProviderConfig) handleRequestBody( - provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { + provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, +) (types.Action, error) { // use original protocol if c.IsOriginal() { return types.ActionContinue, nil @@ -771,3 +778,16 @@ func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext headers.Del("Content-Length") } } + +func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool { + switch apiName { + case ApiNameChatCompletion, + ApiNameEmbeddings, + ApiNameImageGeneration, + ApiNameImageEdit, + ApiNameImageVariation, + ApiNameAudioSpeech: + return true + } + return false +}