From 0f1afcdcca3b09653dff9c378de158f76a67d4d2 Mon Sep 17 00:00:00 2001 From: Kent Dong Date: Mon, 18 Aug 2025 16:27:25 +0800 Subject: [PATCH] fix(ai-proxy): Do not change the configured components of Azure URL (#2782) --- .../extensions/ai-proxy/provider/azure.go | 60 ++++++++++++++----- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 22cdbf583..1a68af01b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -14,6 +14,8 @@ import ( "github.com/higress-group/wasm-go/pkg/wrapper" ) +type azureServiceUrlType int + const ( pathAzurePrefix = "/openai" pathAzureModelPlaceholder = "{model}" @@ -21,6 +23,12 @@ const ( queryAzureApiVersion = "api-version" ) +const ( + azureServiceUrlTypeFull azureServiceUrlType = iota + azureServiceUrlTypeWithDeployment + azureServiceUrlTypeDomainOnly +) + var ( azureModelIrrelevantApis = map[ApiName]bool{ ApiNameModels: true, @@ -31,7 +39,7 @@ var ( ApiNameRetrieveFile: true, ApiNameRetrieveFileContent: true, } - regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(/.*|$)") + regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(?:/(.*)|$)") ) // azureProvider is the provider for Azure OpenAI service. @@ -82,32 +90,44 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path) defaultModel := "placeholder" + var serviceUrlType azureServiceUrlType if modelSubMatch != nil { defaultModel = modelSubMatch[1] + if modelSubMatch[2] != "" { + serviceUrlType = azureServiceUrlTypeFull + } else { + serviceUrlType = azureServiceUrlTypeWithDeployment + } log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel) } else { + serviceUrlType = azureServiceUrlTypeDomainOnly log.Debugf("azureProvider: no default model found in serviceUrl") } + log.Debugf("azureProvider: serviceUrlType=%d", serviceUrlType) config.setDefaultCapabilities(m.DefaultCapabilities()) apiVersion := serviceUrl.Query().Get(queryAzureApiVersion) log.Debugf("azureProvider: using %s: %s", queryAzureApiVersion, apiVersion) return &azureProvider{ - config: config, - serviceUrl: serviceUrl, - apiVersion: apiVersion, - defaultModel: defaultModel, - contextCache: createContextCache(&config), + config: config, + serviceUrl: serviceUrl, + serviceUrlType: serviceUrlType, + serviceUrlFullPath: serviceUrl.Path + "?" + serviceUrl.RawQuery, + apiVersion: apiVersion, + defaultModel: defaultModel, + contextCache: createContextCache(&config), }, nil } type azureProvider struct { config ProviderConfig - contextCache *contextCache - serviceUrl *url.URL - apiVersion string - defaultModel string + contextCache *contextCache + serviceUrl *url.URL + serviceUrlFullPath string + serviceUrlType azureServiceUrlType + apiVersion string + defaultModel string } func (m *azureProvider) GetProviderType() string { @@ -152,21 +172,31 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap return originalPath } + if m.serviceUrlType == azureServiceUrlTypeFull { + log.Debugf("azureProvider: use configured path %s", m.serviceUrlFullPath) + return m.serviceUrlFullPath + } + log.Debugf("azureProvider: original request path: %s", originalPath) path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities) log.Debugf("azureProvider: path: %s", path) if strings.Contains(path, pathAzureModelPlaceholder) { log.Debugf("azureProvider: path contains placeholder: %s", path) - model := ctx.GetStringContext(ctxKeyFinalRequestModel, "") - log.Debugf("azureProvider: model from context: %s", model) - if model == "" { + var model string + if m.serviceUrlType == azureServiceUrlTypeWithDeployment { model = m.defaultModel - log.Debugf("azureProvider: use default model: %s", model) + } else { + model = ctx.GetStringContext(ctxKeyFinalRequestModel, "") + log.Debugf("azureProvider: model from context: %s", model) + if model == "" { + model = m.defaultModel + log.Debugf("azureProvider: use default model: %s", model) + } } path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model) log.Debugf("azureProvider: model replaced path: %s", path) } - path = fmt.Sprintf("%s?%s=%s", path, queryAzureApiVersion, m.apiVersion) + path = path + "?" + m.serviceUrl.RawQuery log.Debugf("azureProvider: final path: %s", path) return path