feat: Support model mapping and more URL configuration formats for Azure OpenAI (#2649)

This commit is contained in:
Kent Dong
2025-07-25 11:28:02 +08:00
committed by GitHub
parent ea0bf7c1b7
commit 7348c265b5
5 changed files with 259 additions and 89 deletions

View File

@@ -5,17 +5,33 @@ import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
const (
pathAzureFiles = "/openai/files"
pathAzureBatches = "/openai/batches"
pathAzurePrefix = "/openai"
pathAzureModelPlaceholder = "{model}"
pathAzureWithModelPrefix = "/openai/deployments/" + pathAzureModelPlaceholder
queryAzureApiVersion = "api-version"
)
var (
azureModelIrrelevantApis = map[ApiName]bool{
ApiNameModels: true,
ApiNameBatches: true,
ApiNameRetrieveBatch: true,
ApiNameCancelBatch: true,
ApiNameFiles: true,
ApiNameRetrieveFile: true,
ApiNameRetrieveFileContent: true,
}
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(/.*|$)")
)
// azureProvider is the provider for Azure OpenAI service.
@@ -23,21 +39,32 @@ type azureProviderInitializer struct {
}
func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
string(ApiNameFiles): PathOpenAIFiles,
string(ApiNameBatches): PathOpenAIBatches,
var capabilities = map[string]string{}
for k, v := range (&openaiProviderInitializer{}).DefaultCapabilities() {
if !strings.HasPrefix(v, PathOpenAIPrefix) {
log.Warnf("azureProviderInitializer: capability %s has an unexpected path %s, skipping", k, v)
continue
}
path := strings.TrimPrefix(v, PathOpenAIPrefix)
if azureModelIrrelevantApis[ApiName(k)] {
path = pathAzurePrefix + path
} else {
path = pathAzureWithModelPrefix + path
}
capabilities[k] = path
log.Debugf("azureProviderInitializer: capability %s -> %s", k, path)
}
return capabilities
}
func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.azureServiceUrl == "" {
return errors.New("missing azureServiceUrl in provider config")
}
if _, err := url.Parse(config.azureServiceUrl); err != nil {
if azureServiceUrl, err := url.Parse(config.azureServiceUrl); err != nil {
return fmt.Errorf("invalid azureServiceUrl: %w", err)
} else if !azureServiceUrl.Query().Has(queryAzureApiVersion) {
return fmt.Errorf("missing %s query parameter in azureServiceUrl: %s", queryAzureApiVersion, config.azureServiceUrl)
}
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
@@ -52,10 +79,24 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
} else {
serviceUrl = u
}
modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path)
defaultModel := "placeholder"
if modelSubMatch != nil {
defaultModel = modelSubMatch[1]
log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel)
} else {
log.Debugf("azureProvider: no default model found in serviceUrl")
}
config.setDefaultCapabilities(m.DefaultCapabilities())
apiVersion := serviceUrl.Query().Get(queryAzureApiVersion)
log.Debugf("azureProvider: using %s: %s", queryAzureApiVersion, apiVersion)
return &azureProvider{
config: config,
serviceUrl: serviceUrl,
apiVersion: apiVersion,
defaultModel: defaultModel,
contextCache: createContextCache(&config),
}, nil
}
@@ -65,6 +106,8 @@ type azureProvider struct {
contextCache *contextCache
serviceUrl *url.URL
apiVersion string
defaultModel string
}
func (m *azureProvider) GetProviderType() string {
@@ -80,44 +123,68 @@ func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
finalRequestUrl := *m.serviceUrl
if u, e := url.Parse(ctx.Path()); e == nil {
if len(u.Query()) != 0 {
q := m.serviceUrl.Query()
for k, v := range u.Query() {
switch len(v) {
case 0:
break
case 1:
q.Set(k, v[0])
break
default:
delete(q, k)
for _, vv := range v {
q.Add(k, vv)
}
}
}
finalRequestUrl.RawQuery = q.Encode()
}
func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (transformedBody []byte, err error) {
transformedBody = body
err = nil
if filesIndex := strings.Index(u.Path, "/files"); filesIndex != -1 {
finalRequestUrl.Path = pathAzureFiles + u.Path[filesIndex+len("/files"):]
} else if batchesIndex := strings.Index(u.Path, "/batches"); batchesIndex != -1 {
finalRequestUrl.Path = pathAzureBatches + u.Path[batchesIndex+len("/batches"):]
}
} else {
log.Errorf("failed to parse request path: %v", e)
transformedBody, err = m.config.defaultTransformRequestBody(ctx, apiName, body)
if err != nil {
return
}
util.OverwriteRequestPathHeader(headers, finalRequestUrl.RequestURI())
// This must be called after the body is transformed, because it uses the model from the context filled by that call.
if path := m.transformRequestPath(ctx, apiName); path != "" {
err = util.OverwriteRequestPath(path)
if err == nil {
log.Debugf("azureProvider: overwrite request path to %s succeeded", path)
} else {
log.Errorf("azureProvider: overwrite request path to %s failed: %v", path, err)
}
}
return
}
func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName ApiName) string {
originalPath := util.GetOriginalRequestPath()
if m.config.IsOriginal() {
return originalPath
}
log.Debugf("azureProvider: original request path: %s", originalPath)
path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities)
log.Debugf("azureProvider: path: %s", path)
if strings.Contains(path, pathAzureModelPlaceholder) {
log.Debugf("azureProvider: path contains placeholder: %s", path)
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
log.Debugf("azureProvider: model from context: %s", model)
if model == "" {
model = m.defaultModel
log.Debugf("azureProvider: use default model: %s", model)
}
path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model)
log.Debugf("azureProvider: model replaced path: %s", path)
}
path = fmt.Sprintf("%s?%s=%s", path, queryAzureApiVersion, m.apiVersion)
log.Debugf("azureProvider: final path: %s", path)
return path
}
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
// We need to overwrite the request path in the request headers stage,
// because for some APIs, we don't read the request body and the path is model irrelevant.
if overwrittenPath := m.transformRequestPath(ctx, apiName); overwrittenPath != "" {
util.OverwriteRequestPathHeader(headers, overwrittenPath)
}
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
if !m.config.isSupportedAPI(apiName) {
// If the API is not supported, we should not read the request body and keep it as it is.
if !m.config.isSupportedAPI(apiName) || !m.config.needToProcessRequestBody(apiName) {
// If the API is not supported or there is no need to process the body,
// we should not read the request body and keep it as it is.
ctx.DontReadRequestBody()
}
}

View File

@@ -67,6 +67,7 @@ const (
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
// OpenAI
PathOpenAIPrefix = "/v1"
PathOpenAICompletions = "/v1/completions"
PathOpenAIChatCompletions = "/v1/chat/completions"
PathOpenAIEmbeddings = "/v1/embeddings"
@@ -851,7 +852,7 @@ func (c *ProviderConfig) handleRequestBody(
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalRequestHeaders()
headers := util.GetRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
util.ReplaceRequestHeaders(headers)
} else {
@@ -877,7 +878,7 @@ func (c *ProviderConfig) handleRequestBody(
}
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
headers := util.GetOriginalRequestHeaders()
headers := util.GetRequestHeaders()
originPath := headers.Get(":path")
if c.basePath != "" && c.basePathHandling == basePathHandlingRemovePrefix {
headers.Set(":path", strings.TrimPrefix(originPath, c.basePath))
@@ -888,9 +889,6 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) {
headers.Set(":path", path.Join(c.basePath, headers.Get(":path")))
}
if headers.Get(":path") != originPath {
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
}
util.ReplaceRequestHeaders(headers)
}
@@ -908,7 +906,9 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
}
model := gjson.GetBytes(body, "model").String()
ctx.SetContext(ctxKeyOriginalRequestModel, model)
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping))
mappedModel := getMappedModel(model, c.modelMapping)
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
return sjson.SetBytes(body, "model", mappedModel)
}
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {

View File

@@ -15,10 +15,10 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
@@ -136,7 +136,7 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if v.config.IsOriginal() {
return types.ActionContinue, nil
}
headers := util.GetOriginalRequestHeaders()
headers := util.GetRequestHeaders()
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)