fix(ai-proxy): Do not change the configured components of Azure URL (#2782)

This commit is contained in:
Kent Dong
2025-08-18 16:27:25 +08:00
committed by GitHub
parent 19d1548971
commit 0f1afcdcca

View File

@@ -14,6 +14,8 @@ import (
"github.com/higress-group/wasm-go/pkg/wrapper" "github.com/higress-group/wasm-go/pkg/wrapper"
) )
type azureServiceUrlType int
const ( const (
pathAzurePrefix = "/openai" pathAzurePrefix = "/openai"
pathAzureModelPlaceholder = "{model}" pathAzureModelPlaceholder = "{model}"
@@ -21,6 +23,12 @@ const (
queryAzureApiVersion = "api-version" queryAzureApiVersion = "api-version"
) )
const (
azureServiceUrlTypeFull azureServiceUrlType = iota
azureServiceUrlTypeWithDeployment
azureServiceUrlTypeDomainOnly
)
var ( var (
azureModelIrrelevantApis = map[ApiName]bool{ azureModelIrrelevantApis = map[ApiName]bool{
ApiNameModels: true, ApiNameModels: true,
@@ -31,7 +39,7 @@ var (
ApiNameRetrieveFile: true, ApiNameRetrieveFile: true,
ApiNameRetrieveFileContent: true, ApiNameRetrieveFileContent: true,
} }
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(/.*|$)") regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(?:/(.*)|$)")
) )
// azureProvider is the provider for Azure OpenAI service. // azureProvider is the provider for Azure OpenAI service.
@@ -82,12 +90,20 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path) modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path)
defaultModel := "placeholder" defaultModel := "placeholder"
var serviceUrlType azureServiceUrlType
if modelSubMatch != nil { if modelSubMatch != nil {
defaultModel = modelSubMatch[1] defaultModel = modelSubMatch[1]
if modelSubMatch[2] != "" {
serviceUrlType = azureServiceUrlTypeFull
} else {
serviceUrlType = azureServiceUrlTypeWithDeployment
}
log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel) log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel)
} else { } else {
serviceUrlType = azureServiceUrlTypeDomainOnly
log.Debugf("azureProvider: no default model found in serviceUrl") log.Debugf("azureProvider: no default model found in serviceUrl")
} }
log.Debugf("azureProvider: serviceUrlType=%d", serviceUrlType)
config.setDefaultCapabilities(m.DefaultCapabilities()) config.setDefaultCapabilities(m.DefaultCapabilities())
apiVersion := serviceUrl.Query().Get(queryAzureApiVersion) apiVersion := serviceUrl.Query().Get(queryAzureApiVersion)
@@ -95,6 +111,8 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
return &azureProvider{ return &azureProvider{
config: config, config: config,
serviceUrl: serviceUrl, serviceUrl: serviceUrl,
serviceUrlType: serviceUrlType,
serviceUrlFullPath: serviceUrl.Path + "?" + serviceUrl.RawQuery,
apiVersion: apiVersion, apiVersion: apiVersion,
defaultModel: defaultModel, defaultModel: defaultModel,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -106,6 +124,8 @@ type azureProvider struct {
contextCache *contextCache contextCache *contextCache
serviceUrl *url.URL serviceUrl *url.URL
serviceUrlFullPath string
serviceUrlType azureServiceUrlType
apiVersion string apiVersion string
defaultModel string defaultModel string
} }
@@ -152,21 +172,31 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap
return originalPath return originalPath
} }
if m.serviceUrlType == azureServiceUrlTypeFull {
log.Debugf("azureProvider: use configured path %s", m.serviceUrlFullPath)
return m.serviceUrlFullPath
}
log.Debugf("azureProvider: original request path: %s", originalPath) log.Debugf("azureProvider: original request path: %s", originalPath)
path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities) path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities)
log.Debugf("azureProvider: path: %s", path) log.Debugf("azureProvider: path: %s", path)
if strings.Contains(path, pathAzureModelPlaceholder) { if strings.Contains(path, pathAzureModelPlaceholder) {
log.Debugf("azureProvider: path contains placeholder: %s", path) log.Debugf("azureProvider: path contains placeholder: %s", path)
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "") var model string
if m.serviceUrlType == azureServiceUrlTypeWithDeployment {
model = m.defaultModel
} else {
model = ctx.GetStringContext(ctxKeyFinalRequestModel, "")
log.Debugf("azureProvider: model from context: %s", model) log.Debugf("azureProvider: model from context: %s", model)
if model == "" { if model == "" {
model = m.defaultModel model = m.defaultModel
log.Debugf("azureProvider: use default model: %s", model) log.Debugf("azureProvider: use default model: %s", model)
} }
}
path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model) path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model)
log.Debugf("azureProvider: model replaced path: %s", path) log.Debugf("azureProvider: model replaced path: %s", path)
} }
path = fmt.Sprintf("%s?%s=%s", path, queryAzureApiVersion, m.apiVersion) path = path + "?" + m.serviceUrl.RawQuery
log.Debugf("azureProvider: final path: %s", path) log.Debugf("azureProvider: final path: %s", path)
return path return path