diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index fe3a539e7..68b0b7efb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -25,6 +25,23 @@ const ( pluginName = "ai-proxy" defaultMaxBodyBytes uint32 = 100 * 1024 * 1024 + + ctxOriginalPath = "original_path" + ctxOriginalHost = "original_host" + ctxOriginalAuth = "original_auth" +) + +var ( + headersCtxKeyMapping = map[string]string{ + util.HeaderAuthority: ctxOriginalHost, + util.HeaderPath: ctxOriginalPath, + util.HeaderAuthorization: ctxOriginalAuth, + } + headerToOriginalHeaderMapping = map[string]string{ + util.HeaderAuthority: util.HeaderOriginalHost, + util.HeaderPath: util.HeaderOriginalPath, + util.HeaderAuthorization: util.HeaderOriginalAuth, + } ) func main() {} @@ -75,6 +92,30 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug return nil } +func initContext(ctx wrapper.HttpContext) { + for header, ctxKey := range headersCtxKeyMapping { + value, _ := proxywasm.GetHttpRequestHeader(header) + ctx.SetContext(ctxKey, value) + } +} + +func saveContextsToHeaders(ctx wrapper.HttpContext) { + for header, ctxKey := range headersCtxKeyMapping { + originalValue := ctx.GetStringContext(ctxKey, "") + if originalValue == "" { + continue + } + currentValue, _ := proxywasm.GetHttpRequestHeader(header) + if currentValue == "" || originalValue == currentValue { + continue + } + originalHeader := headerToOriginalHeaderMapping[header] + if originalHeader != "" { + _ = proxywasm.ReplaceHttpRequestHeader(originalHeader, originalValue) + } + } +} + func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action { activeProvider := pluginConfig.GetProvider() @@ -86,7 +127,14 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType()) + initContext(ctx) + rawPath := ctx.Path() + + defer func() { + saveContextsToHeaders(ctx) + }() + path, _ := url.Parse(rawPath) apiName := getApiName(path.Path) providerConfig := pluginConfig.GetProviderConfig() @@ -154,6 +202,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig } log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType()) + defer func() { + saveContextsToHeaders(ctx) + }() + if handler, ok := activeProvider.(provider.RequestBodyHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) providerConfig := pluginConfig.GetProviderConfig() @@ -214,7 +266,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo // the apiToken is removed only when the number of consecutive request failures exceeds the threshold. providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse) - headers := util.GetOriginalResponseHeaders() + headers := util.GetResponseHeaders() if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) handler.TransformResponseHeaders(ctx, apiName, headers) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index e2668c430..22cdbf583 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -5,17 +5,33 @@ import ( "fmt" "net/http" "net/url" + "regexp" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) const ( - pathAzureFiles = "/openai/files" - pathAzureBatches = "/openai/batches" + pathAzurePrefix = "/openai" + pathAzureModelPlaceholder = "{model}" + pathAzureWithModelPrefix = "/openai/deployments/" + pathAzureModelPlaceholder + queryAzureApiVersion = "api-version" +) + +var ( + azureModelIrrelevantApis = map[ApiName]bool{ + ApiNameModels: true, + ApiNameBatches: true, + ApiNameRetrieveBatch: true, + ApiNameCancelBatch: true, + ApiNameFiles: true, + ApiNameRetrieveFile: true, + ApiNameRetrieveFileContent: true, + } + regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(/.*|$)") ) // azureProvider is the provider for Azure OpenAI service. @@ -23,21 +39,32 @@ type azureProviderInitializer struct { } func (m *azureProviderInitializer) DefaultCapabilities() map[string]string { - return map[string]string{ - // TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities - string(ApiNameChatCompletion): PathOpenAIChatCompletions, - string(ApiNameEmbeddings): PathOpenAIEmbeddings, - string(ApiNameFiles): PathOpenAIFiles, - string(ApiNameBatches): PathOpenAIBatches, + var capabilities = map[string]string{} + for k, v := range (&openaiProviderInitializer{}).DefaultCapabilities() { + if !strings.HasPrefix(v, PathOpenAIPrefix) { + log.Warnf("azureProviderInitializer: capability %s has an unexpected path %s, skipping", k, v) + continue + } + path := strings.TrimPrefix(v, PathOpenAIPrefix) + if azureModelIrrelevantApis[ApiName(k)] { + path = pathAzurePrefix + path + } else { + path = pathAzureWithModelPrefix + path + } + capabilities[k] = path + log.Debugf("azureProviderInitializer: capability %s -> %s", k, path) } + return capabilities } func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error { if config.azureServiceUrl == "" { return errors.New("missing azureServiceUrl in provider config") } - if _, err := url.Parse(config.azureServiceUrl); err != nil { + if azureServiceUrl, err := url.Parse(config.azureServiceUrl); err != nil { return fmt.Errorf("invalid azureServiceUrl: %w", err) + } else if !azureServiceUrl.Query().Has(queryAzureApiVersion) { + return fmt.Errorf("missing %s query parameter in azureServiceUrl: %s", queryAzureApiVersion, config.azureServiceUrl) } if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") @@ -52,10 +79,24 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid } else { serviceUrl = u } + + modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path) + defaultModel := "placeholder" + if modelSubMatch != nil { + defaultModel = modelSubMatch[1] + log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel) + } else { + log.Debugf("azureProvider: no default model found in serviceUrl") + } + config.setDefaultCapabilities(m.DefaultCapabilities()) + apiVersion := serviceUrl.Query().Get(queryAzureApiVersion) + log.Debugf("azureProvider: using %s: %s", queryAzureApiVersion, apiVersion) return &azureProvider{ config: config, serviceUrl: serviceUrl, + apiVersion: apiVersion, + defaultModel: defaultModel, contextCache: createContextCache(&config), }, nil } @@ -65,6 +106,8 @@ type azureProvider struct { contextCache *contextCache serviceUrl *url.URL + apiVersion string + defaultModel string } func (m *azureProvider) GetProviderType() string { @@ -80,44 +123,68 @@ func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body) } -func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { - finalRequestUrl := *m.serviceUrl - if u, e := url.Parse(ctx.Path()); e == nil { - if len(u.Query()) != 0 { - q := m.serviceUrl.Query() - for k, v := range u.Query() { - switch len(v) { - case 0: - break - case 1: - q.Set(k, v[0]) - break - default: - delete(q, k) - for _, vv := range v { - q.Add(k, vv) - } - } - } - finalRequestUrl.RawQuery = q.Encode() - } +func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (transformedBody []byte, err error) { + transformedBody = body + err = nil - if filesIndex := strings.Index(u.Path, "/files"); filesIndex != -1 { - finalRequestUrl.Path = pathAzureFiles + u.Path[filesIndex+len("/files"):] - } else if batchesIndex := strings.Index(u.Path, "/batches"); batchesIndex != -1 { - finalRequestUrl.Path = pathAzureBatches + u.Path[batchesIndex+len("/batches"):] - } - } else { - log.Errorf("failed to parse request path: %v", e) + transformedBody, err = m.config.defaultTransformRequestBody(ctx, apiName, body) + if err != nil { + return } - util.OverwriteRequestPathHeader(headers, finalRequestUrl.RequestURI()) + // This must be called after the body is transformed, because it uses the model from the context filled by that call. + if path := m.transformRequestPath(ctx, apiName); path != "" { + err = util.OverwriteRequestPath(path) + if err == nil { + log.Debugf("azureProvider: overwrite request path to %s succeeded", path) + } else { + log.Errorf("azureProvider: overwrite request path to %s failed: %v", path, err) + } + } + + return +} + +func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName ApiName) string { + originalPath := util.GetOriginalRequestPath() + + if m.config.IsOriginal() { + return originalPath + } + + log.Debugf("azureProvider: original request path: %s", originalPath) + path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities) + log.Debugf("azureProvider: path: %s", path) + if strings.Contains(path, pathAzureModelPlaceholder) { + log.Debugf("azureProvider: path contains placeholder: %s", path) + model := ctx.GetStringContext(ctxKeyFinalRequestModel, "") + log.Debugf("azureProvider: model from context: %s", model) + if model == "" { + model = m.defaultModel + log.Debugf("azureProvider: use default model: %s", model) + } + path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model) + log.Debugf("azureProvider: model replaced path: %s", path) + } + path = fmt.Sprintf("%s?%s=%s", path, queryAzureApiVersion, m.apiVersion) + log.Debugf("azureProvider: final path: %s", path) + + return path +} + +func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { + // We need to overwrite the request path in the request headers stage, + // because for some APIs, we don't read the request body and the path is model irrelevant. + if overwrittenPath := m.transformRequestPath(ctx, apiName); overwrittenPath != "" { + util.OverwriteRequestPathHeader(headers, overwrittenPath) + } util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host) headers.Set("api-key", m.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") - if !m.config.isSupportedAPI(apiName) { - // If the API is not supported, we should not read the request body and keep it as it is. + if !m.config.isSupportedAPI(apiName) || !m.config.needToProcessRequestBody(apiName) { + // If the API is not supported or there is no need to process the body, + // we should not read the request body and keep it as it is. ctx.DontReadRequestBody() } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 8aeb72b0b..a36d9f029 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -67,6 +67,7 @@ const ( ApiNameAnthropicComplete ApiName = "anthropic/v1/complete" // OpenAI + PathOpenAIPrefix = "/v1" PathOpenAICompletions = "/v1/completions" PathOpenAIChatCompletions = "/v1/chat/completions" PathOpenAIEmbeddings = "/v1/embeddings" @@ -851,7 +852,7 @@ func (c *ProviderConfig) handleRequestBody( if handler, ok := provider.(TransformRequestBodyHandler); ok { body, err = handler.TransformRequestBody(ctx, apiName, body) } else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok { - headers := util.GetOriginalRequestHeaders() + headers := util.GetRequestHeaders() body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers) util.ReplaceRequestHeaders(headers) } else { @@ -877,7 +878,7 @@ func (c *ProviderConfig) handleRequestBody( } func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) { - headers := util.GetOriginalRequestHeaders() + headers := util.GetRequestHeaders() originPath := headers.Get(":path") if c.basePath != "" && c.basePathHandling == basePathHandlingRemovePrefix { headers.Set(":path", strings.TrimPrefix(originPath, c.basePath)) @@ -888,9 +889,6 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) { headers.Set(":path", path.Join(c.basePath, headers.Get(":path"))) } - if headers.Get(":path") != originPath { - headers.Set("X-ENVOY-ORIGINAL-PATH", originPath) - } util.ReplaceRequestHeaders(headers) } @@ -908,7 +906,9 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap } model := gjson.GetBytes(body, "model").String() ctx.SetContext(ctxKeyOriginalRequestModel, model) - return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping)) + mappedModel := getMappedModel(model, c.modelMapping) + ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) + return sjson.SetBytes(body, "model", mappedModel) } func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index 397a13879..f27060c9e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -15,10 +15,10 @@ import ( "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" - "github.com/higress-group/wasm-go/pkg/log" - "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -136,7 +136,7 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if v.config.IsOriginal() { return types.ActionContinue, nil } - headers := util.GetOriginalRequestHeaders() + headers := util.GetRequestHeaders() body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers) util.ReplaceRequestHeaders(headers) _ = proxywasm.ReplaceHttpRequestBody(body) diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 37fda2ed5..70db42fb0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -10,7 +10,14 @@ import ( ) const ( - HeaderContentType = "Content-Type" + HeaderContentType = "Content-Type" + HeaderPath = ":path" + HeaderAuthority = ":authority" + HeaderAuthorization = "Authorization" + + HeaderOriginalPath = "X-ENVOY-ORIGINAL-PATH" + HeaderOriginalHost = "X-ENVOY-ORIGINAL-HOST" + HeaderOriginalAuth = "X-HI-ORIGINAL-AUTH" MimeTypeTextPlain = "text/plain" MimeTypeApplicationJson = "application/json" @@ -48,49 +55,49 @@ func CreateHeaders(kvs ...string) [][2]string { } func OverwriteRequestPath(path string) error { - if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { - _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath) - } - return proxywasm.ReplaceHttpRequestHeader(":path", path) + return proxywasm.ReplaceHttpRequestHeader(HeaderPath, path) } func OverwriteRequestAuthorization(credential string) error { - if exist, _ := proxywasm.GetHttpRequestHeader("X-HI-ORIGINAL-AUTH"); exist == "" { - if originAuth, err := proxywasm.GetHttpRequestHeader("Authorization"); err == nil { - _ = proxywasm.AddHttpRequestHeader("X-HI-ORIGINAL-AUTH", originAuth) + if exist, _ := proxywasm.GetHttpRequestHeader(HeaderOriginalAuth); exist == "" { + if originAuth, err := proxywasm.GetHttpRequestHeader(HeaderAuthorization); err == nil { + _ = proxywasm.AddHttpRequestHeader(HeaderOriginalPath, originAuth) } } - return proxywasm.ReplaceHttpRequestHeader("Authorization", credential) + return proxywasm.ReplaceHttpRequestHeader(HeaderAuthorization, credential) } func OverwriteRequestHostHeader(headers http.Header, host string) { - if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil { - headers.Set("X-ENVOY-ORIGINAL-HOST", originHost) + if originHost, err := proxywasm.GetHttpRequestHeader(HeaderAuthority); err == nil { + headers.Set(HeaderOriginalHost, originHost) } - headers.Set(":authority", host) + headers.Set(HeaderAuthority, host) } func OverwriteRequestPathHeader(headers http.Header, path string) { - if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { - headers.Set("X-ENVOY-ORIGINAL-PATH", originPath) - } - headers.Set(":path", path) + headers.Set(HeaderPath, path) } func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string, mapping map[string]string) { - mappedPath, exist := mapping[apiName] - if !exist { + originPath := GetOriginalRequestPath() + mappedPath := MapRequestPathByCapability(apiName, originPath, mapping) + if mappedPath == "" { return } - originPath, err := proxywasm.GetHttpRequestHeader(":path") - if err == nil { - headers.Set("X-ENVOY-ORIGINAL-PATH", originPath) - } + headers.Set(HeaderPath, mappedPath) + log.Debugf("[OverwriteRequestPath] originPath=%s, mappedPath=%s", originPath, mappedPath) +} + +func MapRequestPathByCapability(apiName string, originPath string, mapping map[string]string) string { /** 这里实现不太优雅,理应通过 apiName 来判断使用哪个正则替换 但 ApiName 定义在 provider 中, 而 provider 中又引用了 util 会导致循环引用 **/ + mappedPath, exist := mapping[apiName] + if !exist { + return "" + } if strings.Contains(mappedPath, "{") && strings.Contains(mappedPath, "}") { replacements := []struct { regx *regexp.Regexp @@ -119,17 +126,61 @@ func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string, } } } - headers.Set(":path", mappedPath) - log.Debugf("[OverwriteRequestPath] originPath=%s, mappedPath=%s", originPath, mappedPath) + return mappedPath +} + +func GetOriginalRequestPath() string { + path, err := proxywasm.GetHttpRequestHeader(HeaderOriginalPath) + if path != "" && err == nil { + return path + } + if path, err = proxywasm.GetHttpRequestHeader(HeaderPath); err == nil { + return path + } + return "" +} + +func SetOriginalRequestPath(path string) { + _ = proxywasm.ReplaceHttpRequestHeader(HeaderOriginalPath, path) +} + +func GetOriginalRequestHost() string { + host, err := proxywasm.GetHttpRequestHeader(HeaderOriginalHost) + if host != "" && err == nil { + return host + } + if host, err = proxywasm.GetHttpRequestHeader(HeaderAuthority); err == nil { + return host + } + return "" +} + +func SetOriginalRequestHost(host string) { + _ = proxywasm.ReplaceHttpRequestHeader(HeaderOriginalHost, host) +} + +func GetOriginalRequestAuth() string { + auth, err := proxywasm.GetHttpRequestHeader(HeaderOriginalAuth) + if auth != "" && err == nil { + return auth + } + if auth, err = proxywasm.GetHttpRequestHeader(HeaderAuthorization); err == nil { + return auth + } + return "" +} + +func SetOriginalRequestAuth(auth string) { + _ = proxywasm.ReplaceHttpRequestHeader(HeaderOriginalAuth, auth) } func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) { - if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" { - if originAuth := headers.Get("Authorization"); originAuth != "" { - headers.Set("X-HI-ORIGINAL-AUTH", originAuth) + if exist := headers.Get(HeaderOriginalAuth); exist == "" { + if originAuth := headers.Get(HeaderAuthorization); originAuth != "" { + headers.Set(HeaderOriginalAuth, originAuth) } } - headers.Set("Authorization", credential) + headers.Set(HeaderAuthorization, credential) } func HeaderToSlice(header http.Header) [][2]string { @@ -152,22 +203,22 @@ func SliceToHeader(slice [][2]string) http.Header { return header } -func GetOriginalRequestHeaders() http.Header { - originalHeaders, _ := proxywasm.GetHttpRequestHeaders() - return SliceToHeader(originalHeaders) +func GetRequestHeaders() http.Header { + header, _ := proxywasm.GetHttpRequestHeaders() + return SliceToHeader(header) } -func GetOriginalResponseHeaders() http.Header { - originalHeaders, _ := proxywasm.GetHttpResponseHeaders() - return SliceToHeader(originalHeaders) +func GetResponseHeaders() http.Header { + headers, _ := proxywasm.GetHttpResponseHeaders() + return SliceToHeader(headers) } func ReplaceRequestHeaders(headers http.Header) { - modifiedHeaders := HeaderToSlice(headers) - _ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders) + headerSlice := HeaderToSlice(headers) + _ = proxywasm.ReplaceHttpRequestHeaders(headerSlice) } func ReplaceResponseHeaders(headers http.Header) { - modifiedHeaders := HeaderToSlice(headers) - _ = proxywasm.ReplaceHttpResponseHeaders(modifiedHeaders) + headerSlice := HeaderToSlice(headers) + _ = proxywasm.ReplaceHttpResponseHeaders(headerSlice) }