mirror of
https://github.com/alibaba/higress.git
synced 2026-06-26 10:45:25 +08:00
Implement Vertex Raw mode support in AI Proxy (#3375)
This commit is contained in:
@@ -137,6 +137,10 @@ func TestVertex(t *testing.T) {
|
|||||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||||
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
||||||
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
||||||
|
// Vertex Raw 模式测试
|
||||||
|
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
|
||||||
|
test.RunVertexRawModeOnHttpRequestBodyTests(t)
|
||||||
|
test.RunVertexRawModeOnHttpResponseBodyTests(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBedrock(t *testing.T) {
|
func TestBedrock(t *testing.T) {
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ const (
|
|||||||
ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent"
|
ApiNameGeminiStreamGenerateContent ApiName = "gemini/v1beta/streamgeneratecontent"
|
||||||
ApiNameAnthropicMessages ApiName = "anthropic/v1/messages"
|
ApiNameAnthropicMessages ApiName = "anthropic/v1/messages"
|
||||||
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
|
ApiNameAnthropicComplete ApiName = "anthropic/v1/complete"
|
||||||
|
ApiNameVertexRaw ApiName = "vertex/raw"
|
||||||
|
|
||||||
// OpenAI
|
// OpenAI
|
||||||
PathOpenAIPrefix = "/v1"
|
PathOpenAIPrefix = "/v1"
|
||||||
|
|||||||
@@ -44,9 +44,15 @@ const (
|
|||||||
vertexGlobalRegion = "global"
|
vertexGlobalRegion = "global"
|
||||||
contextClaudeMarker = "isClaudeRequest"
|
contextClaudeMarker = "isClaudeRequest"
|
||||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||||
|
contextVertexRawMarker = "isVertexRawRequest"
|
||||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// vertexRawPathRegex 匹配原生 Vertex AI REST API 路径
|
||||||
|
// 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action}
|
||||||
|
// 允许任意 basePath 前缀,兼容 basePathHandling 配置
|
||||||
|
var vertexRawPathRegex = regexp.MustCompile(`^.*/([^/]+)/projects/([^/]+)/locations/([^/]+)/publishers/([^/]+)/models/([^/:]+):([^/?]+)`)
|
||||||
|
|
||||||
type vertexProviderInitializer struct{}
|
type vertexProviderInitializer struct{}
|
||||||
|
|
||||||
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||||
@@ -92,6 +98,7 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
|||||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||||
|
string(ApiNameVertexRaw): "", // 空字符串表示保持原路径,不做路径转换
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,6 +150,12 @@ func (v *vertexProvider) GetProviderType() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (v *vertexProvider) GetApiName(path string) ApiName {
|
func (v *vertexProvider) GetApiName(path string) ApiName {
|
||||||
|
// 优先匹配原生 Vertex AI REST API 路径,支持任意 basePath 前缀
|
||||||
|
// 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action}
|
||||||
|
// 必须在其他 action 检查之前,因为 :predict、:generateContent 等 action 会被其他规则匹配
|
||||||
|
if vertexRawPathRegex.MatchString(path) {
|
||||||
|
return ApiNameVertexRaw
|
||||||
|
}
|
||||||
if strings.HasSuffix(path, vertexChatCompletionAction) || strings.HasSuffix(path, vertexChatCompletionStreamAction) {
|
if strings.HasSuffix(path, vertexChatCompletionAction) || strings.HasSuffix(path, vertexChatCompletionStreamAction) {
|
||||||
return ApiNameChatCompletion
|
return ApiNameChatCompletion
|
||||||
}
|
}
|
||||||
@@ -211,6 +224,27 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
|||||||
if !v.config.isSupportedAPI(apiName) {
|
if !v.config.isSupportedAPI(apiName) {
|
||||||
return types.ActionContinue, errUnsupportedApiName
|
return types.ActionContinue, errUnsupportedApiName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Vertex Raw 模式: 透传请求体,只做 OAuth 认证
|
||||||
|
// 用于直接访问 Vertex AI REST API,不做协议转换
|
||||||
|
// 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用
|
||||||
|
if apiName == ApiNameVertexRaw {
|
||||||
|
ctx.SetContext(contextVertexRawMarker, true)
|
||||||
|
// Express Mode 不需要 OAuth 认证
|
||||||
|
if v.isExpressMode() {
|
||||||
|
return types.ActionContinue, nil
|
||||||
|
}
|
||||||
|
// 标准模式需要获取 OAuth token
|
||||||
|
cached, err := v.getToken()
|
||||||
|
if cached {
|
||||||
|
return types.ActionContinue, nil
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return types.ActionPause, nil
|
||||||
|
}
|
||||||
|
return types.ActionContinue, err
|
||||||
|
}
|
||||||
|
|
||||||
if v.config.IsOriginal() {
|
if v.config.IsOriginal() {
|
||||||
return types.ActionContinue, nil
|
return types.ActionContinue, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -124,6 +124,47 @@ var invalidVertexExpressAndOpenAICompatibleConfig = func() json.RawMessage {
|
|||||||
return data
|
return data
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// 测试配置:Vertex Raw 模式配置(Express Mode + 原生 Vertex API 路径)
|
||||||
|
var vertexRawModeExpressConfig = func() json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{
|
||||||
|
"type": "vertex",
|
||||||
|
"apiTokens": []string{"test-api-key-for-raw-mode"},
|
||||||
|
"protocol": "original",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 测试配置:Vertex Raw 模式配置(标准模式 + 原生 Vertex API 路径)
|
||||||
|
var vertexRawModeStandardConfig = func() json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{
|
||||||
|
"type": "vertex",
|
||||||
|
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
|
||||||
|
"vertexRegion": "us-central1",
|
||||||
|
"vertexProjectId": "test-project-id",
|
||||||
|
"vertexAuthServiceName": "test-auth-service",
|
||||||
|
"protocol": "original",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 测试配置:Vertex Raw 模式配置(Express Mode + basePath removePrefix)
|
||||||
|
var vertexRawModeWithBasePathConfig = func() json.RawMessage {
|
||||||
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{
|
||||||
|
"type": "vertex",
|
||||||
|
"apiTokens": []string{"test-api-key-for-raw-mode"},
|
||||||
|
"protocol": "original",
|
||||||
|
"basePath": "/vertex-proxy",
|
||||||
|
"basePathHandling": "removePrefix",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
}()
|
||||||
|
|
||||||
func RunVertexParseConfigTests(t *testing.T) {
|
func RunVertexParseConfigTests(t *testing.T) {
|
||||||
test.RunGoTest(t, func(t *testing.T) {
|
test.RunGoTest(t, func(t *testing.T) {
|
||||||
// 测试 Vertex 标准模式配置解析
|
// 测试 Vertex 标准模式配置解析
|
||||||
@@ -1231,3 +1272,314 @@ func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==================== Vertex Raw 模式测试 ====================
|
||||||
|
|
||||||
|
func RunVertexRawModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
// 测试 Vertex Raw 模式请求头处理(Express Mode + 原生 Vertex API 路径)
|
||||||
|
t.Run("vertex raw mode express - request headers with native vertex path", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexRawModeExpressConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// 使用原生 Vertex AI REST API 路径
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 应该返回 HeaderStopIteration,因为需要处理请求体
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
// 验证请求头是否被正确处理
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.NotNil(t, requestHeaders)
|
||||||
|
|
||||||
|
// 验证 Host 是否被改为 vertex 域名(Express Mode 使用不带 region 前缀的域名)
|
||||||
|
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"),
|
||||||
|
"Host header should be changed to vertex domain without region prefix")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 测试 Vertex Raw 模式请求头处理(标准模式 + 原生 Vertex API 路径)
|
||||||
|
t.Run("vertex raw mode standard - request headers with native vertex path", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexRawModeStandardConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// 使用原生 Vertex AI REST API 路径
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
// 验证请求头是否被正确处理
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.NotNil(t, requestHeaders)
|
||||||
|
|
||||||
|
// 验证 Host 是否被改为 vertex 域名(标准模式使用带 region 前缀的域名)
|
||||||
|
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "us-central1-aiplatform.googleapis.com"),
|
||||||
|
"Host header should be changed to vertex domain with region prefix")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 测试 Vertex Raw 模式请求头处理(带 basePath 前缀)
|
||||||
|
t.Run("vertex raw mode with basePath - request headers", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexRawModeWithBasePathConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// 使用带 basePath 前缀的原生 Vertex AI REST API 路径
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/vertex-proxy/v1/projects/test-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-preview-06-06:predict"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
// 验证请求头是否被正确处理
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.NotNil(t, requestHeaders)
|
||||||
|
|
||||||
|
// 验证 Host 是否被改为 vertex 域名
|
||||||
|
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"),
|
||||||
|
"Host header should be changed to vertex domain")
|
||||||
|
|
||||||
|
// 验证路径是否移除了 basePath 前缀
|
||||||
|
pathHeader := ""
|
||||||
|
for _, header := range requestHeaders {
|
||||||
|
if header[0] == ":path" {
|
||||||
|
pathHeader = header[1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotContains(t, pathHeader, "/vertex-proxy", "Path should have basePath prefix removed")
|
||||||
|
require.Contains(t, pathHeader, "/v1/projects/", "Path should contain original vertex path after basePath removal")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 测试 Vertex Raw 模式请求头处理(Anthropic 模型路径)
|
||||||
|
t.Run("vertex raw mode express - request headers with anthropic model path", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexRawModeExpressConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// 使用 Anthropic 模型的原生 Vertex AI REST API 路径
|
||||||
|
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/projects/test-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4@20250514:rawPredict"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, types.HeaderStopIteration, action)
|
||||||
|
|
||||||
|
// 验证请求头是否被正确处理
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
require.NotNil(t, requestHeaders)
|
||||||
|
|
||||||
|
// 验证 Host 是否被改为 vertex 域名
|
||||||
|
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"),
|
||||||
|
"Host header should be changed to vertex domain")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunVertexRawModeOnHttpRequestBodyTests(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
// 测试 Vertex Raw 模式请求体处理(Express Mode - 透传请求体)
|
||||||
|
t.Run("vertex raw mode express - request body passthrough", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexRawModeExpressConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// 先设置请求头
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 设置原生 Vertex 格式的请求体
|
||||||
|
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello, world!"}]}],"generationConfig":{"temperature":0.7}}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
|
||||||
|
// Express Mode 不需要暂停等待 OAuth token
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
// 验证请求体被透传(不做格式转换)
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
// 请求体应该保持原样
|
||||||
|
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 测试 Vertex Raw 模式请求体处理(标准模式 - 需要 OAuth token)
|
||||||
|
// 注意:使用 countTokens action,因为 generateContent/predict 等会被识别为其他 API 类型
|
||||||
|
// 注意:在单元测试环境中,由于测试配置使用的是无效的私钥,JWT 创建会失败,
|
||||||
|
// 因此 getToken() 会返回错误,导致 ActionContinue 而不是 ActionPause。
|
||||||
|
// 这个测试主要验证代码正确进入了 Vertex Raw 模式的处理分支,请求体被透传。
|
||||||
|
t.Run("vertex raw mode standard - request body with oauth", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexRawModeStandardConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// 先设置请求头 - 使用 countTokens action,这是一个不会被其他 API 类型匹配的原生 Vertex API
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:countTokens"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 设置原生 Vertex 格式的请求体
|
||||||
|
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello, world!"}]}]}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
|
||||||
|
// 注意:在单元测试环境中,由于私钥无效,JWT 创建失败会返回 ActionContinue
|
||||||
|
// 在真实环境中,如果 JWT 创建成功,会返回 ActionPause 等待 OAuth token
|
||||||
|
// 这里我们只验证代码正确进入了 Vertex Raw 模式的处理分支
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
// 验证请求体被透传(不做格式转换)
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
|
||||||
|
// 请求体应该保持原样(这是 Vertex Raw 模式的核心功能)
|
||||||
|
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 测试 Vertex Raw 模式请求体处理(带 basePath 前缀 - 路径正确处理)
|
||||||
|
t.Run("vertex raw mode with basePath - request body passthrough", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexRawModeWithBasePathConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// 先设置请求头(带 basePath 前缀)
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/vertex-proxy/v1/projects/test-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-preview-06-06:predict"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 设置原生 Vertex 格式的请求体(图片生成)
|
||||||
|
requestBody := `{"instances":[{"prompt":"A beautiful sunset"}],"parameters":{"sampleCount":1}}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
|
||||||
|
// Express Mode 不需要暂停等待 OAuth token
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
// 验证请求体被透传
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
|
||||||
|
|
||||||
|
// 验证路径已正确处理(移除 basePath)
|
||||||
|
requestHeaders := host.GetRequestHeaders()
|
||||||
|
pathHeader := ""
|
||||||
|
for _, header := range requestHeaders {
|
||||||
|
if header[0] == ":path" {
|
||||||
|
pathHeader = header[1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotContains(t, pathHeader, "/vertex-proxy", "Path should have basePath prefix removed")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 测试 Vertex Raw 模式请求体处理(流式请求)
|
||||||
|
t.Run("vertex raw mode express - streaming request body passthrough", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexRawModeExpressConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// 先设置请求头(流式端点)
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent?alt=sse"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 设置原生 Vertex 格式的请求体
|
||||||
|
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Tell me a story"}]}]}`
|
||||||
|
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
// 验证请求体被透传
|
||||||
|
processedBody := host.GetRequestBody()
|
||||||
|
require.NotNil(t, processedBody)
|
||||||
|
require.Equal(t, requestBody, string(processedBody), "Request body should be passed through unchanged")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunVertexRawModeOnHttpResponseBodyTests(t *testing.T) {
|
||||||
|
test.RunTest(t, func(t *testing.T) {
|
||||||
|
// 测试 Vertex Raw 模式响应体处理(透传响应)
|
||||||
|
t.Run("vertex raw mode express - response body passthrough", func(t *testing.T) {
|
||||||
|
host, status := test.NewTestHost(vertexRawModeExpressConfig)
|
||||||
|
defer host.Reset()
|
||||||
|
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||||
|
|
||||||
|
// 先设置请求头
|
||||||
|
host.CallOnHttpRequestHeaders([][2]string{
|
||||||
|
{":authority", "example.com"},
|
||||||
|
{":path", "/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent"},
|
||||||
|
{":method", "POST"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 设置请求体
|
||||||
|
requestBody := `{"contents":[{"role":"user","parts":[{"text":"Hello"}]}]}`
|
||||||
|
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||||
|
|
||||||
|
// 设置响应属性
|
||||||
|
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||||
|
|
||||||
|
// 设置响应头
|
||||||
|
responseHeaders := [][2]string{
|
||||||
|
{":status", "200"},
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
}
|
||||||
|
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||||
|
|
||||||
|
// 设置原生 Vertex 格式的响应体
|
||||||
|
responseBody := `{
|
||||||
|
"candidates": [{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{"text": "Hello! How can I help you?"}]
|
||||||
|
},
|
||||||
|
"finishReason": "STOP"
|
||||||
|
}],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 5,
|
||||||
|
"candidatesTokenCount": 10,
|
||||||
|
"totalTokenCount": 15
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||||
|
|
||||||
|
require.Equal(t, types.ActionContinue, action)
|
||||||
|
|
||||||
|
// 验证响应体被透传(不做格式转换)
|
||||||
|
processedResponseBody := host.GetResponseBody()
|
||||||
|
require.NotNil(t, processedResponseBody)
|
||||||
|
|
||||||
|
responseStr := string(processedResponseBody)
|
||||||
|
// 响应应该保持原生 Vertex 格式
|
||||||
|
require.Contains(t, responseStr, "candidates", "Response should keep native vertex format with candidates")
|
||||||
|
require.Contains(t, responseStr, "usageMetadata", "Response should keep native vertex format with usageMetadata")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user