mirror of
https://github.com/alibaba/higress.git
synced 2026-06-07 11:47:30 +08:00
feat(vertex): 为 ai-proxy 插件的 Vertex AI Provider 添加 Express Mode 支持 || feat(vertex): Add Express Mode support to Vertex AI Provider of ai-proxy plug-in (#3301)
This commit is contained in:
@@ -27,8 +27,11 @@ const (
|
||||
vertexAuthDomain = "oauth2.googleapis.com"
|
||||
vertexDomain = "aiplatform.googleapis.com"
|
||||
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
|
||||
// Express Mode 路径模板 (不含 project/location)
|
||||
vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s"
|
||||
vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s"
|
||||
vertexChatCompletionAction = "generateContent"
|
||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||
vertexAnthropicMessageAction = "rawPredict"
|
||||
@@ -42,6 +45,13 @@ const (
|
||||
type vertexProviderInitializer struct{}
|
||||
|
||||
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
// Express Mode: 如果配置了 apiTokens,则使用 API Key 认证
|
||||
if len(config.apiTokens) > 0 {
|
||||
// Express Mode 不需要其他配置
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标准模式: 保持原有验证逻辑
|
||||
if config.vertexAuthKey == "" {
|
||||
return errors.New("missing vertexAuthKey in vertex provider config")
|
||||
}
|
||||
@@ -63,19 +73,32 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
|
||||
func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(v.DefaultCapabilities())
|
||||
return &vertexProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
Domain: vertexAuthDomain,
|
||||
ServiceName: config.vertexAuthServiceName,
|
||||
Port: 443,
|
||||
}),
|
||||
|
||||
provider := &vertexProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
claude: &claudeProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 仅标准模式需要 OAuth 客户端(Express Mode 通过 apiTokens 配置)
|
||||
if !provider.isExpressMode() {
|
||||
provider.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
Domain: vertexAuthDomain,
|
||||
ServiceName: config.vertexAuthServiceName,
|
||||
Port: 443,
|
||||
})
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// isExpressMode 检测是否启用 Express Mode
|
||||
// 如果配置了 apiTokens,则使用 Express Mode(API Key 认证)
|
||||
func (v *vertexProvider) isExpressMode() bool {
|
||||
return len(v.config.apiTokens) > 0
|
||||
}
|
||||
|
||||
type vertexProvider struct {
|
||||
@@ -106,11 +129,19 @@ func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
|
||||
func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
var finalVertexDomain string
|
||||
if v.config.vertexRegion != vertexGlobalRegion {
|
||||
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
|
||||
} else {
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 固定域名,不带 region 前缀
|
||||
finalVertexDomain = vertexDomain
|
||||
} else {
|
||||
// 标准模式: 带 region 前缀
|
||||
if v.config.vertexRegion != vertexGlobalRegion {
|
||||
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
|
||||
} else {
|
||||
finalVertexDomain = vertexDomain
|
||||
}
|
||||
}
|
||||
|
||||
util.OverwriteRequestHostHeader(headers, finalVertexDomain)
|
||||
}
|
||||
|
||||
@@ -156,6 +187,16 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
headers := util.GetRequestHeaders()
|
||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 不需要 Authorization header,API Key 已在 URL 中
|
||||
headers.Del("Authorization")
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// 标准模式: 需要获取 OAuth token
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
@@ -422,7 +463,23 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
|
||||
} else {
|
||||
action = vertexAnthropicMessageAction
|
||||
}
|
||||
return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?,使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
@@ -434,7 +491,23 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
} else {
|
||||
action = vertexChatCompletionAction
|
||||
}
|
||||
return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse),使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
|
||||
|
||||
Reference in New Issue
Block a user