feat(vertex): 为 ai-proxy 插件的 Vertex AI Provider 添加 Express Mode 支持 || feat(vertex): Add Express Mode support to Vertex AI Provider of ai-proxy plug-in (#3301)

This commit is contained in:
woody
2026-01-13 20:00:05 +08:00
committed by GitHub
parent 72c87b3e15
commit 23fbe0e9e9
5 changed files with 726 additions and 19 deletions

View File

@@ -27,8 +27,11 @@ const (
vertexAuthDomain = "oauth2.googleapis.com"
vertexDomain = "aiplatform.googleapis.com"
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
// Express Mode 路径模板 (不含 project/location)
vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s"
vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s"
vertexChatCompletionAction = "generateContent"
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
vertexAnthropicMessageAction = "rawPredict"
@@ -42,6 +45,13 @@ const (
type vertexProviderInitializer struct{}
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
// Express Mode: 如果配置了 apiTokens则使用 API Key 认证
if len(config.apiTokens) > 0 {
// Express Mode 不需要其他配置
return nil
}
// 标准模式: 保持原有验证逻辑
if config.vertexAuthKey == "" {
return errors.New("missing vertexAuthKey in vertex provider config")
}
@@ -63,19 +73,32 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(v.DefaultCapabilities())
return &vertexProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.DnsCluster{
Domain: vertexAuthDomain,
ServiceName: config.vertexAuthServiceName,
Port: 443,
}),
provider := &vertexProvider{
config: config,
contextCache: createContextCache(&config),
claude: &claudeProvider{
config: config,
contextCache: createContextCache(&config),
},
}, nil
}
// 仅标准模式需要 OAuth 客户端Express Mode 通过 apiTokens 配置)
if !provider.isExpressMode() {
provider.client = wrapper.NewClusterClient(wrapper.DnsCluster{
Domain: vertexAuthDomain,
ServiceName: config.vertexAuthServiceName,
Port: 443,
})
}
return provider, nil
}
// isExpressMode 检测是否启用 Express Mode
// 如果配置了 apiTokens则使用 Express ModeAPI Key 认证)
func (v *vertexProvider) isExpressMode() bool {
return len(v.config.apiTokens) > 0
}
type vertexProvider struct {
@@ -106,11 +129,19 @@ func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
var finalVertexDomain string
if v.config.vertexRegion != vertexGlobalRegion {
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
} else {
if v.isExpressMode() {
// Express Mode: 固定域名,不带 region 前缀
finalVertexDomain = vertexDomain
} else {
// 标准模式: 带 region 前缀
if v.config.vertexRegion != vertexGlobalRegion {
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
} else {
finalVertexDomain = vertexDomain
}
}
util.OverwriteRequestHostHeader(headers, finalVertexDomain)
}
@@ -156,6 +187,16 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
headers := util.GetRequestHeaders()
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
headers.Set("Content-Length", fmt.Sprint(len(body)))
if v.isExpressMode() {
// Express Mode: 不需要 Authorization headerAPI Key 已在 URL 中
headers.Del("Authorization")
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
return types.ActionContinue, err
}
// 标准模式: 需要获取 OAuth token
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
@@ -422,7 +463,23 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
} else {
action = vertexAnthropicMessageAction
}
return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
if v.isExpressMode() {
// Express Mode: 简化路径 + API Key 参数
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
apiKey := v.config.GetRandomToken()
// 如果 action 已经包含 ?,使用 & 拼接
var fullPath string
if strings.Contains(action, "?") {
fullPath = basePath + "&key=" + apiKey
} else {
fullPath = basePath + "?key=" + apiKey
}
return fullPath
}
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
return path
}
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
@@ -434,7 +491,23 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
} else {
action = vertexChatCompletionAction
}
return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
if v.isExpressMode() {
// Express Mode: 简化路径 + API Key 参数
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
apiKey := v.config.GetRandomToken()
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse使用 & 拼接
var fullPath string
if strings.Contains(action, "?") {
fullPath = basePath + "&key=" + apiKey
} else {
fullPath = basePath + "?key=" + apiKey
}
return fullPath
}
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
return path
}
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {