mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
feat: Support model mapping and more URL configuration formats for Azure OpenAI (#2649)
This commit is contained in:
@@ -25,6 +25,23 @@ const (
|
|||||||
pluginName = "ai-proxy"
|
pluginName = "ai-proxy"
|
||||||
|
|
||||||
defaultMaxBodyBytes uint32 = 100 * 1024 * 1024
|
defaultMaxBodyBytes uint32 = 100 * 1024 * 1024
|
||||||
|
|
||||||
|
ctxOriginalPath = "original_path"
|
||||||
|
ctxOriginalHost = "original_host"
|
||||||
|
ctxOriginalAuth = "original_auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
headersCtxKeyMapping = map[string]string{
|
||||||
|
util.HeaderAuthority: ctxOriginalHost,
|
||||||
|
util.HeaderPath: ctxOriginalPath,
|
||||||
|
util.HeaderAuthorization: ctxOriginalAuth,
|
||||||
|
}
|
||||||
|
headerToOriginalHeaderMapping = map[string]string{
|
||||||
|
util.HeaderAuthority: util.HeaderOriginalHost,
|
||||||
|
util.HeaderPath: util.HeaderOriginalPath,
|
||||||
|
util.HeaderAuthorization: util.HeaderOriginalAuth,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {}
|
func main() {}
|
||||||
@@ -75,6 +92,30 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initContext(ctx wrapper.HttpContext) {
|
||||||
|
for header, ctxKey := range headersCtxKeyMapping {
|
||||||
|
value, _ := proxywasm.GetHttpRequestHeader(header)
|
||||||
|
ctx.SetContext(ctxKey, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveContextsToHeaders(ctx wrapper.HttpContext) {
|
||||||
|
for header, ctxKey := range headersCtxKeyMapping {
|
||||||
|
originalValue := ctx.GetStringContext(ctxKey, "")
|
||||||
|
if originalValue == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentValue, _ := proxywasm.GetHttpRequestHeader(header)
|
||||||
|
if currentValue == "" || originalValue == currentValue {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
originalHeader := headerToOriginalHeaderMapping[header]
|
||||||
|
if originalHeader != "" {
|
||||||
|
_ = proxywasm.ReplaceHttpRequestHeader(originalHeader, originalValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
|
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
|
||||||
activeProvider := pluginConfig.GetProvider()
|
activeProvider := pluginConfig.GetProvider()
|
||||||
|
|
||||||
@@ -86,7 +127,14 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
|||||||
|
|
||||||
log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType())
|
log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType())
|
||||||
|
|
||||||
|
initContext(ctx)
|
||||||
|
|
||||||
rawPath := ctx.Path()
|
rawPath := ctx.Path()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
saveContextsToHeaders(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
path, _ := url.Parse(rawPath)
|
path, _ := url.Parse(rawPath)
|
||||||
apiName := getApiName(path.Path)
|
apiName := getApiName(path.Path)
|
||||||
providerConfig := pluginConfig.GetProviderConfig()
|
providerConfig := pluginConfig.GetProviderConfig()
|
||||||
@@ -154,6 +202,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
|||||||
}
|
}
|
||||||
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
|
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
saveContextsToHeaders(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
|
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
|
||||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||||
providerConfig := pluginConfig.GetProviderConfig()
|
providerConfig := pluginConfig.GetProviderConfig()
|
||||||
@@ -214,7 +266,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
|||||||
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
||||||
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse)
|
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse)
|
||||||
|
|
||||||
headers := util.GetOriginalResponseHeaders()
|
headers := util.GetResponseHeaders()
|
||||||
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
|
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
|
||||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||||
handler.TransformResponseHeaders(ctx, apiName, headers)
|
handler.TransformResponseHeaders(ctx, apiName, headers)
|
||||||
|
|||||||
@@ -5,17 +5,33 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
pathAzureFiles = "/openai/files"
|
pathAzurePrefix = "/openai"
|
||||||
pathAzureBatches = "/openai/batches"
|
pathAzureModelPlaceholder = "{model}"
|
||||||
|
pathAzureWithModelPrefix = "/openai/deployments/" + pathAzureModelPlaceholder
|
||||||
|
queryAzureApiVersion = "api-version"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
azureModelIrrelevantApis = map[ApiName]bool{
|
||||||
|
ApiNameModels: true,
|
||||||
|
ApiNameBatches: true,
|
||||||
|
ApiNameRetrieveBatch: true,
|
||||||
|
ApiNameCancelBatch: true,
|
||||||
|
ApiNameFiles: true,
|
||||||
|
ApiNameRetrieveFile: true,
|
||||||
|
ApiNameRetrieveFileContent: true,
|
||||||
|
}
|
||||||
|
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(/.*|$)")
|
||||||
)
|
)
|
||||||
|
|
||||||
// azureProvider is the provider for Azure OpenAI service.
|
// azureProvider is the provider for Azure OpenAI service.
|
||||||
@@ -23,21 +39,32 @@ type azureProviderInitializer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
|
func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
var capabilities = map[string]string{}
|
||||||
// TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities
|
for k, v := range (&openaiProviderInitializer{}).DefaultCapabilities() {
|
||||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
if !strings.HasPrefix(v, PathOpenAIPrefix) {
|
||||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
log.Warnf("azureProviderInitializer: capability %s has an unexpected path %s, skipping", k, v)
|
||||||
string(ApiNameFiles): PathOpenAIFiles,
|
continue
|
||||||
string(ApiNameBatches): PathOpenAIBatches,
|
}
|
||||||
|
path := strings.TrimPrefix(v, PathOpenAIPrefix)
|
||||||
|
if azureModelIrrelevantApis[ApiName(k)] {
|
||||||
|
path = pathAzurePrefix + path
|
||||||
|
} else {
|
||||||
|
path = pathAzureWithModelPrefix + path
|
||||||
|
}
|
||||||
|
capabilities[k] = path
|
||||||
|
log.Debugf("azureProviderInitializer: capability %s -> %s", k, path)
|
||||||
}
|
}
|
||||||
|
return capabilities
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||||
if config.azureServiceUrl == "" {
|
if config.azureServiceUrl == "" {
|
||||||
return errors.New("missing azureServiceUrl in provider config")
|
return errors.New("missing azureServiceUrl in provider config")
|
||||||
}
|
}
|
||||||
if _, err := url.Parse(config.azureServiceUrl); err != nil {
|
if azureServiceUrl, err := url.Parse(config.azureServiceUrl); err != nil {
|
||||||
return fmt.Errorf("invalid azureServiceUrl: %w", err)
|
return fmt.Errorf("invalid azureServiceUrl: %w", err)
|
||||||
|
} else if !azureServiceUrl.Query().Has(queryAzureApiVersion) {
|
||||||
|
return fmt.Errorf("missing %s query parameter in azureServiceUrl: %s", queryAzureApiVersion, config.azureServiceUrl)
|
||||||
}
|
}
|
||||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||||
return errors.New("no apiToken found in provider config")
|
return errors.New("no apiToken found in provider config")
|
||||||
@@ -52,10 +79,24 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
|
|||||||
} else {
|
} else {
|
||||||
serviceUrl = u
|
serviceUrl = u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path)
|
||||||
|
defaultModel := "placeholder"
|
||||||
|
if modelSubMatch != nil {
|
||||||
|
defaultModel = modelSubMatch[1]
|
||||||
|
log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel)
|
||||||
|
} else {
|
||||||
|
log.Debugf("azureProvider: no default model found in serviceUrl")
|
||||||
|
}
|
||||||
|
|
||||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||||
|
apiVersion := serviceUrl.Query().Get(queryAzureApiVersion)
|
||||||
|
log.Debugf("azureProvider: using %s: %s", queryAzureApiVersion, apiVersion)
|
||||||
return &azureProvider{
|
return &azureProvider{
|
||||||
config: config,
|
config: config,
|
||||||
serviceUrl: serviceUrl,
|
serviceUrl: serviceUrl,
|
||||||
|
apiVersion: apiVersion,
|
||||||
|
defaultModel: defaultModel,
|
||||||
contextCache: createContextCache(&config),
|
contextCache: createContextCache(&config),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -65,6 +106,8 @@ type azureProvider struct {
|
|||||||
|
|
||||||
contextCache *contextCache
|
contextCache *contextCache
|
||||||
serviceUrl *url.URL
|
serviceUrl *url.URL
|
||||||
|
apiVersion string
|
||||||
|
defaultModel string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *azureProvider) GetProviderType() string {
|
func (m *azureProvider) GetProviderType() string {
|
||||||
@@ -80,44 +123,68 @@ func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
|||||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (transformedBody []byte, err error) {
|
||||||
finalRequestUrl := *m.serviceUrl
|
transformedBody = body
|
||||||
if u, e := url.Parse(ctx.Path()); e == nil {
|
err = nil
|
||||||
if len(u.Query()) != 0 {
|
|
||||||
q := m.serviceUrl.Query()
|
|
||||||
for k, v := range u.Query() {
|
|
||||||
switch len(v) {
|
|
||||||
case 0:
|
|
||||||
break
|
|
||||||
case 1:
|
|
||||||
q.Set(k, v[0])
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
delete(q, k)
|
|
||||||
for _, vv := range v {
|
|
||||||
q.Add(k, vv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
finalRequestUrl.RawQuery = q.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
if filesIndex := strings.Index(u.Path, "/files"); filesIndex != -1 {
|
transformedBody, err = m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||||
finalRequestUrl.Path = pathAzureFiles + u.Path[filesIndex+len("/files"):]
|
if err != nil {
|
||||||
} else if batchesIndex := strings.Index(u.Path, "/batches"); batchesIndex != -1 {
|
return
|
||||||
finalRequestUrl.Path = pathAzureBatches + u.Path[batchesIndex+len("/batches"):]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Errorf("failed to parse request path: %v", e)
|
|
||||||
}
|
}
|
||||||
util.OverwriteRequestPathHeader(headers, finalRequestUrl.RequestURI())
|
|
||||||
|
|
||||||
|
// This must be called after the body is transformed, because it uses the model from the context filled by that call.
|
||||||
|
if path := m.transformRequestPath(ctx, apiName); path != "" {
|
||||||
|
err = util.OverwriteRequestPath(path)
|
||||||
|
if err == nil {
|
||||||
|
log.Debugf("azureProvider: overwrite request path to %s succeeded", path)
|
||||||
|
} else {
|
||||||
|
log.Errorf("azureProvider: overwrite request path to %s failed: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName ApiName) string {
|
||||||
|
originalPath := util.GetOriginalRequestPath()
|
||||||
|
|
||||||
|
if m.config.IsOriginal() {
|
||||||
|
return originalPath
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("azureProvider: original request path: %s", originalPath)
|
||||||
|
path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities)
|
||||||
|
log.Debugf("azureProvider: path: %s", path)
|
||||||
|
if strings.Contains(path, pathAzureModelPlaceholder) {
|
||||||
|
log.Debugf("azureProvider: path contains placeholder: %s", path)
|
||||||
|
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
|
||||||
|
log.Debugf("azureProvider: model from context: %s", model)
|
||||||
|
if model == "" {
|
||||||
|
model = m.defaultModel
|
||||||
|
log.Debugf("azureProvider: use default model: %s", model)
|
||||||
|
}
|
||||||
|
path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model)
|
||||||
|
log.Debugf("azureProvider: model replaced path: %s", path)
|
||||||
|
}
|
||||||
|
path = fmt.Sprintf("%s?%s=%s", path, queryAzureApiVersion, m.apiVersion)
|
||||||
|
log.Debugf("azureProvider: final path: %s", path)
|
||||||
|
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||||
|
// We need to overwrite the request path in the request headers stage,
|
||||||
|
// because for some APIs, we don't read the request body and the path is model irrelevant.
|
||||||
|
if overwrittenPath := m.transformRequestPath(ctx, apiName); overwrittenPath != "" {
|
||||||
|
util.OverwriteRequestPathHeader(headers, overwrittenPath)
|
||||||
|
}
|
||||||
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
|
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
|
||||||
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
|
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
|
||||||
headers.Del("Content-Length")
|
headers.Del("Content-Length")
|
||||||
|
|
||||||
if !m.config.isSupportedAPI(apiName) {
|
if !m.config.isSupportedAPI(apiName) || !m.config.needToProcessRequestBody(apiName) {
|
||||||
// If the API is not supported, we should not read the request body and keep it as it is.
|
// If the API is not supported or there is no need to process the body,
|
||||||
|
// we should not read the request body and keep it as it is.
|
||||||
ctx.DontReadRequestBody()
|
ctx.DontReadRequestBody()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ const (
|
|||||||
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
|
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
|
||||||
|
|
||||||
// OpenAI
|
// OpenAI
|
||||||
|
PathOpenAIPrefix = "/v1"
|
||||||
PathOpenAICompletions = "/v1/completions"
|
PathOpenAICompletions = "/v1/completions"
|
||||||
PathOpenAIChatCompletions = "/v1/chat/completions"
|
PathOpenAIChatCompletions = "/v1/chat/completions"
|
||||||
PathOpenAIEmbeddings = "/v1/embeddings"
|
PathOpenAIEmbeddings = "/v1/embeddings"
|
||||||
@@ -851,7 +852,7 @@ func (c *ProviderConfig) handleRequestBody(
|
|||||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||||
body, err = handler.TransformRequestBody(ctx, apiName, body)
|
body, err = handler.TransformRequestBody(ctx, apiName, body)
|
||||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||||
headers := util.GetOriginalRequestHeaders()
|
headers := util.GetRequestHeaders()
|
||||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||||
util.ReplaceRequestHeaders(headers)
|
util.ReplaceRequestHeaders(headers)
|
||||||
} else {
|
} else {
|
||||||
@@ -877,7 +878,7 @@ func (c *ProviderConfig) handleRequestBody(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
|
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
|
||||||
headers := util.GetOriginalRequestHeaders()
|
headers := util.GetRequestHeaders()
|
||||||
originPath := headers.Get(":path")
|
originPath := headers.Get(":path")
|
||||||
if c.basePath != "" && c.basePathHandling == basePathHandlingRemovePrefix {
|
if c.basePath != "" && c.basePathHandling == basePathHandlingRemovePrefix {
|
||||||
headers.Set(":path", strings.TrimPrefix(originPath, c.basePath))
|
headers.Set(":path", strings.TrimPrefix(originPath, c.basePath))
|
||||||
@@ -888,9 +889,6 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
|
|||||||
if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) {
|
if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) {
|
||||||
headers.Set(":path", path.Join(c.basePath, headers.Get(":path")))
|
headers.Set(":path", path.Join(c.basePath, headers.Get(":path")))
|
||||||
}
|
}
|
||||||
if headers.Get(":path") != originPath {
|
|
||||||
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
|
|
||||||
}
|
|
||||||
util.ReplaceRequestHeaders(headers)
|
util.ReplaceRequestHeaders(headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -908,7 +906,9 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
|
|||||||
}
|
}
|
||||||
model := gjson.GetBytes(body, "model").String()
|
model := gjson.GetBytes(body, "model").String()
|
||||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||||
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping))
|
mappedModel := getMappedModel(model, c.modelMapping)
|
||||||
|
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||||
|
return sjson.SetBytes(body, "model", mappedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
||||||
|
|||||||
@@ -15,10 +15,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
|
||||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -136,7 +136,7 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
|||||||
if v.config.IsOriginal() {
|
if v.config.IsOriginal() {
|
||||||
return types.ActionContinue, nil
|
return types.ActionContinue, nil
|
||||||
}
|
}
|
||||||
headers := util.GetOriginalRequestHeaders()
|
headers := util.GetRequestHeaders()
|
||||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||||
util.ReplaceRequestHeaders(headers)
|
util.ReplaceRequestHeaders(headers)
|
||||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||||
|
|||||||
@@ -10,7 +10,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HeaderContentType = "Content-Type"
|
HeaderContentType = "Content-Type"
|
||||||
|
HeaderPath = ":path"
|
||||||
|
HeaderAuthority = ":authority"
|
||||||
|
HeaderAuthorization = "Authorization"
|
||||||
|
|
||||||
|
HeaderOriginalPath = "X-ENVOY-ORIGINAL-PATH"
|
||||||
|
HeaderOriginalHost = "X-ENVOY-ORIGINAL-HOST"
|
||||||
|
HeaderOriginalAuth = "X-HI-ORIGINAL-AUTH"
|
||||||
|
|
||||||
MimeTypeTextPlain = "text/plain"
|
MimeTypeTextPlain = "text/plain"
|
||||||
MimeTypeApplicationJson = "application/json"
|
MimeTypeApplicationJson = "application/json"
|
||||||
@@ -48,49 +55,49 @@ func CreateHeaders(kvs ...string) [][2]string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OverwriteRequestPath(path string) error {
|
func OverwriteRequestPath(path string) error {
|
||||||
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
|
return proxywasm.ReplaceHttpRequestHeader(HeaderPath, path)
|
||||||
_ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath)
|
|
||||||
}
|
|
||||||
return proxywasm.ReplaceHttpRequestHeader(":path", path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func OverwriteRequestAuthorization(credential string) error {
|
func OverwriteRequestAuthorization(credential string) error {
|
||||||
if exist, _ := proxywasm.GetHttpRequestHeader("X-HI-ORIGINAL-AUTH"); exist == "" {
|
if exist, _ := proxywasm.GetHttpRequestHeader(HeaderOriginalAuth); exist == "" {
|
||||||
if originAuth, err := proxywasm.GetHttpRequestHeader("Authorization"); err == nil {
|
if originAuth, err := proxywasm.GetHttpRequestHeader(HeaderAuthorization); err == nil {
|
||||||
_ = proxywasm.AddHttpRequestHeader("X-HI-ORIGINAL-AUTH", originAuth)
|
_ = proxywasm.AddHttpRequestHeader(HeaderOriginalPath, originAuth)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return proxywasm.ReplaceHttpRequestHeader("Authorization", credential)
|
return proxywasm.ReplaceHttpRequestHeader(HeaderAuthorization, credential)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OverwriteRequestHostHeader(headers http.Header, host string) {
|
func OverwriteRequestHostHeader(headers http.Header, host string) {
|
||||||
if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil {
|
if originHost, err := proxywasm.GetHttpRequestHeader(HeaderAuthority); err == nil {
|
||||||
headers.Set("X-ENVOY-ORIGINAL-HOST", originHost)
|
headers.Set(HeaderOriginalHost, originHost)
|
||||||
}
|
}
|
||||||
headers.Set(":authority", host)
|
headers.Set(HeaderAuthority, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OverwriteRequestPathHeader(headers http.Header, path string) {
|
func OverwriteRequestPathHeader(headers http.Header, path string) {
|
||||||
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
|
headers.Set(HeaderPath, path)
|
||||||
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
|
|
||||||
}
|
|
||||||
headers.Set(":path", path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string, mapping map[string]string) {
|
func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string, mapping map[string]string) {
|
||||||
mappedPath, exist := mapping[apiName]
|
originPath := GetOriginalRequestPath()
|
||||||
if !exist {
|
mappedPath := MapRequestPathByCapability(apiName, originPath, mapping)
|
||||||
|
if mappedPath == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
originPath, err := proxywasm.GetHttpRequestHeader(":path")
|
headers.Set(HeaderPath, mappedPath)
|
||||||
if err == nil {
|
log.Debugf("[OverwriteRequestPath] originPath=%s, mappedPath=%s", originPath, mappedPath)
|
||||||
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
|
}
|
||||||
}
|
|
||||||
|
func MapRequestPathByCapability(apiName string, originPath string, mapping map[string]string) string {
|
||||||
/**
|
/**
|
||||||
这里实现不太优雅,理应通过 apiName 来判断使用哪个正则替换
|
这里实现不太优雅,理应通过 apiName 来判断使用哪个正则替换
|
||||||
但 ApiName 定义在 provider 中, 而 provider 中又引用了 util
|
但 ApiName 定义在 provider 中, 而 provider 中又引用了 util
|
||||||
会导致循环引用
|
会导致循环引用
|
||||||
**/
|
**/
|
||||||
|
mappedPath, exist := mapping[apiName]
|
||||||
|
if !exist {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
if strings.Contains(mappedPath, "{") && strings.Contains(mappedPath, "}") {
|
if strings.Contains(mappedPath, "{") && strings.Contains(mappedPath, "}") {
|
||||||
replacements := []struct {
|
replacements := []struct {
|
||||||
regx *regexp.Regexp
|
regx *regexp.Regexp
|
||||||
@@ -119,17 +126,61 @@ func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
headers.Set(":path", mappedPath)
|
return mappedPath
|
||||||
log.Debugf("[OverwriteRequestPath] originPath=%s, mappedPath=%s", originPath, mappedPath)
|
}
|
||||||
|
|
||||||
|
func GetOriginalRequestPath() string {
|
||||||
|
path, err := proxywasm.GetHttpRequestHeader(HeaderOriginalPath)
|
||||||
|
if path != "" && err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
if path, err = proxywasm.GetHttpRequestHeader(HeaderPath); err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetOriginalRequestPath(path string) {
|
||||||
|
_ = proxywasm.ReplaceHttpRequestHeader(HeaderOriginalPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOriginalRequestHost() string {
|
||||||
|
host, err := proxywasm.GetHttpRequestHeader(HeaderOriginalHost)
|
||||||
|
if host != "" && err == nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
if host, err = proxywasm.GetHttpRequestHeader(HeaderAuthority); err == nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetOriginalRequestHost(host string) {
|
||||||
|
_ = proxywasm.ReplaceHttpRequestHeader(HeaderOriginalHost, host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOriginalRequestAuth() string {
|
||||||
|
auth, err := proxywasm.GetHttpRequestHeader(HeaderOriginalAuth)
|
||||||
|
if auth != "" && err == nil {
|
||||||
|
return auth
|
||||||
|
}
|
||||||
|
if auth, err = proxywasm.GetHttpRequestHeader(HeaderAuthorization); err == nil {
|
||||||
|
return auth
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetOriginalRequestAuth(auth string) {
|
||||||
|
_ = proxywasm.ReplaceHttpRequestHeader(HeaderOriginalAuth, auth)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
|
func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
|
||||||
if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" {
|
if exist := headers.Get(HeaderOriginalAuth); exist == "" {
|
||||||
if originAuth := headers.Get("Authorization"); originAuth != "" {
|
if originAuth := headers.Get(HeaderAuthorization); originAuth != "" {
|
||||||
headers.Set("X-HI-ORIGINAL-AUTH", originAuth)
|
headers.Set(HeaderOriginalAuth, originAuth)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
headers.Set("Authorization", credential)
|
headers.Set(HeaderAuthorization, credential)
|
||||||
}
|
}
|
||||||
|
|
||||||
func HeaderToSlice(header http.Header) [][2]string {
|
func HeaderToSlice(header http.Header) [][2]string {
|
||||||
@@ -152,22 +203,22 @@ func SliceToHeader(slice [][2]string) http.Header {
|
|||||||
return header
|
return header
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetOriginalRequestHeaders() http.Header {
|
func GetRequestHeaders() http.Header {
|
||||||
originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
|
header, _ := proxywasm.GetHttpRequestHeaders()
|
||||||
return SliceToHeader(originalHeaders)
|
return SliceToHeader(header)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetOriginalResponseHeaders() http.Header {
|
func GetResponseHeaders() http.Header {
|
||||||
originalHeaders, _ := proxywasm.GetHttpResponseHeaders()
|
headers, _ := proxywasm.GetHttpResponseHeaders()
|
||||||
return SliceToHeader(originalHeaders)
|
return SliceToHeader(headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReplaceRequestHeaders(headers http.Header) {
|
func ReplaceRequestHeaders(headers http.Header) {
|
||||||
modifiedHeaders := HeaderToSlice(headers)
|
headerSlice := HeaderToSlice(headers)
|
||||||
_ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
|
_ = proxywasm.ReplaceHttpRequestHeaders(headerSlice)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReplaceResponseHeaders(headers http.Header) {
|
func ReplaceResponseHeaders(headers http.Header) {
|
||||||
modifiedHeaders := HeaderToSlice(headers)
|
headerSlice := HeaderToSlice(headers)
|
||||||
_ = proxywasm.ReplaceHttpResponseHeaders(modifiedHeaders)
|
_ = proxywasm.ReplaceHttpResponseHeaders(headerSlice)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user