mirror of
https://github.com/alibaba/higress.git
synced 2026-05-25 21:28:17 +08:00
Signed-off-by: wydream <yaodiwu618@gmail.com> Co-authored-by: EndlessSeeker <153817598+EndlessSeeker@users.noreply.github.com>
377 lines
16 KiB
Go
377 lines
16 KiB
Go
package provider
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"net/url"
|
||
"path"
|
||
"regexp"
|
||
"strings"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/higress-group/wasm-go/pkg/log"
|
||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||
)
|
||
|
||
type azureServiceUrlType int
|
||
|
||
const (
|
||
pathAzurePrefix = "/openai"
|
||
pathAzureOpenAIV1 = "/openai/v1"
|
||
pathAzureModelPlaceholder = "{model}"
|
||
pathAzureWithModelPrefix = "/openai/deployments/" + pathAzureModelPlaceholder
|
||
queryAzureApiVersion = "api-version"
|
||
)
|
||
|
||
const (
|
||
azureServiceUrlTypeFull azureServiceUrlType = iota
|
||
azureServiceUrlTypeWithDeployment
|
||
azureServiceUrlTypeDomainOnly
|
||
azureServiceUrlTypeOpenAIV1Base
|
||
)
|
||
|
||
var (
|
||
azureModelIrrelevantApis = map[ApiName]bool{
|
||
ApiNameModels: true,
|
||
ApiNameBatches: true,
|
||
ApiNameRetrieveBatch: true,
|
||
ApiNameCancelBatch: true,
|
||
ApiNameFiles: true,
|
||
ApiNameRetrieveFile: true,
|
||
ApiNameRetrieveFileContent: true,
|
||
ApiNameResponses: true,
|
||
}
|
||
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(?:/(.*)|$)")
|
||
)
|
||
|
||
// azureProvider is the provider for Azure OpenAI service.
|
||
type azureProviderInitializer struct {
|
||
}
|
||
|
||
func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
|
||
var capabilities = map[string]string{}
|
||
for k, v := range (&openaiProviderInitializer{}).DefaultCapabilities() {
|
||
if !strings.HasPrefix(v, PathOpenAIPrefix) {
|
||
log.Warnf("azureProviderInitializer: capability %s has an unexpected path %s, skipping", k, v)
|
||
continue
|
||
}
|
||
path := strings.TrimPrefix(v, PathOpenAIPrefix)
|
||
if azureModelIrrelevantApis[ApiName(k)] {
|
||
path = pathAzurePrefix + path
|
||
} else {
|
||
path = pathAzureWithModelPrefix + path
|
||
}
|
||
capabilities[k] = path
|
||
log.Debugf("azureProviderInitializer: capability %s -> %s", k, path)
|
||
}
|
||
return capabilities
|
||
}
|
||
|
||
// DefaultOpenAIV1Capabilities 生成 Azure OpenAI v1 base_url 模式下的 capability path。
|
||
// 输入约束:basePath 必须是 Azure v1 base path,通常为 /openai/v1,调用方应先完成 URL mode 判定。
|
||
// 输出语义:把 OpenAI 标准 /v1/... 能力路径平移到 Azure /openai/v1/...,不引入 deployment 和 api-version。
|
||
// 边界场景:若后续 Azure v1 base path 发生变化,只需要同步调用方传入的 basePath 与 spec 契约。
|
||
func (m *azureProviderInitializer) DefaultOpenAIV1Capabilities(basePath string) map[string]string {
|
||
return defaultOpenAIV1Capabilities(basePath, (&openaiProviderInitializer{}).DefaultCapabilities())
|
||
}
|
||
|
||
// defaultOpenAIV1Capabilities 将 OpenAI capability map 转换为 Azure OpenAI v1 path map。
|
||
// 输入约束:openAICapabilities 中的 path 应以 /v1 开头;异常 path 会被跳过,避免生成不可用上游路径。
|
||
// 输出语义:返回的 path 统一位于 basePath 下,不包含 deployment 占位符和日期型 api-version。
|
||
// 边界场景:保留异常 capability 跳过逻辑,防止未来 OpenAI capability 定义变更时污染 Azure v1 默认映射。
|
||
func defaultOpenAIV1Capabilities(basePath string, openAICapabilities map[string]string) map[string]string {
|
||
var capabilities = map[string]string{}
|
||
for k, v := range openAICapabilities {
|
||
if !strings.HasPrefix(v, PathOpenAIPrefix) {
|
||
log.Warnf("azureProviderInitializer: capability %s has an unexpected path %s, skipping", k, v)
|
||
continue
|
||
}
|
||
capabilities[k] = path.Join(basePath, strings.TrimPrefix(v, PathOpenAIPrefix))
|
||
log.Debugf("azureProviderInitializer: v1 capability %s -> %s", k, capabilities[k])
|
||
}
|
||
return capabilities
|
||
}
|
||
|
||
// ValidateConfig 校验 Azure OpenAI provider 配置是否满足启动条件。
|
||
// 输入约束:azureServiceUrl 必须是合法 URL,apiTokens 至少包含一个可用 token。
|
||
// 输出语义:返回 nil 表示插件可启动;返回 error 时会阻止当前 provider 配置生效。
|
||
// 边界场景:/openai/v1 新版路径不要求日期型 api-version,legacy deployment 或 domain-only 模式仍要求非空 api-version。
|
||
func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||
if config.azureServiceUrl == "" {
|
||
return errors.New("missing azureServiceUrl in provider config")
|
||
}
|
||
if azureServiceUrl, err := url.Parse(config.azureServiceUrl); err != nil {
|
||
return fmt.Errorf("invalid azureServiceUrl: %w", err)
|
||
} else if err := validateAzureServiceURLAPIVersion(azureServiceUrl, config.azureServiceUrl); err != nil {
|
||
return err
|
||
}
|
||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||
return errors.New("no apiToken found in provider config")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// validateAzureServiceURLAPIVersion 按 Azure OpenAI URL path mode 校验 api-version。
|
||
// 输入约束:serviceURL 应来自 azureServiceUrl 的解析结果,rawServiceURL 仅用于错误信息回显。
|
||
// 输出语义:v1 路径直接通过;legacy 路径只有携带非空 api-version 时通过。
|
||
// 边界场景:/openai/v10、/openai/deployments/... 和仅域名 URL 都不属于 v1 模式。
|
||
func validateAzureServiceURLAPIVersion(serviceURL *url.URL, rawServiceURL string) error {
|
||
if isAzureOpenAIV1Path(serviceURL.Path) {
|
||
// APIGO-CONTRACT: azure-service-url-api-version-optional
|
||
// 新版 Azure OpenAI v1 API 以 /openai/v1 作为 base_url,产品机制不再要求日期型 api-version;
|
||
// 若微软 v1 API 路径或版本机制再次变化,需要同步更新 apigo 控制面与 ai-proxy 插件判定。
|
||
return nil
|
||
}
|
||
if serviceURL.Query().Get(queryAzureApiVersion) == "" {
|
||
// APIGO-CONTRACT: azure-service-url-api-version-optional
|
||
// legacy deployment/domain-only 代理模式仍依赖 api-version 生成旧版 Azure data-plane 请求;
|
||
// 只有确认旧版 /openai/deployments 路径也取消版本参数要求时,才能放宽该分支。
|
||
return fmt.Errorf("missing %s query parameter in azureServiceUrl: %s", queryAzureApiVersion, rawServiceURL)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// isAzureOpenAIV1Path 判断 Azure OpenAI 服务 URL 是否使用新版 /openai/v1 path mode。
|
||
// 输入约束:rawPath 可为空或带尾随斜杠,函数内部会用 path.Clean 规整。
|
||
// 输出语义:仅 /openai/v1 及其子路径返回 true,避免把 /openai/v10 误判为 v1。
|
||
// 边界场景:空路径会被规整为 ".",必须返回 false,让 domain-only legacy 模式继续要求 api-version。
|
||
func isAzureOpenAIV1Path(rawPath string) bool {
|
||
cleanPath := path.Clean(rawPath)
|
||
return cleanPath == pathAzureOpenAIV1 || strings.HasPrefix(cleanPath, pathAzureOpenAIV1+"/")
|
||
}
|
||
|
||
// isAzureOpenAIV1BasePath 判断 URL 是否是 Azure OpenAI v1 的 base_url,而不是具体接口完整路径。
|
||
// 输入约束:rawPath 来自 azureServiceUrl.Path,可为空或带尾随斜杠。
|
||
// 输出语义:仅精确 /openai/v1 返回 true;/openai/v1/chat/completions 仍保留 full path 语义。
|
||
// 边界场景:/openai/v10 或 /openai/v1beta 都不能被误判为 v1 base_url。
|
||
func isAzureOpenAIV1BasePath(rawPath string) bool {
|
||
return path.Clean(rawPath) == pathAzureOpenAIV1
|
||
}
|
||
|
||
// appendAzureServiceURLRawQuery 将 azureServiceUrl 中的原始 query 拼接到目标路径。
|
||
// 输入约束:basePath 可以已经包含 query 或以 ? 结尾;rawQuery 为空时不得追加任何分隔符。
|
||
// 输出语义:返回可直接写入 :path 的完整路径,并保持已有 query 与 azureServiceUrl query 的相对顺序。
|
||
// 边界场景:Azure v1 URL 常常没有 query,此时必须避免生成尾随空问号。
|
||
func appendAzureServiceURLRawQuery(basePath, rawQuery string) string {
|
||
if rawQuery == "" {
|
||
return basePath
|
||
}
|
||
if !strings.Contains(basePath, "?") {
|
||
return basePath + "?" + rawQuery
|
||
}
|
||
if strings.HasSuffix(basePath, "?") {
|
||
return basePath + rawQuery
|
||
}
|
||
return basePath + "&" + rawQuery
|
||
}
|
||
|
||
// CreateProvider 基于已校验的 Azure OpenAI 配置创建运行时 provider。
|
||
// 输入约束:config.azureServiceUrl 必须可解析;ValidateConfig 会在正常启动路径提前检查 api-version 与 token。
|
||
// 输出语义:返回的 provider 会记录 URL 类型、默认模型、原始 query 和上下文缓存,用于请求阶段改写 Host/Path/Header。
|
||
// 边界场景:/openai/deployments/{deployment} 可提取默认 deployment;/openai/v1 作为 v1 base_url 映射 OpenAI path,不补 api-version。
|
||
func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||
var serviceUrl *url.URL
|
||
if u, err := url.Parse(config.azureServiceUrl); err != nil {
|
||
return nil, fmt.Errorf("invalid azureServiceUrl: %w", err)
|
||
} else {
|
||
serviceUrl = u
|
||
}
|
||
|
||
modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path)
|
||
defaultModel := "placeholder"
|
||
var serviceUrlType azureServiceUrlType
|
||
if modelSubMatch != nil {
|
||
defaultModel = modelSubMatch[1]
|
||
if modelSubMatch[2] != "" {
|
||
serviceUrlType = azureServiceUrlTypeFull
|
||
} else {
|
||
serviceUrlType = azureServiceUrlTypeWithDeployment
|
||
}
|
||
log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel)
|
||
} else if isAzureOpenAIV1BasePath(serviceUrl.Path) {
|
||
// APIGO-CONTRACT: azure-service-url-api-version-optional
|
||
// /openai/v1 是 Azure OpenAI v1 的 base_url,运行时需要继续拼接 OpenAI capability path;
|
||
// 若将其当作 full path,会把 /v1/chat/completions 错误覆盖成 /openai/v1。
|
||
serviceUrlType = azureServiceUrlTypeOpenAIV1Base
|
||
log.Debugf("azureProvider: using Azure OpenAI v1 base path: %s", serviceUrl.Path)
|
||
} else {
|
||
// If path doesn't match the /openai/deployments pattern,
|
||
// check if it's a custom full path or domain only
|
||
if serviceUrl.Path != "" && serviceUrl.Path != "/" {
|
||
serviceUrlType = azureServiceUrlTypeFull
|
||
log.Debugf("azureProvider: using custom full path: %s", serviceUrl.Path)
|
||
} else {
|
||
serviceUrlType = azureServiceUrlTypeDomainOnly
|
||
log.Debugf("azureProvider: no default model found in serviceUrl")
|
||
}
|
||
}
|
||
log.Debugf("azureProvider: serviceUrlType=%d", serviceUrlType)
|
||
|
||
if serviceUrlType == azureServiceUrlTypeOpenAIV1Base {
|
||
config.setDefaultCapabilities(m.DefaultOpenAIV1Capabilities(pathAzureOpenAIV1))
|
||
} else {
|
||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||
}
|
||
apiVersion := serviceUrl.Query().Get(queryAzureApiVersion)
|
||
log.Debugf("azureProvider: using %s: %s", queryAzureApiVersion, apiVersion)
|
||
return &azureProvider{
|
||
config: config,
|
||
serviceUrl: serviceUrl,
|
||
serviceUrlType: serviceUrlType,
|
||
serviceUrlFullPath: appendAzureServiceURLRawQuery(serviceUrl.Path, serviceUrl.RawQuery),
|
||
apiVersion: apiVersion,
|
||
defaultModel: defaultModel,
|
||
contextCache: createContextCache(&config),
|
||
}, nil
|
||
}
|
||
|
||
type azureProvider struct {
|
||
config ProviderConfig
|
||
|
||
contextCache *contextCache
|
||
serviceUrl *url.URL
|
||
serviceUrlFullPath string
|
||
serviceUrlType azureServiceUrlType
|
||
apiVersion string
|
||
defaultModel string
|
||
}
|
||
|
||
func (m *azureProvider) GetProviderType() string {
|
||
return providerTypeAzure
|
||
}
|
||
|
||
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||
return nil
|
||
}
|
||
|
||
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||
}
|
||
|
||
func isAzureMultipartImageRequest(apiName ApiName, contentType string) bool {
|
||
if apiName != ApiNameImageEdit && apiName != ApiNameImageVariation {
|
||
return false
|
||
}
|
||
return isMultipartFormData(contentType)
|
||
}
|
||
|
||
func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (transformedBody []byte, err error) {
|
||
transformedBody = body
|
||
err = nil
|
||
|
||
contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType)
|
||
isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType)
|
||
|
||
transformedBody, err = m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||
if isMultipartImageRequest {
|
||
if err != nil {
|
||
log.Debugf("[azure multipart] body transform failed: api=%s, err=%v", apiName, err)
|
||
} else {
|
||
log.Debugf("[azure multipart] body transformed: api=%s, originalModel=%s, mappedModel=%s, bodyBytes=%d->%d",
|
||
apiName,
|
||
ctx.GetStringContext(ctxKeyOriginalRequestModel, ""),
|
||
ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||
len(body),
|
||
len(transformedBody),
|
||
)
|
||
}
|
||
}
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// This must be called after the body is transformed, because it uses the model from the context filled by that call.
|
||
if path := m.transformRequestPath(ctx, apiName); path != "" {
|
||
if isMultipartImageRequest {
|
||
log.Debugf("[azure multipart] body path overwrite: api=%s, path=%s, modelInContext=%s",
|
||
apiName, path, ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
|
||
}
|
||
err = util.OverwriteRequestPath(path)
|
||
if err == nil {
|
||
log.Debugf("azureProvider: overwrite request path to %s succeeded", path)
|
||
} else {
|
||
log.Errorf("azureProvider: overwrite request path to %s failed: %v", path, err)
|
||
}
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// transformRequestPath 根据插件协议和 Azure service URL 类型生成最终上游请求路径。
|
||
// 输入约束:ctx 中可能已经写入最终模型名;apiName 必须来自 ai-proxy 能识别的 API 枚举。
|
||
// 输出语义:original 协议返回空字符串表示不覆盖 path;其他协议返回包含必要 Azure path/query 的上游路径。
|
||
// 边界场景:v1 完整路径不应追加空 query,legacy domain-only 模式仍会把 api-version 追加到 capability 映射后的路径。
|
||
func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName ApiName) string {
|
||
// When using original protocol, don't overwrite the path.
|
||
// This ensures basePathHandling works correctly even in TransformRequestBody stage.
|
||
if m.config.IsOriginal() {
|
||
return ""
|
||
}
|
||
|
||
originalPath := util.GetOriginalRequestPath()
|
||
|
||
if m.serviceUrlType == azureServiceUrlTypeFull {
|
||
log.Debugf("azureProvider: use configured path %s", m.serviceUrlFullPath)
|
||
return m.serviceUrlFullPath
|
||
}
|
||
|
||
log.Debugf("azureProvider: original request path: %s", originalPath)
|
||
path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities)
|
||
log.Debugf("azureProvider: path: %s", path)
|
||
if strings.Contains(path, pathAzureModelPlaceholder) {
|
||
log.Debugf("azureProvider: path contains placeholder: %s", path)
|
||
var model string
|
||
if m.serviceUrlType == azureServiceUrlTypeWithDeployment {
|
||
model = m.defaultModel
|
||
} else {
|
||
model = ctx.GetStringContext(ctxKeyFinalRequestModel, "")
|
||
log.Debugf("azureProvider: model from context: %s", model)
|
||
if model == "" {
|
||
model = m.defaultModel
|
||
log.Debugf("azureProvider: use default model: %s", model)
|
||
}
|
||
}
|
||
path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model)
|
||
log.Debugf("azureProvider: model replaced path: %s", path)
|
||
}
|
||
path = appendAzureServiceURLRawQuery(path, m.serviceUrl.RawQuery)
|
||
log.Debugf("azureProvider: final path: %s", path)
|
||
|
||
return path
|
||
}
|
||
|
||
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||
contentType := headers.Get(util.HeaderContentType)
|
||
isMultipartImageRequest := isAzureMultipartImageRequest(apiName, contentType)
|
||
|
||
// We need to overwrite the request path in the request headers stage,
|
||
// because for some APIs, we don't read the request body and the path is model irrelevant.
|
||
if overwrittenPath := m.transformRequestPath(ctx, apiName); overwrittenPath != "" {
|
||
util.OverwriteRequestPathHeader(headers, overwrittenPath)
|
||
if isMultipartImageRequest {
|
||
log.Debugf("[azure multipart] header path overwrite: api=%s, path=%s, modelInContext=%s",
|
||
apiName, overwrittenPath, ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
|
||
}
|
||
}
|
||
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
|
||
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
|
||
headers.Del("Content-Length")
|
||
|
||
supportedAPI := m.config.isSupportedAPI(apiName)
|
||
needProcessBody := m.config.needToProcessRequestBody(apiName)
|
||
if isMultipartImageRequest {
|
||
log.Debugf("[azure multipart] body processing decision: api=%s, supported=%t, needProcessBody=%t",
|
||
apiName, supportedAPI, needProcessBody)
|
||
}
|
||
|
||
if !supportedAPI || !needProcessBody {
|
||
// If the API is not supported or there is no need to process the body,
|
||
// we should not read the request body and keep it as it is.
|
||
ctx.DontReadRequestBody()
|
||
}
|
||
}
|