mirror of
https://github.com/alibaba/higress.git
synced 2026-06-02 00:57:28 +08:00
fix(ai-proxy): Do not change the configured components of Azure URL (#2782)
This commit is contained in:
@@ -14,6 +14,8 @@ import (
|
|||||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type azureServiceUrlType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
pathAzurePrefix = "/openai"
|
pathAzurePrefix = "/openai"
|
||||||
pathAzureModelPlaceholder = "{model}"
|
pathAzureModelPlaceholder = "{model}"
|
||||||
@@ -21,6 +23,12 @@ const (
|
|||||||
queryAzureApiVersion = "api-version"
|
queryAzureApiVersion = "api-version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
azureServiceUrlTypeFull azureServiceUrlType = iota
|
||||||
|
azureServiceUrlTypeWithDeployment
|
||||||
|
azureServiceUrlTypeDomainOnly
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
azureModelIrrelevantApis = map[ApiName]bool{
|
azureModelIrrelevantApis = map[ApiName]bool{
|
||||||
ApiNameModels: true,
|
ApiNameModels: true,
|
||||||
@@ -31,7 +39,7 @@ var (
|
|||||||
ApiNameRetrieveFile: true,
|
ApiNameRetrieveFile: true,
|
||||||
ApiNameRetrieveFileContent: true,
|
ApiNameRetrieveFileContent: true,
|
||||||
}
|
}
|
||||||
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(/.*|$)")
|
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(?:/(.*)|$)")
|
||||||
)
|
)
|
||||||
|
|
||||||
// azureProvider is the provider for Azure OpenAI service.
|
// azureProvider is the provider for Azure OpenAI service.
|
||||||
@@ -82,12 +90,20 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
|
|||||||
|
|
||||||
modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path)
|
modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path)
|
||||||
defaultModel := "placeholder"
|
defaultModel := "placeholder"
|
||||||
|
var serviceUrlType azureServiceUrlType
|
||||||
if modelSubMatch != nil {
|
if modelSubMatch != nil {
|
||||||
defaultModel = modelSubMatch[1]
|
defaultModel = modelSubMatch[1]
|
||||||
|
if modelSubMatch[2] != "" {
|
||||||
|
serviceUrlType = azureServiceUrlTypeFull
|
||||||
|
} else {
|
||||||
|
serviceUrlType = azureServiceUrlTypeWithDeployment
|
||||||
|
}
|
||||||
log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel)
|
log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel)
|
||||||
} else {
|
} else {
|
||||||
|
serviceUrlType = azureServiceUrlTypeDomainOnly
|
||||||
log.Debugf("azureProvider: no default model found in serviceUrl")
|
log.Debugf("azureProvider: no default model found in serviceUrl")
|
||||||
}
|
}
|
||||||
|
log.Debugf("azureProvider: serviceUrlType=%d", serviceUrlType)
|
||||||
|
|
||||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||||
apiVersion := serviceUrl.Query().Get(queryAzureApiVersion)
|
apiVersion := serviceUrl.Query().Get(queryAzureApiVersion)
|
||||||
@@ -95,6 +111,8 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
|
|||||||
return &azureProvider{
|
return &azureProvider{
|
||||||
config: config,
|
config: config,
|
||||||
serviceUrl: serviceUrl,
|
serviceUrl: serviceUrl,
|
||||||
|
serviceUrlType: serviceUrlType,
|
||||||
|
serviceUrlFullPath: serviceUrl.Path + "?" + serviceUrl.RawQuery,
|
||||||
apiVersion: apiVersion,
|
apiVersion: apiVersion,
|
||||||
defaultModel: defaultModel,
|
defaultModel: defaultModel,
|
||||||
contextCache: createContextCache(&config),
|
contextCache: createContextCache(&config),
|
||||||
@@ -106,6 +124,8 @@ type azureProvider struct {
|
|||||||
|
|
||||||
contextCache *contextCache
|
contextCache *contextCache
|
||||||
serviceUrl *url.URL
|
serviceUrl *url.URL
|
||||||
|
serviceUrlFullPath string
|
||||||
|
serviceUrlType azureServiceUrlType
|
||||||
apiVersion string
|
apiVersion string
|
||||||
defaultModel string
|
defaultModel string
|
||||||
}
|
}
|
||||||
@@ -152,21 +172,31 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap
|
|||||||
return originalPath
|
return originalPath
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.serviceUrlType == azureServiceUrlTypeFull {
|
||||||
|
log.Debugf("azureProvider: use configured path %s", m.serviceUrlFullPath)
|
||||||
|
return m.serviceUrlFullPath
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("azureProvider: original request path: %s", originalPath)
|
log.Debugf("azureProvider: original request path: %s", originalPath)
|
||||||
path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities)
|
path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities)
|
||||||
log.Debugf("azureProvider: path: %s", path)
|
log.Debugf("azureProvider: path: %s", path)
|
||||||
if strings.Contains(path, pathAzureModelPlaceholder) {
|
if strings.Contains(path, pathAzureModelPlaceholder) {
|
||||||
log.Debugf("azureProvider: path contains placeholder: %s", path)
|
log.Debugf("azureProvider: path contains placeholder: %s", path)
|
||||||
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
|
var model string
|
||||||
|
if m.serviceUrlType == azureServiceUrlTypeWithDeployment {
|
||||||
|
model = m.defaultModel
|
||||||
|
} else {
|
||||||
|
model = ctx.GetStringContext(ctxKeyFinalRequestModel, "")
|
||||||
log.Debugf("azureProvider: model from context: %s", model)
|
log.Debugf("azureProvider: model from context: %s", model)
|
||||||
if model == "" {
|
if model == "" {
|
||||||
model = m.defaultModel
|
model = m.defaultModel
|
||||||
log.Debugf("azureProvider: use default model: %s", model)
|
log.Debugf("azureProvider: use default model: %s", model)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model)
|
path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model)
|
||||||
log.Debugf("azureProvider: model replaced path: %s", path)
|
log.Debugf("azureProvider: model replaced path: %s", path)
|
||||||
}
|
}
|
||||||
path = fmt.Sprintf("%s?%s=%s", path, queryAzureApiVersion, m.apiVersion)
|
path = path + "?" + m.serviceUrl.RawQuery
|
||||||
log.Debugf("azureProvider: final path: %s", path)
|
log.Debugf("azureProvider: final path: %s", path)
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|||||||
Reference in New Issue
Block a user