feature: allow ai-proxy to forward standard AI capabilities that are … (#1704)

This commit is contained in:
pepesi
2025-02-12 15:23:44 +08:00
committed by GitHub
parent 477e44b9f1
commit a84a382f1d
32 changed files with 517 additions and 158 deletions

View File

@@ -78,7 +78,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
rawPath := ctx.Path()
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
apiName := getApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
@@ -103,20 +103,25 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)
hasRequestBody := wrapper.HasRequestBody()
err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
if err != nil {
if providerConfig.PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestHeader] passthrough unsupported API: %v", err)
ctx.DontReadRequestBody()
return types.ActionContinue
}
ctx.DontReadRequestBody()
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
return types.ActionContinue
}
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
hasRequestBody := wrapper.HasRequestBody()
if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Delay the header processing to allow changing in OnRequestBody
return types.HeaderStopIteration
}
ctx.DontReadRequestBody()
return types.ActionContinue
}
@@ -151,6 +156,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if err == nil {
return action
}
if pluginConfig.GetProviderConfig().PassthroughUnsupportedAPI() {
log.Warnf("[onHttpRequestBody] passthrough unsupported API: %v", err)
return types.ActionContinue
}
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
}
return types.ActionContinue
@@ -267,12 +276,23 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
}
}
func getOpenAiApiName(path string) provider.ApiName {
func getApiName(path string) provider.ApiName {
// openai style
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
if strings.HasSuffix(path, "/v1/audio/speech") {
return provider.ApiNameAudioSpeech
}
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
// cohere style
if strings.HasSuffix(path, "/v1/rerank") {
return provider.ApiNameCohereV1Rerank
}
return ""
}