Files
higress/plugins/wasm-go/extensions/ai-proxy/provider/provider.go

1107 lines
46 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package provider
import (
"bytes"
"errors"
"fmt"
"math/rand"
"net/http"
"path"
"regexp"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type (
ApiName string
Pointcut string
basePathHandling string
)
const (
// ApiName 格式 {vendor}/{version}/{apitype}
// 表示遵循 厂商/版本/接口类型 的格式
// 目前openai是事实意义上的标准但是也有其他厂商存在其他任务的一些可能的标准比如cohere的rerank
ApiNameCompletion ApiName = "openai/v1/completions"
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameImageEdit ApiName = "openai/v1/imageedit"
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
ApiNameFiles ApiName = "openai/v1/files"
ApiNameRetrieveFile ApiName = "openai/v1/retrievefile"
ApiNameRetrieveFileContent ApiName = "openai/v1/retrievefilecontent"
ApiNameBatches ApiName = "openai/v1/batches"
ApiNameRetrieveBatch ApiName = "openai/v1/retrievebatch"
ApiNameCancelBatch ApiName = "openai/v1/cancelbatch"
ApiNameModels ApiName = "openai/v1/models"
ApiNameResponses ApiName = "openai/v1/responses"
ApiNameFineTuningJobs ApiName = "openai/v1/fine-tuningjobs"
ApiNameRetrieveFineTuningJob ApiName = "openai/v1/retrievefine-tuningjob"
ApiNameFineTuningJobEvents ApiName = "openai/v1/fine-tuningjobsevents"
ApiNameFineTuningJobCheckpoints ApiName = "openai/v1/fine-tuningjobcheckpoints"
ApiNameCancelFineTuningJob ApiName = "openai/v1/cancelfine-tuningjob"
ApiNameResumeFineTuningJob ApiName = "openai/v1/resumefine-tuningjob"
ApiNamePauseFineTuningJob ApiName = "openai/v1/pausefine-tuningjob"
ApiNameFineTuningCheckpointPermissions ApiName = "openai/v1/fine-tuningjobcheckpointpermissions"
ApiNameDeleteFineTuningCheckpointPermission ApiName = "openai/v1/deletefine-tuningjobcheckpointpermission"
ApiNameVideos ApiName = "openai/v1/videos"
ApiNameRetrieveVideo ApiName = "openai/v1/retrievevideo"
ApiNameVideoRemix ApiName = "openai/v1/videoremix"
ApiNameRetrieveVideoContent ApiName = "openai/v1/retrievevideocontent"
// TODO: 以下是一些非标准的API名称需要进一步确认是否支持
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
ApiNameQwenAsyncAIGC ApiName = "qwen/v1/services/aigc"
ApiNameQwenAsyncTask ApiName = "qwen/v1/tasks"
ApiNameQwenV1Rerank ApiName = "qwen/v1/rerank"
ApiNameGeminiGenerateContent ApiName = "gemini/v1beta/generatecontent"
ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent"
ApiNameAnthropicMessages ApiName = "anthropic/v1/messages"
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
ApiNameVertexRaw ApiName = "vertex/raw"
// OpenAI
PathOpenAIPrefix = "/v1"
PathOpenAICompletions = "/v1/completions"
PathOpenAIChatCompletions = "/v1/chat/completions"
PathOpenAIEmbeddings = "/v1/embeddings"
PathOpenAIFiles = "/v1/files"
PathOpenAIRetrieveFile = "/v1/files/{file_id}"
PathOpenAIRetrieveFileContent = "/v1/files/{file_id}/content"
PathOpenAIBatches = "/v1/batches"
PathOpenAIRetrieveBatch = "/v1/batches/{batch_id}"
PathOpenAICancelBatch = "/v1/batches/{batch_id}/cancel"
PathOpenAIModels = "/v1/models"
PathOpenAIImageGeneration = "/v1/images/generations"
PathOpenAIImageEdit = "/v1/images/edits"
PathOpenAIImageVariation = "/v1/images/variations"
PathOpenAIAudioSpeech = "/v1/audio/speech"
PathOpenAIResponses = "/v1/responses"
PathOpenAIFineTuningJobs = "/v1/fine_tuning/jobs"
PathOpenAIRetrieveFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}"
PathOpenAIFineTuningJobEvents = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/events"
PathOpenAIFineTuningJobCheckpoints = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/checkpoints"
PathOpenAICancelFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel"
PathOpenAIResumeFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/resume"
PathOpenAIPauseFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/pause"
PathOpenAIFineTuningCheckpointPermissions = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions"
PathOpenAIFineDeleteTuningCheckpointPermission = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions/{permission_id}"
PathOpenAIVideos = "/v1/videos"
PathOpenAIRetrieveVideo = "/v1/videos/{video_id}"
PathOpenAIVideoRemix = "/v1/videos/{video_id}/remix"
PathOpenAIRetrieveVideoContent = "/v1/videos/{video_id}/content"
// Anthropic
PathAnthropicMessages = "/v1/messages"
PathAnthropicComplete = "/v1/complete"
// Cohere
PathCohereV1Rerank = "/v1/rerank"
providerTypeMoonshot = "moonshot"
providerTypeAzure = "azure"
providerTypeAi360 = "ai360"
providerTypeGithub = "github"
providerTypeQwen = "qwen"
providerTypeOpenAI = "openai"
providerTypeGroq = "groq"
providerTypeGrok = "grok"
providerTypeBaichuan = "baichuan"
providerTypeYi = "yi"
providerTypeDeepSeek = "deepseek"
providerTypeZhipuAi = "zhipuai"
providerTypeOllama = "ollama"
providerTypeClaude = "claude"
providerTypeBaidu = "baidu"
providerTypeHunyuan = "hunyuan"
providerTypeStepfun = "stepfun"
providerTypeMinimax = "minimax"
providerTypeCloudflare = "cloudflare"
providerTypeSpark = "spark"
providerTypeGemini = "gemini"
providerTypeDeepl = "deepl"
providerTypeMistral = "mistral"
providerTypeCohere = "cohere"
providerTypeDoubao = "doubao"
providerTypeCoze = "coze"
providerTypeTogetherAI = "together-ai"
providerTypeDify = "dify"
providerTypeBedrock = "bedrock"
providerTypeVertex = "vertex"
providerTypeTriton = "triton"
providerTypeOpenRouter = "openrouter"
providerTypeLongcat = "longcat"
providerTypeFireworks = "fireworks"
providerTypeVllm = "vllm"
providerTypeGeneric = "generic"
protocolOpenAI = "openai"
protocolOriginal = "original"
roleSystem = "system"
roleAssistant = "assistant"
roleUser = "user"
roleTool = "tool"
finishReasonStop = "stop"
finishReasonLength = "length"
finishReasonToolCall = "tool_calls"
ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyApiKey = "apiKey"
CtxKeyApiName = "apiName"
ctxKeyIsStreaming = "isStreaming"
ctxKeyStreamingBody = "streamingBody"
ctxKeyOriginalRequestModel = "originalRequestModel"
ctxKeyFinalRequestModel = "finalRequestModel"
ctxKeyPushedMessage = "pushedMessage"
ctxKeyContentPushed = "contentPushed"
ctxKeyReasoningContentPushed = "reasoningContentPushed"
objectChatCompletion = "chat.completion"
objectChatCompletionChunk = "chat.completion.chunk"
reasoningBehaviorPassThrough = "passthrough"
reasoningBehaviorIgnore = "ignore"
reasoningBehaviorConcat = "concat"
wildcard = "*"
defaultTimeout = 2 * 60 * 1000 // ms
basePathHandlingRemovePrefix basePathHandling = "removePrefix"
basePathHandlingPrepend basePathHandling = "prepend"
)
type providerInitializer interface {
ValidateConfig(*ProviderConfig) error
CreateProvider(ProviderConfig) (Provider, error)
}
var (
errUnsupportedApiName = errors.New("unsupported API name")
providerInitializers = map[string]providerInitializer{
providerTypeMoonshot: &moonshotProviderInitializer{},
providerTypeAzure: &azureProviderInitializer{},
providerTypeAi360: &ai360ProviderInitializer{},
providerTypeGithub: &githubProviderInitializer{},
providerTypeQwen: &qwenProviderInitializer{},
providerTypeOpenAI: &openaiProviderInitializer{},
providerTypeGroq: &groqProviderInitializer{},
providerTypeGrok: &grokProviderInitializer{},
providerTypeBaichuan: &baichuanProviderInitializer{},
providerTypeYi: &yiProviderInitializer{},
providerTypeDeepSeek: &deepseekProviderInitializer{},
providerTypeZhipuAi: &zhipuAiProviderInitializer{},
providerTypeOllama: &ollamaProviderInitializer{},
providerTypeClaude: &claudeProviderInitializer{},
providerTypeBaidu: &baiduProviderInitializer{},
providerTypeHunyuan: &hunyuanProviderInitializer{},
providerTypeStepfun: &stepfunProviderInitializer{},
providerTypeMinimax: &minimaxProviderInitializer{},
providerTypeCloudflare: &cloudflareProviderInitializer{},
providerTypeSpark: &sparkProviderInitializer{},
providerTypeGemini: &geminiProviderInitializer{},
providerTypeDeepl: &deeplProviderInitializer{},
providerTypeMistral: &mistralProviderInitializer{},
providerTypeCohere: &cohereProviderInitializer{},
providerTypeDoubao: &doubaoProviderInitializer{},
providerTypeCoze: &cozeProviderInitializer{},
providerTypeTogetherAI: &togetherAIProviderInitializer{},
providerTypeDify: &difyProviderInitializer{},
providerTypeBedrock: &bedrockProviderInitializer{},
providerTypeVertex: &vertexProviderInitializer{},
providerTypeTriton: &tritonProviderInitializer{},
providerTypeOpenRouter: &openrouterProviderInitializer{},
providerTypeLongcat: &longcatProviderInitializer{},
providerTypeFireworks: &fireworksProviderInitializer{},
providerTypeVllm: &vllmProviderInitializer{},
providerTypeGeneric: &genericProviderInitializer{},
}
)
type Provider interface {
GetProviderType() string
}
type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error
}
type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error)
}
type StreamingResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error)
}
type StreamingEventHandler interface {
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error)
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
}
type TransformRequestBodyHandler interface {
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
}
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
// Some providers (e.g. gemini) transform request headers (e.g., path) based on the request body (e.g., model).
type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error)
}
type TransformResponseHeadersHandler interface {
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
}
type TransformResponseBodyHandler interface {
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
}
type ProviderConfig struct {
// @Title zh-CN ID
// @Description zh-CN AI服务提供商标识
id string `required:"true" yaml:"id" json:"id"`
// @Title zh-CN 类型
// @Description zh-CN AI服务提供商类型
typ string `required:"true" yaml:"type" json:"type"`
// @Title zh-CN API Tokens
// @Description zh-CN 在请求AI服务时用于认证的API Token列表。不同的AI服务提供商可能有不同的名称。部分供应商只支持配置一个API Token如Azure OpenAI
apiTokens []string `required:"false" yaml:"apiToken" json:"apiTokens"`
// @Title zh-CN 请求超时
// @Description zh-CN 请求AI服务的超时时间单位为毫秒。默认值为120000即2分钟。此项配置目前仅用于获取上下文信息并不影响实际转发大模型请求。
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
// @Title zh-CN apiToken 故障切换
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
failover *failover `required:"false" yaml:"failover" json:"failover"`
// @Title zh-CN 失败请求重试
// @Description zh-CN 对失败的请求立即进行重试
retryOnFailure *retryOnFailure `required:"false" yaml:"retryOnFailure" json:"retryOnFailure"`
// @Title zh-CN 推理内容处理方式
// @Description zh-CN 如何处理大模型服务返回的推理内容。目前支持以下取值passthrough正常输出推理内容、ignore不输出推理内容、concat将推理内容拼接在常规输出内容之前。默认为 normal。仅支持通义千问服务。
reasoningContentMode string `required:"false" yaml:"reasoningContentMode" json:"reasoningContentMode"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
// @Title zh-CN Moonshot File ID
// @Description zh-CN 仅适用于Moonshot AI服务。Moonshot AI服务的文件ID其内容用于补充AI请求上下文
moonshotFileId string `required:"false" yaml:"moonshotFileId" json:"moonshotFileId"`
// @Title zh-CN Azure OpenAI Service URL
// @Description zh-CN 仅适用于Azure OpenAI服务。要请求的OpenAI服务的完整URL包含api-version等参数
azureServiceUrl string `required:"false" yaml:"azureServiceUrl" json:"azureServiceUrl"`
// @Title zh-CN 通义千问File ID
// @Description zh-CN 仅适用于通义千问服务。上传到Dashscope的文件ID其内容用于补充AI请求上下文。仅支持qwen-long模型。
qwenFileIds []string `required:"false" yaml:"qwenFileIds" json:"qwenFileIds"`
// @Title zh-CN 启用通义千问搜索服务
// @Description zh-CN 仅适用于通义千问服务,表示是否启用通义千问的互联网搜索功能。
qwenEnableSearch bool `required:"false" yaml:"qwenEnableSearch" json:"qwenEnableSearch"`
// @Title zh-CN 通义千问服务域名
// @Description zh-CN 仅适用于通义千问服务,默认转发域名为 dashscope.aliyuncs.com, 当使用金融云服务时,可以设置为 dashscope-finance.aliyuncs.com
qwenDomain string `required:"false" yaml:"qwenDomain" json:"qwenDomain"`
// @Title zh-CN 开启通义千问兼容模式
// @Description zh-CN 启用通义千问兼容模式后,将调用千问的兼容模式接口,同时对请求/响应不做修改。
qwenEnableCompatible bool `required:"false" yaml:"qwenEnableCompatible" json:"qwenEnableCompatible"`
// @Title zh-CN Ollama Server IP/Domain
// @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的主机地址。
ollamaServerHost string `required:"false" yaml:"ollamaServerHost" json:"ollamaServerHost"`
// @Title zh-CN Ollama Server Port
// @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的端口号。
ollamaServerPort uint32 `required:"false" yaml:"ollamaServerPort" json:"ollamaServerPort"`
// @Title zh-CN hunyuan api key for authorization
// @Description zh-CN 仅适用于Hun Yuan AI服务鉴权API key/id 参考https://cloud.tencent.com/document/api/1729/101843#Golang
hunyuanAuthKey string `required:"false" yaml:"hunyuanAuthKey" json:"hunyuanAuthKey"`
// @Title zh-CN hunyuan api id for authorization
// @Description zh-CN 仅适用于Hun Yuan AI服务鉴权
hunyuanAuthId string `required:"false" yaml:"hunyuanAuthId" json:"hunyuanAuthId"`
// @Title zh-CN Amazon Bedrock AccessKey for authorization
// @Description zh-CN 仅适用于Amazon Bedrock服务鉴权API key/id 参考https://docs.aws.amazon.com/zh_cn/IAM/latest/UserGuide/reference_sigv.html
awsAccessKey string `required:"false" yaml:"awsAccessKey" json:"awsAccessKey"`
// @Title zh-CN Amazon Bedrock SecretKey for authorization
// @Description zh-CN 仅适用于Amazon Bedrock服务鉴权
awsSecretKey string `required:"false" yaml:"awsSecretKey" json:"awsSecretKey"`
// @Title zh-CN Amazon Bedrock Region
// @Description zh-CN 仅适用于Amazon Bedrock服务访问
awsRegion string `required:"false" yaml:"awsRegion" json:"awsRegion"`
// @Title zh-CN Amazon Bedrock 额外模型请求参数
// @Description zh-CN 仅适用于Amazon Bedrock服务用于设置模型特定的推理参数
bedrockAdditionalFields map[string]interface{} `required:"false" yaml:"bedrockAdditionalFields" json:"bedrockAdditionalFields"`
// @Title zh-CN minimax API type
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型v2 和 pro 中选填一项,默认值为 v2
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
// @Title zh-CN minimax group id
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型为 pro 时必填
minimaxGroupId string `required:"false" yaml:"minimaxGroupId" json:"minimaxGroupId"`
// @Title zh-CN 模型名称映射表
// @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射
modelMapping map[string]string `required:"false" yaml:"modelMapping" json:"modelMapping"`
// @Title zh-CN 对外接口协议
// @Description zh-CN 通过本插件对外提供的AI服务接口协议。默认值为“openai”即OpenAI的接口协议。如需保留原有接口协议可配置为“original"
protocol string `required:"false" yaml:"protocol" json:"protocol"`
// @Title zh-CN 模型对话上下文
// @Description zh-CN 配置一个外部获取对话上下文的文件来源用于在AI请求中补充对话上下文
context *ContextConfig `required:"false" yaml:"context" json:"context"`
// @Title zh-CN 版本
// @Description zh-CN 请求AI服务的版本目前仅适用于 Gemini 和 Claude AI服务
apiVersion string `required:"false" yaml:"apiVersion" json:"apiVersion"`
// @Title zh-CN Cloudflare Account ID
// @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api
cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"`
// @Title zh-CN Gemini AI内容过滤和安全级别设定
// @Description zh-CN 仅适用于 Gemini AI 服务。参考https://ai.google.dev/gemini-api/docs/safety-settings
geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"`
// @Title zh-CN Gemini Thinking Budget 配置
// @Description zh-CN 仅适用于 Gemini AI 服务,用于控制思考预算
geminiThinkingBudget int64 `required:"false" yaml:"geminiThinkingBudget" json:"geminiThinkingBudget"`
// @Title zh-CN Vertex AI访问区域
// @Description zh-CN 仅适用于Vertex AI服务。如需查看支持的区域的完整列表请参阅https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations?hl=zh-cn#available-regions
vertexRegion string `required:"false" yaml:"vertexRegion" json:"vertexRegion"`
// @Title zh-CN Vertex AI项目Id
// @Description zh-CN 仅适用于Vertex AI服务。创建和管理项目请参阅https://cloud.google.com/resource-manager/docs/creating-managing-projects?hl=zh-cn#identifiers
vertexProjectId string `required:"false" yaml:"vertexProjectId" json:"vertexProjectId"`
// @Title zh-CN Vertex 认证秘钥
// @Description zh-CN 用于Google服务账号认证的完整JSON密钥文件内容获取可参考https://cloud.google.com/iam/docs/keys-create-delete?hl=zh-cn#iam-service-account-keys-create-console
vertexAuthKey string `required:"false" yaml:"vertexAuthKey" json:"vertexAuthKey"`
// @Title zh-CN Vertex 认证服务名
// @Description zh-CN 用于Google服务账号认证的服务,DNS类型的服务名
vertexAuthServiceName string `required:"false" yaml:"vertexAuthServiceName" json:"vertexAuthServiceName"`
// @Title zh-CN Vertex token刷新提前时间
// @Description zh-CN 用于Google服务账号认证access token过期时间判定提前刷新单位为秒默认值为60秒
vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"`
// @Title zh-CN Vertex AI OpenAI兼容模式
// @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API请求和响应均使用OpenAI格式无需协议转换。与Express Mode(apiTokens)互斥。
vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"`
// @Title zh-CN 翻译服务需指定的目标语种
// @Description zh-CN 翻译结果的语种目前仅适用于DeepL服务。
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
// @Title zh-CN 指定服务返回的响应需满足的JSON Schema
// @Description zh-CN 目前仅适用于OpenAI部分模型服务。参考https://platform.openai.com/docs/guides/structured-outputs
responseJsonSchema map[string]interface{} `required:"false" yaml:"responseJsonSchema" json:"responseJsonSchema"`
// @Title zh-CN 自定义认证Header名称
// @Description zh-CN 用于从请求中提取认证token的自定义header名称。如不配置则按默认优先级检查 x-api-key、x-authorization、anthropic-api-key 和 Authorization header。
authHeaderKey string `required:"false" yaml:"authHeaderKey" json:"authHeaderKey"`
// @Title zh-CN 自定义大模型参数配置
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
customSettings []CustomSetting
// @Title zh-CN dify私有化部署的url
difyApiUrl string `required:"false" yaml:"difyApiUrl" json:"difyApiUrl"`
// @Title zh-CN dify的应用类型Chat/Completion/Agent/Workflow
botType string `required:"false" yaml:"botType" json:"botType"`
// @Title zh-CN dify中应用类型为workflow时需要设置输入变量当botType为workflow时一起使用
inputVariable string `required:"false" yaml:"inputVariable" json:"inputVariable"`
// @Title zh-CN dify中应用类型为workflow时需要设置输出变量当botType为workflow时一起使用
outputVariable string `required:"false" yaml:"outputVariable" json:"outputVariable"`
// @Title zh-CN 额外支持的ai能力
// @Description zh-CN 开放的ai能力和urlpath映射例如 {"openai/v1/chatcompletions": "/v1/chat/completions"}
capabilities map[string]string
// @Title zh-CN 如果配置了basePath可用于在请求path中移除该前缀或添加至请求path中默认为进行移除
basePath string `required:"false" yaml:"basePath" json:"basePath"`
// @Title zh-CN basePathHandling用于指定basePath的处理方式可选值removePrefix、prepend
basePathHandling basePathHandling `required:"false" yaml:"basePathHandling" json:"basePathHandling"`
// @Title zh-CN generic Provider 对应的Host
// @Description zh-CN 仅适用于generic provider用于覆盖请求转发的目标Host
genericHost string `required:"false" yaml:"genericHost" json:"genericHost"`
// @Title zh-CN 上下文清理命令
// @Description zh-CN 配置清理命令文本列表,当请求的 messages 中存在完全匹配任意一个命令的 user 消息时,将该消息及之前所有非 system 消息清理掉,实现主动清理上下文的效果
contextCleanupCommands []string `required:"false" yaml:"contextCleanupCommands" json:"contextCleanupCommands"`
// @Title zh-CN 首包超时
// @Description zh-CN 流式请求中收到上游服务第一个响应包的超时时间,单位为毫秒。默认值为 0表示不开启首包超时
firstByteTimeout uint32 `required:"false" yaml:"firstByteTimeout" json:"firstByteTimeout"`
// @Title zh-CN Triton Model Version
// @Description 仅适用于 NVIDIA Triton Interference Server :path 中的 modelVersion 参考:"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/protocol/extension_generate.html"
tritonModelVersion string `required:"false" yaml:"tritonModelVersion" json:"tritonModelVersion"`
// @Title zh-CN Triton Server 部署的 Domain
// @Description 仅适用于 NVIDIA Triton Interference Server :path 中的 modelVersion 参考:"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/protocol/extension_generate.html"
tritonDomain string `required:"false" yaml:"tritonDomain" json:"tritonDomain"`
// @Title zh-CN vLLM自定义后端URL
// @Description zh-CN 仅适用于vLLM服务。vLLM服务的完整URL包含协议、域名、端口等
vllmCustomUrl string `required:"false" yaml:"vllmCustomUrl" json:"vllmCustomUrl"`
// @Title zh-CN vLLM主机地址
// @Description zh-CN 仅适用于vLLM服务指定vLLM服务器的主机地址例如vllm-service.cluster.local
vllmServerHost string `required:"false" yaml:"vllmServerHost" json:"vllmServerHost"`
// @Title zh-CN 豆包服务域名
// @Description zh-CN 仅适用于豆包服务,默认转发域名为 ark.cn-beijing.volces.com
doubaoDomain string `required:"false" yaml:"doubaoDomain" json:"doubaoDomain"`
}
func (c *ProviderConfig) GetId() string {
return c.id
}
func (c *ProviderConfig) GetType() string {
return c.typ
}
func (c *ProviderConfig) GetProtocol() string {
return c.protocol
}
func (c *ProviderConfig) GetVllmCustomUrl() string {
return c.vllmCustomUrl
}
func (c *ProviderConfig) GetVllmServerHost() string {
return c.vllmServerHost
}
func (c *ProviderConfig) GetContextCleanupCommands() []string {
return c.contextCleanupCommands
}
func (c *ProviderConfig) IsOpenAIProtocol() bool {
return c.protocol == protocolOpenAI
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
c.id = json.Get("id").String()
c.typ = json.Get("type").String()
c.apiTokens = make([]string, 0)
for _, token := range json.Get("apiTokens").Array() {
c.apiTokens = append(c.apiTokens, token.String())
}
c.timeout = uint32(json.Get("timeout").Uint())
if c.timeout == 0 {
c.timeout = defaultTimeout
}
// first byte timeout
c.firstByteTimeout = uint32(json.Get("firstByteTimeout").Uint())
c.openaiCustomUrl = json.Get("openaiCustomUrl").String()
c.moonshotFileId = json.Get("moonshotFileId").String()
c.azureServiceUrl = json.Get("azureServiceUrl").String()
c.qwenFileIds = make([]string, 0)
for _, fileId := range json.Get("qwenFileIds").Array() {
c.qwenFileIds = append(c.qwenFileIds, fileId.String())
}
c.qwenEnableSearch = json.Get("qwenEnableSearch").Bool()
if compatible := json.Get("qwenEnableCompatible"); compatible.Exists() {
c.qwenEnableCompatible = compatible.Bool()
} else {
// Default use official compatiable mode
c.qwenEnableCompatible = true
}
c.qwenDomain = json.Get("qwenDomain").String()
if c.qwenDomain != "" {
// TODO: validate the domain, if not valid, set to default
}
c.ollamaServerHost = json.Get("ollamaServerHost").String()
c.ollamaServerPort = uint32(json.Get("ollamaServerPort").Uint())
c.modelMapping = make(map[string]string)
for k, v := range json.Get("modelMapping").Map() {
c.modelMapping[k] = v.String()
}
c.protocol = json.Get("protocol").String()
if c.protocol == "" {
c.protocol = protocolOpenAI
}
contextJson := json.Get("context")
if contextJson.Exists() {
c.context = &ContextConfig{}
c.context.FromJson(contextJson)
}
// 这里获取 claudeVersion 字段,与结构体中定义 yaml/json 的 tag 不一致
c.apiVersion = json.Get("claudeVersion").String()
if c.apiVersion == "" {
// 增加获取 version 字段,用于适配其他模型的配置,并保持与结构体中定义的 tag 一致
c.apiVersion = json.Get("apiVersion").String()
}
c.hunyuanAuthId = json.Get("hunyuanAuthId").String()
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
c.awsAccessKey = json.Get("awsAccessKey").String()
c.awsSecretKey = json.Get("awsSecretKey").String()
c.awsRegion = json.Get("awsRegion").String()
if c.typ == providerTypeBedrock {
c.bedrockAdditionalFields = make(map[string]interface{})
for k, v := range json.Get("bedrockAdditionalFields").Map() {
c.bedrockAdditionalFields[k] = v.Value()
}
}
c.minimaxApiType = json.Get("minimaxApiType").String()
c.minimaxGroupId = json.Get("minimaxGroupId").String()
c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
if c.typ == providerTypeGemini || c.typ == providerTypeVertex {
c.geminiSafetySetting = make(map[string]string)
for k, v := range json.Get("geminiSafetySetting").Map() {
c.geminiSafetySetting[k] = v.String()
}
}
c.geminiThinkingBudget = json.Get("geminiThinkingBudget").Int()
c.vertexRegion = json.Get("vertexRegion").String()
c.vertexProjectId = json.Get("vertexProjectId").String()
c.vertexAuthKey = json.Get("vertexAuthKey").String()
c.vertexAuthServiceName = json.Get("vertexAuthServiceName").String()
c.vertexTokenRefreshAhead = json.Get("vertexTokenRefreshAhead").Int()
if c.vertexTokenRefreshAhead == 0 {
c.vertexTokenRefreshAhead = 60
}
c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool()
c.targetLang = json.Get("targetLang").String()
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
c.responseJsonSchema = schemaValue
} else {
c.responseJsonSchema = nil
}
c.customSettings = make([]CustomSetting, 0)
customSettingsJson := json.Get("customSettings")
if customSettingsJson.Exists() {
protocol := protocolOpenAI
if c.protocol == protocolOriginal {
// use provider name to represent original protocol name
protocol = c.typ
}
for _, settingJson := range customSettingsJson.Array() {
setting := CustomSetting{}
setting.FromJson(settingJson)
// use protocol info to rewrite setting
setting.AdjustWithProtocol(protocol)
if setting.Validate() {
c.customSettings = append(c.customSettings, setting)
}
}
}
c.reasoningContentMode = json.Get("reasoningContentMode").String()
if c.reasoningContentMode == "" {
c.reasoningContentMode = reasoningBehaviorPassThrough
} else {
c.reasoningContentMode = strings.ToLower(c.reasoningContentMode)
switch c.reasoningContentMode {
case reasoningBehaviorPassThrough, reasoningBehaviorIgnore, reasoningBehaviorConcat:
// valid values, no action needed
default:
c.reasoningContentMode = reasoningBehaviorPassThrough
}
}
failoverJson := json.Get("failover")
c.failover = &failover{
enabled: false,
}
if failoverJson.Exists() {
c.failover.FromJson(failoverJson)
}
retryOnFailureJson := json.Get("retryOnFailure")
c.retryOnFailure = &retryOnFailure{
enabled: false,
}
if retryOnFailureJson.Exists() {
c.retryOnFailure.FromJson(retryOnFailureJson)
}
c.difyApiUrl = json.Get("difyApiUrl").String()
c.botType = json.Get("botType").String()
c.inputVariable = json.Get("inputVariable").String()
c.outputVariable = json.Get("outputVariable").String()
// NVIDIA triton
c.tritonModelVersion = json.Get("tritonModelVersion").String()
c.tritonDomain = json.Get("tritonDomain").String()
c.capabilities = make(map[string]string)
for capability, pathJson := range json.Get("capabilities").Map() {
// 过滤掉不受支持的能力
switch capability {
case string(ApiNameChatCompletion),
string(ApiNameEmbeddings),
string(ApiNameImageGeneration),
string(ApiNameImageVariation),
string(ApiNameImageEdit),
string(ApiNameAudioSpeech),
string(ApiNameCohereV1Rerank),
string(ApiNameVideos),
string(ApiNameRetrieveVideo),
string(ApiNameRetrieveVideoContent),
string(ApiNameVideoRemix):
c.capabilities[capability] = pathJson.String()
}
}
c.basePath = json.Get("basePath").String()
c.basePathHandling = basePathHandling(json.Get("basePathHandling").String())
if c.basePath != "" && c.basePathHandling == "" {
c.basePathHandling = basePathHandlingRemovePrefix
}
c.genericHost = json.Get("genericHost").String()
c.vllmServerHost = json.Get("vllmServerHost").String()
c.vllmCustomUrl = json.Get("vllmCustomUrl").String()
c.doubaoDomain = json.Get("doubaoDomain").String()
c.contextCleanupCommands = make([]string, 0)
for _, cmd := range json.Get("contextCleanupCommands").Array() {
if cmd.String() != "" {
c.contextCleanupCommands = append(c.contextCleanupCommands, cmd.String())
}
}
}
func (c *ProviderConfig) Validate() error {
if c.protocol != protocolOpenAI && c.protocol != protocolOriginal {
return errors.New("invalid protocol in config")
}
if c.context != nil {
if err := c.context.Validate(); err != nil {
return err
}
}
if c.failover.enabled {
if err := c.failover.Validate(); err != nil {
return err
}
}
if c.typ == "" {
return errors.New("missing type in provider config")
}
initializer, has := providerInitializers[c.typ]
if !has {
return errors.New("unknown provider type: " + c.typ)
}
if err := initializer.ValidateConfig(c); err != nil {
return err
}
return nil
}
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
ctxApiKey := ctx.GetContext(ctxKeyApiKey)
if ctxApiKey == nil {
ctxApiKey = c.GetRandomToken()
ctx.SetContext(ctxKeyApiKey, ctxApiKey)
}
return ctxApiKey.(string)
}
func (c *ProviderConfig) GetRandomToken() string {
apiTokens := c.apiTokens
count := len(apiTokens)
switch count {
case 0:
return ""
case 1:
return apiTokens[0]
default:
return apiTokens[rand.Intn(count)]
}
}
func (c *ProviderConfig) IsOriginal() bool {
return c.protocol == protocolOriginal
}
func (c *ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
return ReplaceByCustomSettings(body, c.customSettings)
}
func CreateProvider(pc ProviderConfig) (Provider, error) {
initializer, has := providerInitializers[pc.typ]
if !has {
return nil, errors.New("unknown provider type: " + pc.typ)
}
return initializer.CreateProvider(pc)
}
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte) error {
switch req := request.(type) {
case *chatCompletionRequest:
if err := decodeChatCompletionRequest(body, req); err != nil {
return err
}
streaming := req.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
ctx.SetContext(ctxKeyIsStreaming, true)
} else {
ctx.SetContext(ctxKeyIsStreaming, false)
}
return c.setRequestModel(ctx, req)
case *embeddingsRequest:
if err := decodeEmbeddingsRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req)
case *imageGenerationRequest:
if err := decodeImageGenerationRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req)
default:
return errors.New("unsupported request type")
}
}
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}) error {
var model *string
switch req := request.(type) {
case *chatCompletionRequest:
model = &req.Model
case *embeddingsRequest:
model = &req.Model
case *imageGenerationRequest:
model = &req.Model
default:
return errors.New("unsupported request type")
}
return c.mapModel(ctx, model)
}
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string) error {
if *model == "" {
return errors.New("missing model in request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
mappedModel := getMappedModel(*model, c.modelMapping)
if mappedModel == "" {
return errors.New("model becomes empty after applying the configured mapping")
}
*model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, *model)
return nil
}
func getMappedModel(model string, modelMapping map[string]string) string {
mappedModel := doGetMappedModel(model, modelMapping)
if len(mappedModel) != 0 {
return mappedModel
}
return model
}
func doGetMappedModel(model string, modelMapping map[string]string) string {
if len(modelMapping) == 0 {
return ""
}
if v, ok := modelMapping[model]; ok {
log.Debugf("model [%s] is mapped to [%s] explictly", model, v)
return v
}
for k, v := range modelMapping {
if k == wildcard {
continue
}
if strings.HasSuffix(k, wildcard) {
k = strings.TrimSuffix(k, wildcard)
if strings.HasPrefix(model, k) {
log.Debugf("model [%s] is mapped to [%s] via prefix [%s]", model, v, k)
return v
}
}
if strings.HasPrefix(k, "~") {
k = strings.TrimPrefix(k, "~")
re := regexp.MustCompile(k)
if re.MatchString(model) {
v = re.ReplaceAllString(model, v)
log.Debugf("model [%s] is mapped to [%s] via regex [%s]", model, v, k)
return v
}
}
}
if v, ok := modelMapping[wildcard]; ok {
log.Debugf("model [%s] is mapped to [%s] via wildcard", model, v)
return v
}
return ""
}
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent {
body := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
body = append(bufferedStreamingBody, chunk...)
}
body = bytes.ReplaceAll(body, []byte("\r\n"), []byte("\n"))
body = bytes.ReplaceAll(body, []byte("\r"), []byte("\n"))
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
defer func() {
if eventStartIndex >= 0 && eventStartIndex < len(body) {
// Just in case the received chunk is not a complete event.
ctx.SetContext(ctxKeyStreamingBody, body[eventStartIndex:])
} else {
ctx.SetContext(ctxKeyStreamingBody, nil)
}
}()
// Sample Qwen event response:
//
// event:result
// :HTTP_STATUS/200
// data:{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":116,"input_tokens":114,"output_tokens":2},"request_id":"71689cfc-1f42-9949-86e8-9563b7f832b1"}
//
// event:error
// :HTTP_STATUS/400
// data:{"code":"InvalidParameter","message":"Preprocessor error","request_id":"0cbe6006-faec-9854-bf8b-c906d75c3bd8"}
//
var events []StreamEvent
currentKey := ""
currentEvent := &StreamEvent{}
i, length := 0, len(body)
for i = 0; i < length; i++ {
ch := body[i]
if ch != '\n' {
if lineStartIndex == -1 {
if eventStartIndex == -1 {
eventStartIndex = i
}
lineStartIndex = i
valueStartIndex = -1
}
if valueStartIndex == -1 {
if ch == ':' {
valueStartIndex = i + 1
currentKey = string(body[lineStartIndex:valueStartIndex])
}
} else if valueStartIndex == i && ch == ' ' {
// Skip leading spaces in data.
valueStartIndex = i + 1
}
continue
}
if lineStartIndex != -1 {
value := string(body[valueStartIndex:i])
currentEvent.SetValue(currentKey, value)
} else {
currentEvent.RawEvent = string(body[eventStartIndex : i+1])
// Extra new line. The current event is complete.
events = append(events, *currentEvent)
// Reset event parsing state.
eventStartIndex = -1
currentEvent = &StreamEvent{}
}
// Reset line parsing state.
lineStartIndex = -1
valueStartIndex = -1
currentKey = ""
}
return events
}
func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
_, exist := c.capabilities[string(apiName)]
return exist
}
func (c *ProviderConfig) IsSupportedAPI(apiName ApiName) bool {
return c.isSupportedAPI(apiName)
}
func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) {
if c.capabilities == nil {
c.capabilities = make(map[string]string)
}
for capability, path := range capabilities {
c.capabilities[capability] = path
}
}
func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte,
) (types.Action, error) {
// add the first byte timeout header to the request
if c.firstByteTimeout != 0 && c.isStreamingAPI(apiName, body) {
err := proxywasm.ReplaceHttpRequestHeader("x-envoy-upstream-rq-first-byte-timeout-ms", strconv.FormatUint(uint64(c.firstByteTimeout), 10))
if err != nil {
log.Errorf("failed to set x-envoy-upstream-rq-first-byte-timeout-ms header: %v", err)
}
log.Debugf("[firstByteTimeout] %d", c.firstByteTimeout)
}
// use original protocol
if c.IsOriginal() {
return types.ActionContinue, nil
}
var err error
// handle claude protocol input - auto-detect based on conversion marker
// If main.go detected a Claude request that needs conversion, convert the body
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
if needClaudeConversion {
// Convert Claude protocol to OpenAI protocol
converter := &ClaudeToOpenAIConverter{}
body, err = converter.ConvertClaudeRequestToOpenAI(body)
if err != nil {
return types.ActionContinue, fmt.Errorf("failed to convert claude request to openai: %v", err)
}
log.Debugf("[Auto Protocol] converted Claude request body to OpenAI format")
}
// handle context cleanup command for chat completion requests
if apiName == ApiNameChatCompletion && len(c.contextCleanupCommands) > 0 {
body, err = cleanupContextMessages(body, c.contextCleanupCommands)
if err != nil {
log.Warnf("[contextCleanup] failed to cleanup context messages: %v", err)
// Continue processing even if cleanup fails
err = nil
}
}
// use openai protocol (either original openai or converted from claude)
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
util.ReplaceRequestHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, apiName, body)
}
if err != nil {
return types.ActionContinue, err
}
if apiName == ApiNameChatCompletion {
if c.context == nil {
return types.ActionContinue, replaceRequestBody(body)
}
err = contextCache.GetContextFromFile(ctx, provider, body)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
return types.ActionContinue, replaceRequestBody(body)
}
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
headers := util.GetRequestHeaders()
originPath := headers.Get(":path")
// Record the path after removePrefix processing
var removePrefixPath string
if c.basePath != "" && c.basePathHandling == basePathHandlingRemovePrefix {
removePrefixPath = strings.TrimPrefix(originPath, c.basePath)
headers.Set(":path", removePrefixPath)
}
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, apiName, headers)
}
// When using original protocol with removePrefix, restore the basePath-processed path.
// This ensures basePathHandling works correctly even when TransformRequestHeaders
// overwrites the path (which most providers do).
//
// TODO: Most providers (OpenAI, vLLM, DeepSeek, Claude, etc.) unconditionally overwrite
// the path in TransformRequestHeaders without checking IsOriginal(). Ideally, each provider
// should check IsOriginal() before overwriting the path (like Qwen does). Once all providers
// are updated to handle protocol correctly, this workaround can be removed.
// Affected providers: OpenAI, vLLM, ZhipuAI, Moonshot, Longcat, DeepSeek, Azure, Yi,
// TogetherAI, Stepfun, Ollama, Hunyuan, GitHub, Doubao, Cohere, Baichuan, AI360, Claude,
// Groq, Grok, Spark, Fireworks, Cloudflare, Baidu, OpenRouter, DeepL (24+ providers)
if c.IsOriginal() && removePrefixPath != "" {
headers.Set(":path", removePrefixPath)
}
if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) {
headers.Set(":path", path.Join(c.basePath, headers.Get(":path")))
}
util.ReplaceRequestHeaders(headers)
}
// defaultTransformRequestBody 默认的请求体转换方法只做模型映射用slog替换模型名称不用序列化和反序列化提高性能
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
switch apiName {
case ApiNameChatCompletion,
ApiNameVideos,
ApiNameVideoRemix:
stream := gjson.GetBytes(body, "stream").Bool()
if stream {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
ctx.SetContext(ctxKeyIsStreaming, true)
} else {
ctx.SetContext(ctxKeyIsStreaming, false)
}
}
model := gjson.GetBytes(body, "model").String()
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, c.modelMapping)
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
return sjson.SetBytes(body, "model", mappedModel)
}
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
if c.protocol == protocolOriginal {
ctx.DontReadResponseBody()
} else {
headers.Del("Content-Length")
}
}
func (c *ProviderConfig) isStreamingAPI(apiName ApiName, body []byte) bool {
stream := false
switch apiName {
case ApiNameCompletion,
ApiNameChatCompletion,
ApiNameImageGeneration,
ApiNameImageEdit,
ApiNameResponses,
ApiNameQwenAsyncAIGC,
ApiNameAnthropicMessages,
ApiNameAnthropicComplete:
stream = gjson.GetBytes(body, "stream").Bool()
case ApiNameGeminiStreamGenerateContent:
stream = true
}
return stream
}
func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool {
switch apiName {
case ApiNameChatCompletion,
ApiNameVideos,
ApiNameVideoRemix,
ApiNameCompletion,
ApiNameEmbeddings,
ApiNameImageGeneration,
ApiNameImageEdit,
ApiNameImageVariation,
ApiNameAudioSpeech,
ApiNameFineTuningJobs,
ApiNameResponses,
ApiNameGeminiGenerateContent,
ApiNameGeminiStreamGenerateContent,
ApiNameAnthropicMessages:
return true
}
return false
}