mirror of
https://github.com/alibaba/higress.git
synced 2026-03-01 23:20:52 +08:00
597 lines
21 KiB
Go
597 lines
21 KiB
Go
// File generated by hgctl. Modify as required.
|
|
// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6
|
|
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config"
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
|
|
|
"github.com/higress-group/wasm-go/pkg/log"
|
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
const (
|
|
pluginName = "ai-proxy"
|
|
|
|
defaultMaxBodyBytes uint32 = 100 * 1024 * 1024
|
|
|
|
ctxOriginalPath = "original_path"
|
|
ctxOriginalHost = "original_host"
|
|
ctxOriginalAuth = "original_auth"
|
|
)
|
|
|
|
type pair[K, V any] struct {
|
|
key K
|
|
value V
|
|
}
|
|
|
|
var (
|
|
headersCtxKeyMapping = map[string]string{
|
|
util.HeaderAuthority: ctxOriginalHost,
|
|
util.HeaderPath: ctxOriginalPath,
|
|
}
|
|
headerToOriginalHeaderMapping = map[string]string{
|
|
util.HeaderAuthority: util.HeaderOriginalHost,
|
|
util.HeaderPath: util.HeaderOriginalPath,
|
|
}
|
|
pathSuffixToApiName = []pair[string, provider.ApiName]{
|
|
// OpenAI style
|
|
{provider.PathOpenAIChatCompletions, provider.ApiNameChatCompletion},
|
|
{provider.PathOpenAICompletions, provider.ApiNameCompletion},
|
|
{provider.PathOpenAIEmbeddings, provider.ApiNameEmbeddings},
|
|
{provider.PathOpenAIAudioSpeech, provider.ApiNameAudioSpeech},
|
|
{provider.PathOpenAIImageGeneration, provider.ApiNameImageGeneration},
|
|
{provider.PathOpenAIImageVariation, provider.ApiNameImageVariation},
|
|
{provider.PathOpenAIImageEdit, provider.ApiNameImageEdit},
|
|
{provider.PathOpenAIBatches, provider.ApiNameBatches},
|
|
{provider.PathOpenAIFiles, provider.ApiNameFiles},
|
|
{provider.PathOpenAIModels, provider.ApiNameModels},
|
|
{provider.PathOpenAIFineTuningJobs, provider.ApiNameFineTuningJobs},
|
|
{provider.PathOpenAIResponses, provider.ApiNameResponses},
|
|
{provider.PathOpenAIVideos, provider.ApiNameVideos},
|
|
// Anthropic style
|
|
{provider.PathAnthropicMessages, provider.ApiNameAnthropicMessages},
|
|
{provider.PathAnthropicComplete, provider.ApiNameAnthropicComplete},
|
|
// Cohere style
|
|
{provider.PathCohereV1Rerank, provider.ApiNameCohereV1Rerank},
|
|
}
|
|
pathPatternToApiName = []pair[*regexp.Regexp, provider.ApiName]{
|
|
// OpenAI style
|
|
{util.RegRetrieveBatchPath, provider.ApiNameRetrieveBatch},
|
|
{util.RegCancelBatchPath, provider.ApiNameCancelBatch},
|
|
{util.RegRetrieveFilePath, provider.ApiNameRetrieveFile},
|
|
{util.RegRetrieveFileContentPath, provider.ApiNameRetrieveFileContent},
|
|
{util.RegRetrieveVideoPath, provider.ApiNameRetrieveVideo},
|
|
{util.RegRetrieveVideoContentPath, provider.ApiNameRetrieveVideoContent},
|
|
{util.RegVideoRemixPath, provider.ApiNameVideoRemix},
|
|
{util.RegRetrieveFineTuningJobPath, provider.ApiNameRetrieveFineTuningJob},
|
|
{util.RegRetrieveFineTuningJobEventsPath, provider.ApiNameFineTuningJobEvents},
|
|
{util.RegRetrieveFineTuningJobCheckpointsPath, provider.ApiNameFineTuningJobCheckpoints},
|
|
{util.RegCancelFineTuningJobPath, provider.ApiNameCancelFineTuningJob},
|
|
{util.RegResumeFineTuningJobPath, provider.ApiNameResumeFineTuningJob},
|
|
{util.RegPauseFineTuningJobPath, provider.ApiNamePauseFineTuningJob},
|
|
{util.RegFineTuningCheckpointPermissionPath, provider.ApiNameFineTuningCheckpointPermissions},
|
|
{util.RegDeleteFineTuningCheckpointPermissionPath, provider.ApiNameDeleteFineTuningCheckpointPermission},
|
|
// Gemini style
|
|
{util.RegGeminiGenerateContent, provider.ApiNameGeminiGenerateContent},
|
|
{util.RegGeminiStreamGenerateContent, provider.ApiNameGeminiStreamGenerateContent},
|
|
}
|
|
)
|
|
|
|
func main() {}
|
|
|
|
func init() {
|
|
wrapper.SetCtx(
|
|
pluginName,
|
|
wrapper.ParseOverrideConfig(parseGlobalConfig, parseOverrideRuleConfig),
|
|
wrapper.ProcessRequestHeaders(onHttpRequestHeader),
|
|
wrapper.ProcessRequestBody(onHttpRequestBody),
|
|
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
|
wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
|
|
wrapper.ProcessResponseBody(onHttpResponseBody),
|
|
wrapper.WithRebuildAfterRequests[config.PluginConfig](1000),
|
|
wrapper.WithRebuildMaxMemBytes[config.PluginConfig](200*1024*1024),
|
|
)
|
|
}
|
|
|
|
func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig) error {
|
|
log.Debugf("loading global config: %s", json.String())
|
|
|
|
pluginConfig.FromJson(json)
|
|
if err := pluginConfig.Validate(); err != nil {
|
|
log.Errorf("global rule config is invalid: %v", err)
|
|
return err
|
|
}
|
|
if err := pluginConfig.Complete(); err != nil {
|
|
log.Errorf("failed to apply global rule config: %v", err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig) error {
|
|
log.Debugf("loading override rule config: %s", json.String())
|
|
|
|
*pluginConfig = global
|
|
|
|
pluginConfig.FromJson(json)
|
|
if err := pluginConfig.Validate(); err != nil {
|
|
log.Errorf("overriden rule config is invalid: %v", err)
|
|
return err
|
|
}
|
|
if err := pluginConfig.Complete(); err != nil {
|
|
log.Errorf("failed to apply overriden rule config: %v", err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func initContext(ctx wrapper.HttpContext) {
|
|
for header, ctxKey := range headersCtxKeyMapping {
|
|
value, _ := proxywasm.GetHttpRequestHeader(header)
|
|
ctx.SetContext(ctxKey, value)
|
|
}
|
|
for _, originHeader := range headerToOriginalHeaderMapping {
|
|
_ = proxywasm.RemoveHttpRequestHeader(originHeader)
|
|
}
|
|
originalAuth, _ := proxywasm.GetHttpRequestHeader(util.HeaderOriginalAuth)
|
|
if originalAuth == "" {
|
|
value, _ := proxywasm.GetHttpRequestHeader(util.HeaderAuthorization)
|
|
ctx.SetContext(ctxOriginalAuth, value)
|
|
}
|
|
}
|
|
|
|
func saveContextsToHeaders(ctx wrapper.HttpContext) {
|
|
for header, ctxKey := range headersCtxKeyMapping {
|
|
originalValue := ctx.GetStringContext(ctxKey, "")
|
|
if originalValue == "" {
|
|
continue
|
|
}
|
|
currentValue, _ := proxywasm.GetHttpRequestHeader(header)
|
|
if currentValue == "" || originalValue == currentValue {
|
|
continue
|
|
}
|
|
originalHeader := headerToOriginalHeaderMapping[header]
|
|
if originalHeader != "" {
|
|
_ = proxywasm.ReplaceHttpRequestHeader(originalHeader, originalValue)
|
|
}
|
|
}
|
|
originalValue := ctx.GetStringContext(ctxOriginalAuth, "")
|
|
if originalValue != "" {
|
|
_ = proxywasm.ReplaceHttpRequestHeader(util.HeaderOriginalAuth, originalValue)
|
|
}
|
|
}
|
|
|
|
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onHttpRequestHeader] no active provider, skip processing")
|
|
ctx.DontReadRequestBody()
|
|
return types.ActionContinue
|
|
}
|
|
|
|
log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType())
|
|
|
|
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
|
ctx.DisableReroute()
|
|
|
|
initContext(ctx)
|
|
|
|
rawPath := ctx.Path()
|
|
|
|
defer func() {
|
|
saveContextsToHeaders(ctx)
|
|
}()
|
|
|
|
path, _ := url.Parse(rawPath)
|
|
apiName := getApiName(path.Path)
|
|
providerConfig := pluginConfig.GetProviderConfig()
|
|
if providerConfig.IsOriginal() {
|
|
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
|
|
apiName = handler.GetApiName(path.Path)
|
|
}
|
|
} else {
|
|
// Only perform protocol conversion for non-original protocols.
|
|
// Auto-detect protocol based on request path and handle conversion if needed
|
|
// If request is Claude format (/v1/messages) but provider doesn't support it natively,
|
|
// convert to OpenAI format (/v1/chat/completions)
|
|
if apiName == provider.ApiNameAnthropicMessages && !providerConfig.IsSupportedAPI(provider.ApiNameAnthropicMessages) {
|
|
// Provider doesn't support Claude protocol natively, convert to OpenAI format
|
|
newPath := strings.Replace(path.Path, provider.PathAnthropicMessages, provider.PathOpenAIChatCompletions, 1)
|
|
_ = proxywasm.ReplaceHttpRequestHeader(":path", newPath)
|
|
// Update apiName to match the new path
|
|
apiName = provider.ApiNameChatCompletion
|
|
// Mark that we need to convert response back to Claude format
|
|
ctx.SetContext("needClaudeResponseConversion", true)
|
|
log.Debugf("[Auto Protocol] Claude request detected, provider doesn't support natively, converted path from %s to %s, apiName: %s", path.Path, newPath, apiName)
|
|
} else if apiName == provider.ApiNameAnthropicMessages {
|
|
// Provider supports Claude protocol natively, no conversion needed
|
|
log.Debugf("[Auto Protocol] Claude request detected, provider supports natively, keeping original path: %s, apiName: %s", path.Path, apiName)
|
|
}
|
|
}
|
|
|
|
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
|
ctx.DontReadRequestBody()
|
|
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
|
|
}
|
|
|
|
if apiName == "" {
|
|
ctx.DontReadRequestBody()
|
|
ctx.DontReadResponseBody()
|
|
log.Warnf("[onHttpRequestHeader] unsupported path: %s, will not process http path and body", path.Path)
|
|
}
|
|
|
|
ctx.SetContext(provider.CtxKeyApiName, apiName)
|
|
|
|
// Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses,
|
|
// allowing plugins to inspect or modify the response correctly
|
|
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
|
|
|
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
|
// Set the apiToken for the current request.
|
|
providerConfig.SetApiTokenInUse(ctx)
|
|
// Set available apiTokens of current request in the context, will be used in the retryOnFailure
|
|
providerConfig.SetAvailableApiTokens(ctx)
|
|
|
|
// save the original request host and path in case they are needed for apiToken health check and retry
|
|
ctx.SetContext(provider.CtxRequestHost, ctx.Host())
|
|
ctx.SetContext(provider.CtxRequestPath, ctx.Path())
|
|
|
|
err := handler.OnRequestHeaders(ctx, apiName)
|
|
if err != nil {
|
|
_ = util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
|
|
return types.ActionContinue
|
|
}
|
|
|
|
hasRequestBody := ctx.HasRequestBody()
|
|
if hasRequestBody {
|
|
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
|
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
|
// Delay the header processing to allow changing in OnRequestBody
|
|
return types.HeaderStopIteration
|
|
}
|
|
ctx.DontReadRequestBody()
|
|
return types.ActionContinue
|
|
}
|
|
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onHttpRequestBody] no active provider, skip processing")
|
|
return types.ActionContinue
|
|
}
|
|
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
|
|
|
|
defer func() {
|
|
saveContextsToHeaders(ctx)
|
|
}()
|
|
|
|
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
|
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
|
providerConfig := pluginConfig.GetProviderConfig()
|
|
// If retryOnFailure is enabled, save the transformed body to the context in case of retry
|
|
if providerConfig.IsRetryOnFailureEnabled() {
|
|
ctx.SetContext(provider.CtxRequestBody, body)
|
|
}
|
|
newBody, settingErr := providerConfig.ReplaceByCustomSettings(body)
|
|
if settingErr != nil {
|
|
log.Errorf("failed to replace request body by custom settings: %v", settingErr)
|
|
}
|
|
// 仅 /v1/chat/completions 和 /v1/completions 接口支持 stream_options 参数
|
|
if providerConfig.IsOpenAIProtocol() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) {
|
|
newBody = normalizeOpenAiRequestBody(newBody)
|
|
}
|
|
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
|
|
body = newBody
|
|
action, err := handler.OnRequestBody(ctx, apiName, body)
|
|
if err == nil {
|
|
return action
|
|
}
|
|
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
|
}
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
|
|
if !wrapper.IsResponseFromUpstream() {
|
|
// Response is not coming from the upstream. Let it pass through.
|
|
ctx.DontReadResponseBody()
|
|
return types.ActionContinue
|
|
}
|
|
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onHttpResponseHeaders] no active provider, skip processing")
|
|
ctx.DontReadResponseBody()
|
|
return types.ActionContinue
|
|
}
|
|
|
|
log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())
|
|
|
|
providerConfig := pluginConfig.GetProviderConfig()
|
|
apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)
|
|
apiTokens := providerConfig.GetAvailableApiToken(ctx)
|
|
|
|
status, err := proxywasm.GetHttpResponseHeader(":status")
|
|
if err != nil || status != "200" {
|
|
if err != nil {
|
|
log.Errorf("unable to load :status header from response: %v", err)
|
|
}
|
|
ctx.DontReadResponseBody()
|
|
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, apiTokens, status)
|
|
}
|
|
|
|
// Reset ctxApiTokenRequestFailureCount if the request is successful,
|
|
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
|
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse)
|
|
|
|
headers := util.GetResponseHeaders()
|
|
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
|
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
|
handler.TransformResponseHeaders(ctx, apiName, headers)
|
|
} else {
|
|
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
|
|
}
|
|
util.ReplaceResponseHeaders(headers)
|
|
|
|
_, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler)
|
|
var needHandleStreamingBody bool
|
|
_, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler)
|
|
if !needHandleStreamingBody {
|
|
_, needHandleStreamingBody = activeProvider.(provider.StreamingEventHandler)
|
|
}
|
|
|
|
// Check if we need to read body for Claude response conversion
|
|
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
|
|
|
|
if !needHandleBody && !needHandleStreamingBody && !needClaudeConversion {
|
|
ctx.DontReadResponseBody()
|
|
} else {
|
|
checkStream(ctx)
|
|
}
|
|
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool) []byte {
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onStreamingResponseBody] no active provider, skip processing")
|
|
return chunk
|
|
}
|
|
|
|
log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType())
|
|
log.Debugf("[onStreamingResponseBody] isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
|
|
|
|
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
|
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
|
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk)
|
|
if err == nil && modifiedChunk != nil {
|
|
// Convert to Claude format if needed
|
|
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, modifiedChunk)
|
|
if convertErr != nil {
|
|
return modifiedChunk
|
|
}
|
|
return claudeChunk
|
|
}
|
|
return chunk
|
|
}
|
|
if handler, ok := activeProvider.(provider.StreamingEventHandler); ok {
|
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
|
events := provider.ExtractStreamingEvents(ctx, chunk)
|
|
log.Debugf("[onStreamingResponseBody] %d events received", len(events))
|
|
if len(events) == 0 {
|
|
// No events are extracted, return empty bytes slice
|
|
return []byte("")
|
|
}
|
|
var responseBuilder strings.Builder
|
|
for _, event := range events {
|
|
log.Debugf("processing event: %v", event)
|
|
|
|
if event.IsEndData() {
|
|
responseBuilder.WriteString(event.ToHttpString())
|
|
continue
|
|
}
|
|
|
|
outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event)
|
|
if err != nil {
|
|
log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk)
|
|
return chunk
|
|
}
|
|
if len(outputEvents) == 0 {
|
|
// no need convert, keep original events
|
|
responseBuilder.WriteString(event.RawEvent)
|
|
} else {
|
|
for _, outputEvent := range outputEvents {
|
|
responseBuilder.WriteString(outputEvent.ToHttpString())
|
|
}
|
|
}
|
|
}
|
|
|
|
result := []byte(responseBuilder.String())
|
|
|
|
// Convert to Claude format if needed
|
|
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
|
|
if convertErr != nil {
|
|
return result
|
|
}
|
|
return claudeChunk
|
|
}
|
|
|
|
if !needsClaudeResponseConversion(ctx) {
|
|
return chunk
|
|
}
|
|
|
|
// If provider doesn't implement any streaming handlers but we need Claude conversion
|
|
// First extract complete events from the chunk
|
|
events := provider.ExtractStreamingEvents(ctx, chunk)
|
|
log.Debugf("[onStreamingResponseBody] %d events received (no handler)", len(events))
|
|
if len(events) == 0 {
|
|
// No events are extracted, return empty bytes slice
|
|
return []byte("")
|
|
}
|
|
|
|
// Build response from extracted events (without handler processing)
|
|
var responseBuilder strings.Builder
|
|
for _, event := range events {
|
|
responseBuilder.WriteString(event.ToHttpString())
|
|
}
|
|
|
|
result := []byte(responseBuilder.String())
|
|
|
|
// Convert to Claude format if needed
|
|
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
|
|
if convertErr != nil {
|
|
return result
|
|
}
|
|
return claudeChunk
|
|
}
|
|
|
|
func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onHttpResponseBody] no active provider, skip processing")
|
|
return types.ActionContinue
|
|
}
|
|
|
|
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
|
|
|
|
var finalBody []byte
|
|
|
|
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
|
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
|
transformedBody, err := handler.TransformResponseBody(ctx, apiName, body)
|
|
if err != nil {
|
|
_ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
|
return types.ActionContinue
|
|
}
|
|
finalBody = transformedBody
|
|
} else {
|
|
finalBody = body
|
|
}
|
|
|
|
// Convert to Claude format if needed (applies to both branches)
|
|
convertedBody, err := convertResponseBodyToClaude(ctx, finalBody)
|
|
if err != nil {
|
|
_ = util.ErrorHandler("ai-proxy.convert_resp_to_claude_failed", err)
|
|
return types.ActionContinue
|
|
}
|
|
|
|
if err = provider.ReplaceResponseBody(convertedBody); err != nil {
|
|
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
|
}
|
|
return types.ActionContinue
|
|
}
|
|
|
|
// Helper function to check if Claude response conversion is needed
|
|
func needsClaudeResponseConversion(ctx wrapper.HttpContext) bool {
|
|
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
|
|
return needClaudeConversion
|
|
}
|
|
|
|
// Helper function to convert OpenAI streaming response to Claude format
|
|
func convertStreamingResponseToClaude(ctx wrapper.HttpContext, data []byte) ([]byte, error) {
|
|
if !needsClaudeResponseConversion(ctx) {
|
|
return data, nil
|
|
}
|
|
|
|
// Get or create converter instance from context to maintain state
|
|
const claudeConverterKey = "claudeConverter"
|
|
var converter *provider.ClaudeToOpenAIConverter
|
|
|
|
if converterData := ctx.GetContext(claudeConverterKey); converterData != nil {
|
|
if c, ok := converterData.(*provider.ClaudeToOpenAIConverter); ok {
|
|
converter = c
|
|
}
|
|
}
|
|
|
|
if converter == nil {
|
|
converter = &provider.ClaudeToOpenAIConverter{}
|
|
ctx.SetContext(claudeConverterKey, converter)
|
|
}
|
|
|
|
claudeChunk, err := converter.ConvertOpenAIStreamResponseToClaude(ctx, data)
|
|
if err != nil {
|
|
log.Errorf("failed to convert streaming response to claude format: %v", err)
|
|
return data, err
|
|
}
|
|
return claudeChunk, nil
|
|
}
|
|
|
|
// Helper function to convert OpenAI response body to Claude format
|
|
func convertResponseBodyToClaude(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
|
if !needsClaudeResponseConversion(ctx) {
|
|
return body, nil
|
|
}
|
|
|
|
converter := &provider.ClaudeToOpenAIConverter{}
|
|
convertedBody, err := converter.ConvertOpenAIResponseToClaude(ctx, body)
|
|
if err != nil {
|
|
return body, fmt.Errorf("failed to convert response to claude format: %v", err)
|
|
}
|
|
return convertedBody, nil
|
|
}
|
|
|
|
func normalizeOpenAiRequestBody(body []byte) []byte {
|
|
var err error
|
|
// Default setting include_usage.
|
|
if gjson.GetBytes(body, "stream").Bool() && (!gjson.GetBytes(body, "stream_options").Exists() || !gjson.GetBytes(body, "stream_options.include_usage").Exists()) {
|
|
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
|
|
if err != nil {
|
|
log.Errorf("set include_usage failed, err:%s", err)
|
|
}
|
|
}
|
|
return body
|
|
}
|
|
|
|
func checkStream(ctx wrapper.HttpContext) {
|
|
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
|
|
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
|
|
if err != nil {
|
|
log.Errorf("unable to load content-type header from response: %v", err)
|
|
}
|
|
ctx.BufferResponseBody()
|
|
ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes)
|
|
}
|
|
}
|
|
|
|
func getApiName(path string) provider.ApiName {
|
|
// Check path suffix matches first
|
|
for _, p := range pathSuffixToApiName {
|
|
if strings.HasSuffix(path, p.key) {
|
|
return p.value
|
|
}
|
|
}
|
|
|
|
// Check path pattern matches
|
|
for _, p := range pathPatternToApiName {
|
|
if p.key.MatchString(path) {
|
|
return p.value
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|