feat/ai proxy vertex ai compatible (#3324)

This commit is contained in:
woody
2026-01-14 10:13:00 +08:00
committed by GitHub
parent e7010256fe
commit f1a5f18c78
7 changed files with 802 additions and 1 deletions

View File

@@ -387,6 +387,9 @@ type ProviderConfig struct {
// @Title zh-CN Vertex token刷新提前时间
// @Description zh-CN 用于Google服务账号认证access token过期时间判定提前刷新单位为秒默认值为60秒
vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"`
// @Title zh-CN Vertex AI OpenAI兼容模式
// @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API请求和响应均使用OpenAI格式无需协议转换。与Express Mode(apiTokens)互斥。
vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"`
// @Title zh-CN 翻译服务需指定的目标语种
// @Description zh-CN 翻译结果的语种目前仅适用于DeepL服务。
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
@@ -540,6 +543,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
if c.vertexTokenRefreshAhead == 0 {
c.vertexTokenRefreshAhead = 60
}
c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool()
c.targetLang = json.Get("targetLang").String()
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {

View File

@@ -21,6 +21,7 @@ import (
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
@@ -32,6 +33,9 @@ const (
// Express Mode 路径模板 (不含 project/location)
vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s"
vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s"
// OpenAI-compatible endpoint 路径模板
// /v1beta1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi/chat/completions
vertexOpenAICompatiblePathTemplate = "/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions"
vertexChatCompletionAction = "generateContent"
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
vertexAnthropicMessageAction = "rawPredict"
@@ -39,6 +43,7 @@ const (
vertexEmbeddingAction = "predict"
vertexGlobalRegion = "global"
contextClaudeMarker = "isClaudeRequest"
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
vertexAnthropicVersion = "vertex-2023-10-16"
)
@@ -47,10 +52,28 @@ type vertexProviderInitializer struct{}
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
// Express Mode: 如果配置了 apiTokens则使用 API Key 认证
if len(config.apiTokens) > 0 {
// Express Mode 与 OpenAI 兼容模式互斥
if config.vertexOpenAICompatible {
return errors.New("vertexOpenAICompatible is not compatible with Express Mode (apiTokens)")
}
// Express Mode 不需要其他配置
return nil
}
// OpenAI 兼容模式: 需要 OAuth 认证配置
if config.vertexOpenAICompatible {
if config.vertexAuthKey == "" {
return errors.New("missing vertexAuthKey in vertex provider config for OpenAI compatible mode")
}
if config.vertexRegion == "" || config.vertexProjectId == "" {
return errors.New("missing vertexRegion or vertexProjectId in vertex provider config for OpenAI compatible mode")
}
if config.vertexAuthServiceName == "" {
return errors.New("missing vertexAuthServiceName in vertex provider config for OpenAI compatible mode")
}
return nil
}
// 标准模式: 保持原有验证逻辑
if config.vertexAuthKey == "" {
return errors.New("missing vertexAuthKey in vertex provider config")
@@ -101,6 +124,12 @@ func (v *vertexProvider) isExpressMode() bool {
return len(v.config.apiTokens) > 0
}
// isOpenAICompatibleMode 检测是否启用 OpenAI 兼容模式
// 使用 Vertex AI 的 OpenAI-compatible Chat Completions API
func (v *vertexProvider) isOpenAICompatibleMode() bool {
return v.config.vertexOpenAICompatible
}
type vertexProvider struct {
client wrapper.HttpClient
config ProviderConfig
@@ -184,7 +213,30 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if v.config.IsOriginal() {
return types.ActionContinue, nil
}
headers := util.GetRequestHeaders()
// OpenAI 兼容模式: 不转换请求体,只设置路径和进行模型映射
if v.isOpenAICompatibleMode() {
ctx.SetContext(contextOpenAICompatibleMarker, true)
body, err := v.onOpenAICompatibleRequestBody(ctx, apiName, body, headers)
headers.Set("Content-Length", fmt.Sprint(len(body)))
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
return types.ActionContinue, err
}
// OpenAI 兼容模式需要 OAuth token
cached, err := v.getToken()
if cached {
return types.ActionContinue, nil
}
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
headers.Set("Content-Length", fmt.Sprint(len(body)))
@@ -220,6 +272,32 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
}
}
// onOpenAICompatibleRequestBody 处理 OpenAI 兼容模式的请求
// 不转换请求体格式,只进行模型映射和路径设置
func (v *vertexProvider) onOpenAICompatibleRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return nil, fmt.Errorf("OpenAI compatible mode only supports chat completions API")
}
// 解析请求进行模型映射
request := &chatCompletionRequest{}
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
// 设置 OpenAI 兼容端点路径
path := v.getOpenAICompatibleRequestPath()
util.OverwriteRequestPathHeader(headers, path)
// 如果模型被映射,需要更新请求体中的模型字段
if request.Model != "" {
body, _ = sjson.SetBytes(body, "model", request.Model)
}
// 保持 OpenAI 格式,直接返回(可能更新了模型字段)
return body, nil
}
func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
err := v.config.parseRequestAndMapModel(ctx, request, body)
@@ -261,6 +339,12 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
}
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON将非 ASCII 字符编码为 \uXXXX
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
return util.DecodeUnicodeEscapesInSSE(chunk), nil
}
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
}
@@ -301,6 +385,12 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
}
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON将非 ASCII 字符编码为 \uXXXX
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
return util.DecodeUnicodeEscapes(body), nil
}
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.TransformResponseBody(ctx, apiName, body)
}
@@ -510,6 +600,11 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
return path
}
// getOpenAICompatibleRequestPath 获取 OpenAI 兼容模式的请求路径
func (v *vertexProvider) getOpenAICompatibleRequestPath() string {
return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion)
}
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
safetySettings := make([]vertexChatSafetySetting, 0)
for category, threshold := range v.config.geminiSafetySetting {