diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 81905ce11..7b742def8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -331,12 +331,12 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) { func getApiName(path string) provider.ApiName { // openai style - if strings.HasSuffix(path, "/v1/completions") { - return provider.ApiNameCompletion - } if strings.HasSuffix(path, "/v1/chat/completions") { return provider.ApiNameChatCompletion } + if strings.HasSuffix(path, "/v1/completions") { + return provider.ApiNameCompletion + } if strings.HasSuffix(path, "/v1/embeddings") { return provider.ApiNameEmbeddings } @@ -346,6 +346,12 @@ func getApiName(path string) provider.ApiName { if strings.HasSuffix(path, "/v1/images/generations") { return provider.ApiNameImageGeneration } + if strings.HasSuffix(path, "/v1/batches") { + return provider.ApiNameBatches + } + if strings.HasSuffix(path, "/v1/files") { + return provider.ApiNameFiles + } // cohere style if strings.HasSuffix(path, "/v1/rerank") { return provider.ApiNameCohereV1Rerank diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index c95b3bcbb..0b5fadb07 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -5,12 +5,18 @@ import ( "fmt" "net/http" "net/url" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) +const ( + pathAzureFiles = "/openai/files" + pathAzureBatches = "/openai/batches" +) + // azureProvider is the provider for Azure OpenAI service. type azureProviderInitializer struct { } @@ -20,6 +26,8 @@ func (m *azureProviderInitializer) DefaultCapabilities() map[string]string { // TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities string(ApiNameChatCompletion): PathOpenAIChatCompletions, string(ApiNameEmbeddings): PathOpenAIEmbeddings, + string(ApiNameFiles): PathOpenAIFiles, + string(ApiNameBatches): PathOpenAIBatches, } } @@ -68,32 +76,47 @@ func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam } func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if !m.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName - } return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - if apiName != "" { - u, e := url.Parse(ctx.Path()) - if e == nil { - customApiVersion := u.Query().Get("api-version") - if customApiVersion == "" { - util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI()) - } else { - q := m.serviceUrl.Query() - q.Set("api-version", customApiVersion) - newUrl := *m.serviceUrl - newUrl.RawQuery = q.Encode() - util.OverwriteRequestPathHeader(headers, newUrl.RequestURI()) + finalRequestUrl := *m.serviceUrl + if u, e := url.Parse(ctx.Path()); e == nil { + if len(u.Query()) != 0 { + q := m.serviceUrl.Query() + for k, v := range u.Query() { + switch len(v) { + case 0: + break + case 1: + q.Set(k, v[0]) + break + default: + delete(q, k) + for _, vv := range v { + q.Add(k, vv) + } + } } - } else { - log.Errorf("failed to parse request path: %v", e) - util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI()) + finalRequestUrl.RawQuery = q.Encode() } + + if filesIndex := strings.Index(u.Path, "/files"); filesIndex != -1 { + finalRequestUrl.Path = pathAzureFiles + u.Path[filesIndex+len("/files"):] + } else if batchesIndex := strings.Index(u.Path, "/batches"); batchesIndex != -1 { + finalRequestUrl.Path = pathAzureBatches + u.Path[batchesIndex+len("/batches"):] + } + } else { + log.Errorf("failed to parse request path: %v", e) } + util.OverwriteRequestPathHeader(headers, finalRequestUrl.RequestURI()) + util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host) headers.Set("api-key", m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") + + if !m.config.isSupportedAPI(apiName) { + // If the API is not supported, we should not read the request body and keep it as it is. + ctx.DontReadRequestBody() + } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 355bbf27c..c08dff807 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -28,10 +28,14 @@ const ( ApiNameEmbeddings ApiName = "openai/v1/embeddings" ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" + ApiNameFiles ApiName = "openai/v1/files" + ApiNameBatches ApiName = "openai/v1/batches" PathOpenAICompletions = "/v1/completions" PathOpenAIChatCompletions = "/v1/chat/completions" PathOpenAIEmbeddings = "/v1/embeddings" + PathOpenAIFiles = "/v1/files" + PathOpenAIBatches = "/v1/batches" // TODO: 以下是一些非标准的API名称,需要进一步确认是否支持 ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"