// File generated by hgctl. Modify as required. // See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6 package main import ( "fmt" "net/url" "regexp" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) const ( pluginName = "ai-proxy" defaultMaxBodyBytes uint32 = 100 * 1024 * 1024 ctxOriginalPath = "original_path" ctxOriginalHost = "original_host" ctxOriginalAuth = "original_auth" ) type pair[K, V any] struct { key K value V } var ( headersCtxKeyMapping = map[string]string{ util.HeaderAuthority: ctxOriginalHost, util.HeaderPath: ctxOriginalPath, } headerToOriginalHeaderMapping = map[string]string{ util.HeaderAuthority: util.HeaderOriginalHost, util.HeaderPath: util.HeaderOriginalPath, } pathSuffixToApiName = []pair[string, provider.ApiName]{ // OpenAI style {provider.PathOpenAIChatCompletions, provider.ApiNameChatCompletion}, {provider.PathOpenAICompletions, provider.ApiNameCompletion}, {provider.PathOpenAIEmbeddings, provider.ApiNameEmbeddings}, {provider.PathOpenAIAudioSpeech, provider.ApiNameAudioSpeech}, {provider.PathOpenAIImageGeneration, provider.ApiNameImageGeneration}, {provider.PathOpenAIImageVariation, provider.ApiNameImageVariation}, {provider.PathOpenAIImageEdit, provider.ApiNameImageEdit}, {provider.PathOpenAIBatches, provider.ApiNameBatches}, {provider.PathOpenAIFiles, provider.ApiNameFiles}, {provider.PathOpenAIModels, provider.ApiNameModels}, {provider.PathOpenAIFineTuningJobs, provider.ApiNameFineTuningJobs}, {provider.PathOpenAIResponses, provider.ApiNameResponses}, {provider.PathOpenAIVideos, provider.ApiNameVideos}, // Anthropic style {provider.PathAnthropicMessages, provider.ApiNameAnthropicMessages}, {provider.PathAnthropicComplete, provider.ApiNameAnthropicComplete}, // Cohere style {provider.PathCohereV1Rerank, provider.ApiNameCohereV1Rerank}, } pathPatternToApiName = []pair[*regexp.Regexp, provider.ApiName]{ // OpenAI style {util.RegRetrieveBatchPath, provider.ApiNameRetrieveBatch}, {util.RegCancelBatchPath, provider.ApiNameCancelBatch}, {util.RegRetrieveFilePath, provider.ApiNameRetrieveFile}, {util.RegRetrieveFileContentPath, provider.ApiNameRetrieveFileContent}, {util.RegRetrieveVideoPath, provider.ApiNameRetrieveVideo}, {util.RegRetrieveVideoContentPath, provider.ApiNameRetrieveVideoContent}, {util.RegVideoRemixPath, provider.ApiNameVideoRemix}, {util.RegRetrieveFineTuningJobPath, provider.ApiNameRetrieveFineTuningJob}, {util.RegRetrieveFineTuningJobEventsPath, provider.ApiNameFineTuningJobEvents}, {util.RegRetrieveFineTuningJobCheckpointsPath, provider.ApiNameFineTuningJobCheckpoints}, {util.RegCancelFineTuningJobPath, provider.ApiNameCancelFineTuningJob}, {util.RegResumeFineTuningJobPath, provider.ApiNameResumeFineTuningJob}, {util.RegPauseFineTuningJobPath, provider.ApiNamePauseFineTuningJob}, {util.RegFineTuningCheckpointPermissionPath, provider.ApiNameFineTuningCheckpointPermissions}, {util.RegDeleteFineTuningCheckpointPermissionPath, provider.ApiNameDeleteFineTuningCheckpointPermission}, // Gemini style {util.RegGeminiGenerateContent, provider.ApiNameGeminiGenerateContent}, {util.RegGeminiStreamGenerateContent, provider.ApiNameGeminiStreamGenerateContent}, } ) func main() {} func init() { wrapper.SetCtx( pluginName, wrapper.ParseOverrideConfig(parseGlobalConfig, parseOverrideRuleConfig), wrapper.ProcessRequestHeaders(onHttpRequestHeader), wrapper.ProcessRequestBody(onHttpRequestBody), wrapper.ProcessResponseHeaders(onHttpResponseHeaders), wrapper.ProcessStreamingResponseBody(onStreamingResponseBody), wrapper.ProcessResponseBody(onHttpResponseBody), wrapper.WithRebuildAfterRequests[config.PluginConfig](1000), wrapper.WithRebuildMaxMemBytes[config.PluginConfig](200*1024*1024), ) } func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig) error { log.Debugf("loading global config: %s", json.String()) pluginConfig.FromJson(json) if err := pluginConfig.Validate(); err != nil { log.Errorf("global rule config is invalid: %v", err) return err } if err := pluginConfig.Complete(); err != nil { log.Errorf("failed to apply global rule config: %v", err) return err } return nil } func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig) error { log.Debugf("loading override rule config: %s", json.String()) *pluginConfig = global pluginConfig.FromJson(json) if err := pluginConfig.Validate(); err != nil { log.Errorf("overriden rule config is invalid: %v", err) return err } if err := pluginConfig.Complete(); err != nil { log.Errorf("failed to apply overriden rule config: %v", err) return err } return nil } func initContext(ctx wrapper.HttpContext) { for header, ctxKey := range headersCtxKeyMapping { value, _ := proxywasm.GetHttpRequestHeader(header) ctx.SetContext(ctxKey, value) } for _, originHeader := range headerToOriginalHeaderMapping { _ = proxywasm.RemoveHttpRequestHeader(originHeader) } originalAuth, _ := proxywasm.GetHttpRequestHeader(util.HeaderOriginalAuth) if originalAuth == "" { value, _ := proxywasm.GetHttpRequestHeader(util.HeaderAuthorization) ctx.SetContext(ctxOriginalAuth, value) } } func saveContextsToHeaders(ctx wrapper.HttpContext) { for header, ctxKey := range headersCtxKeyMapping { originalValue := ctx.GetStringContext(ctxKey, "") if originalValue == "" { continue } currentValue, _ := proxywasm.GetHttpRequestHeader(header) if currentValue == "" || originalValue == currentValue { continue } originalHeader := headerToOriginalHeaderMapping[header] if originalHeader != "" { _ = proxywasm.ReplaceHttpRequestHeader(originalHeader, originalValue) } } originalValue := ctx.GetStringContext(ctxOriginalAuth, "") if originalValue != "" { _ = proxywasm.ReplaceHttpRequestHeader(util.HeaderOriginalAuth, originalValue) } } func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action { activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onHttpRequestHeader] no active provider, skip processing") ctx.DontReadRequestBody() return types.ActionContinue } log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType()) // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. ctx.DisableReroute() initContext(ctx) rawPath := ctx.Path() defer func() { saveContextsToHeaders(ctx) }() path, _ := url.Parse(rawPath) apiName := getApiName(path.Path) providerConfig := pluginConfig.GetProviderConfig() if providerConfig.IsOriginal() { if handler, ok := activeProvider.(provider.ApiNameHandler); ok { apiName = handler.GetApiName(path.Path) } } else { // Only perform protocol conversion for non-original protocols. // Auto-detect protocol based on request path and handle conversion if needed // If request is Claude format (/v1/messages) but provider doesn't support it natively, // convert to OpenAI format (/v1/chat/completions) if apiName == provider.ApiNameAnthropicMessages && !providerConfig.IsSupportedAPI(provider.ApiNameAnthropicMessages) { // Provider doesn't support Claude protocol natively, convert to OpenAI format newPath := strings.Replace(path.Path, provider.PathAnthropicMessages, provider.PathOpenAIChatCompletions, 1) _ = proxywasm.ReplaceHttpRequestHeader(":path", newPath) // Update apiName to match the new path apiName = provider.ApiNameChatCompletion // Mark that we need to convert response back to Claude format ctx.SetContext("needClaudeResponseConversion", true) log.Debugf("[Auto Protocol] Claude request detected, provider doesn't support natively, converted path from %s to %s, apiName: %s", path.Path, newPath, apiName) } else if apiName == provider.ApiNameAnthropicMessages { // Provider supports Claude protocol natively, no conversion needed log.Debugf("[Auto Protocol] Claude request detected, provider supports natively, keeping original path: %s, apiName: %s", path.Path, apiName) } } if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) { ctx.DontReadRequestBody() log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType) } if apiName == "" { ctx.DontReadRequestBody() ctx.DontReadResponseBody() log.Warnf("[onHttpRequestHeader] unsupported path: %s, will not process http path and body", path.Path) } ctx.SetContext(provider.CtxKeyApiName, apiName) // Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses, // allowing plugins to inspect or modify the response correctly _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx) // Set available apiTokens of current request in the context, will be used in the retryOnFailure providerConfig.SetAvailableApiTokens(ctx) // save the original request host and path in case they are needed for apiToken health check and retry ctx.SetContext(provider.CtxRequestHost, ctx.Host()) ctx.SetContext(provider.CtxRequestPath, ctx.Path()) err := handler.OnRequestHeaders(ctx, apiName) if err != nil { _ = util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) return types.ActionContinue } hasRequestBody := ctx.HasRequestBody() if hasRequestBody { _ = proxywasm.RemoveHttpRequestHeader("Content-Length") ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) // Delay the header processing to allow changing in OnRequestBody return types.HeaderStopIteration } ctx.DontReadRequestBody() return types.ActionContinue } return types.ActionContinue } func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action { activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onHttpRequestBody] no active provider, skip processing") return types.ActionContinue } log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType()) defer func() { saveContextsToHeaders(ctx) }() if handler, ok := activeProvider.(provider.RequestBodyHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) providerConfig := pluginConfig.GetProviderConfig() // If retryOnFailure is enabled, save the transformed body to the context in case of retry if providerConfig.IsRetryOnFailureEnabled() { ctx.SetContext(provider.CtxRequestBody, body) } newBody, settingErr := providerConfig.ReplaceByCustomSettings(body) if settingErr != nil { log.Errorf("failed to replace request body by custom settings: %v", settingErr) } // 仅 /v1/chat/completions 和 /v1/completions 接口支持 stream_options 参数 if providerConfig.IsOpenAIProtocol() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) { newBody = normalizeOpenAiRequestBody(newBody) } log.Debugf("[onHttpRequestBody] newBody=%s", newBody) body = newBody action, err := handler.OnRequestBody(ctx, apiName, body) if err == nil { return action } _ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err)) } return types.ActionContinue } func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action { if !wrapper.IsResponseFromUpstream() { // Response is not coming from the upstream. Let it pass through. ctx.DontReadResponseBody() return types.ActionContinue } activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onHttpResponseHeaders] no active provider, skip processing") ctx.DontReadResponseBody() return types.ActionContinue } log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType()) providerConfig := pluginConfig.GetProviderConfig() apiTokenInUse := providerConfig.GetApiTokenInUse(ctx) apiTokens := providerConfig.GetAvailableApiToken(ctx) status, err := proxywasm.GetHttpResponseHeader(":status") if err != nil || status != "200" { if err != nil { log.Errorf("unable to load :status header from response: %v", err) } ctx.DontReadResponseBody() return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, apiTokens, status) } // Reset ctxApiTokenRequestFailureCount if the request is successful, // the apiToken is removed only when the number of consecutive request failures exceeds the threshold. providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse) headers := util.GetResponseHeaders() if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) handler.TransformResponseHeaders(ctx, apiName, headers) } else { providerConfig.DefaultTransformResponseHeaders(ctx, headers) } util.ReplaceResponseHeaders(headers) _, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler) var needHandleStreamingBody bool _, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler) if !needHandleStreamingBody { _, needHandleStreamingBody = activeProvider.(provider.StreamingEventHandler) } // Check if we need to read body for Claude response conversion needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool) if !needHandleBody && !needHandleStreamingBody && !needClaudeConversion { ctx.DontReadResponseBody() } else { checkStream(ctx) } return types.ActionContinue } func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool) []byte { activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onStreamingResponseBody] no active provider, skip processing") return chunk } log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType()) log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk)) if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk) if err == nil && modifiedChunk != nil { // Convert to Claude format if needed claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, modifiedChunk) if convertErr != nil { return modifiedChunk } return claudeChunk } return chunk } if handler, ok := activeProvider.(provider.StreamingEventHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) events := provider.ExtractStreamingEvents(ctx, chunk) log.Debugf("[onStreamingResponseBody] %d events received", len(events)) if len(events) == 0 { // No events are extracted, return empty bytes slice return []byte("") } var responseBuilder strings.Builder for _, event := range events { log.Debugf("processing event: %v", event) if event.IsEndData() { responseBuilder.WriteString(event.ToHttpString()) continue } outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event) if err != nil { log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk) return chunk } if len(outputEvents) == 0 { // no need convert, keep original events responseBuilder.WriteString(event.RawEvent) } else { for _, outputEvent := range outputEvents { responseBuilder.WriteString(outputEvent.ToHttpString()) } } } result := []byte(responseBuilder.String()) // Convert to Claude format if needed claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result) if convertErr != nil { return result } return claudeChunk } if !needsClaudeResponseConversion(ctx) { return chunk } // If provider doesn't implement any streaming handlers but we need Claude conversion // First extract complete events from the chunk events := provider.ExtractStreamingEvents(ctx, chunk) log.Debugf("[onStreamingResponseBody] %d events received (no handler)", len(events)) if len(events) == 0 { // No events are extracted, return empty bytes slice return []byte("") } // Build response from extracted events (without handler processing) var responseBuilder strings.Builder for _, event := range events { responseBuilder.WriteString(event.ToHttpString()) } result := []byte(responseBuilder.String()) // Convert to Claude format if needed claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result) if convertErr != nil { return result } return claudeChunk } func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action { activeProvider := pluginConfig.GetProvider() if activeProvider == nil { log.Debugf("[onHttpResponseBody] no active provider, skip processing") return types.ActionContinue } log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType()) var finalBody []byte if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok { apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName) transformedBody, err := handler.TransformResponseBody(ctx, apiName, body) if err != nil { _ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err)) return types.ActionContinue } finalBody = transformedBody } else { finalBody = body } // Convert to Claude format if needed (applies to both branches) convertedBody, err := convertResponseBodyToClaude(ctx, finalBody) if err != nil { _ = util.ErrorHandler("ai-proxy.convert_resp_to_claude_failed", err) return types.ActionContinue } if err = provider.ReplaceResponseBody(convertedBody); err != nil { _ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err)) } return types.ActionContinue } // Helper function to check if Claude response conversion is needed func needsClaudeResponseConversion(ctx wrapper.HttpContext) bool { needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool) return needClaudeConversion } // Helper function to convert OpenAI streaming response to Claude format func convertStreamingResponseToClaude(ctx wrapper.HttpContext, data []byte) ([]byte, error) { if !needsClaudeResponseConversion(ctx) { return data, nil } // Get or create converter instance from context to maintain state const claudeConverterKey = "claudeConverter" var converter *provider.ClaudeToOpenAIConverter if converterData := ctx.GetContext(claudeConverterKey); converterData != nil { if c, ok := converterData.(*provider.ClaudeToOpenAIConverter); ok { converter = c } } if converter == nil { converter = &provider.ClaudeToOpenAIConverter{} ctx.SetContext(claudeConverterKey, converter) } claudeChunk, err := converter.ConvertOpenAIStreamResponseToClaude(ctx, data) if err != nil { log.Errorf("failed to convert streaming response to claude format: %v", err) return data, err } return claudeChunk, nil } // Helper function to convert OpenAI response body to Claude format func convertResponseBodyToClaude(ctx wrapper.HttpContext, body []byte) ([]byte, error) { if !needsClaudeResponseConversion(ctx) { return body, nil } converter := &provider.ClaudeToOpenAIConverter{} convertedBody, err := converter.ConvertOpenAIResponseToClaude(ctx, body) if err != nil { return body, fmt.Errorf("failed to convert response to claude format: %v", err) } return convertedBody, nil } func normalizeOpenAiRequestBody(body []byte) []byte { var err error // Default setting include_usage. if gjson.GetBytes(body, "stream").Bool() && (!gjson.GetBytes(body, "stream_options").Exists() || !gjson.GetBytes(body, "stream_options.include_usage").Exists()) { body, err = sjson.SetBytes(body, "stream_options.include_usage", true) if err != nil { log.Errorf("set include_usage failed, err:%s", err) } } return body } func checkStream(ctx wrapper.HttpContext) { contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { if err != nil { log.Errorf("unable to load content-type header from response: %v", err) } ctx.BufferResponseBody() ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes) } } func getApiName(path string) provider.ApiName { // Check path suffix matches first for _, p := range pathSuffixToApiName { if strings.HasSuffix(path, p.key) { return p.value } } // Check path pattern matches for _, p := range pathPatternToApiName { if p.key.MatchString(path) { return p.value } } return "" }