feat: Support model mapping and more URL configuration formats for Azure OpenAI (#2649)

This commit is contained in:
Kent Dong
2025-07-25 11:28:02 +08:00
committed by GitHub
parent ea0bf7c1b7
commit 7348c265b5
5 changed files with 259 additions and 89 deletions

View File

@@ -25,6 +25,23 @@ const (
pluginName = "ai-proxy"
defaultMaxBodyBytes uint32 = 100 * 1024 * 1024
ctxOriginalPath = "original_path"
ctxOriginalHost = "original_host"
ctxOriginalAuth = "original_auth"
)
var (
headersCtxKeyMapping = map[string]string{
util.HeaderAuthority: ctxOriginalHost,
util.HeaderPath: ctxOriginalPath,
util.HeaderAuthorization: ctxOriginalAuth,
}
headerToOriginalHeaderMapping = map[string]string{
util.HeaderAuthority: util.HeaderOriginalHost,
util.HeaderPath: util.HeaderOriginalPath,
util.HeaderAuthorization: util.HeaderOriginalAuth,
}
)
func main() {}
@@ -75,6 +92,30 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
return nil
}
func initContext(ctx wrapper.HttpContext) {
for header, ctxKey := range headersCtxKeyMapping {
value, _ := proxywasm.GetHttpRequestHeader(header)
ctx.SetContext(ctxKey, value)
}
}
func saveContextsToHeaders(ctx wrapper.HttpContext) {
for header, ctxKey := range headersCtxKeyMapping {
originalValue := ctx.GetStringContext(ctxKey, "")
if originalValue == "" {
continue
}
currentValue, _ := proxywasm.GetHttpRequestHeader(header)
if currentValue == "" || originalValue == currentValue {
continue
}
originalHeader := headerToOriginalHeaderMapping[header]
if originalHeader != "" {
_ = proxywasm.ReplaceHttpRequestHeader(originalHeader, originalValue)
}
}
}
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
activeProvider := pluginConfig.GetProvider()
@@ -86,7 +127,14 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType())
initContext(ctx)
rawPath := ctx.Path()
defer func() {
saveContextsToHeaders(ctx)
}()
path, _ := url.Parse(rawPath)
apiName := getApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
@@ -154,6 +202,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
}
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
defer func() {
saveContextsToHeaders(ctx)
}()
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
providerConfig := pluginConfig.GetProviderConfig()
@@ -214,7 +266,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse)
headers := util.GetOriginalResponseHeaders()
headers := util.GetResponseHeaders()
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
handler.TransformResponseHeaders(ctx, apiName, headers)