mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
feat(vertex): 为 ai-proxy 插件的 Vertex AI Provider 添加 Express Mode 支持 || feat(vertex): Add Express Mode support to Vertex AI Provider of ai-proxy plug-in (#3301)
This commit is contained in:
@@ -309,7 +309,9 @@ Dify 所对应的 `type` 为 `dify`。它特有的配置字段如下:
|
||||
|
||||
#### Google Vertex AI
|
||||
|
||||
Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下:
|
||||
Google Vertex AI 所对应的 type 为 vertex。支持两种认证模式:
|
||||
|
||||
**标准模式**(使用 Service Account):
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
@@ -320,6 +322,15 @@ Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
| `vertexTokenRefreshAhead` | number | 非必填 | - | Vertex access token刷新提前时间(单位秒) |
|
||||
|
||||
**Express Mode**(使用 API Key,简化配置):
|
||||
|
||||
Express Mode 是 Vertex AI 推出的简化访问模式,只需 API Key 即可快速开始使用,无需配置 Service Account。详见 [Vertex AI Express Mode 文档](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview)。
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
| `apiTokens` | array of string | 必填 | - | Express Mode 使用的 API Key,从 Google Cloud Console 的 API & Services > Credentials 获取 |
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
|
||||
#### AWS Bedrock
|
||||
|
||||
AWS Bedrock 所对应的 type 为 bedrock。它支持两种认证方式:
|
||||
@@ -1955,7 +1966,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务(标准模式)
|
||||
|
||||
**配置信息**
|
||||
|
||||
@@ -2017,6 +2028,60 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务(Express Mode)
|
||||
|
||||
Express Mode 是 Vertex AI 的简化访问模式,只需 API Key 即可快速开始使用。
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-flash",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是 Gemini,由 Google 开发的人工智能助手。有什么我可以帮您的吗?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.5-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 25,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 AWS Bedrock 服务
|
||||
|
||||
AWS Bedrock 支持两种认证方式:
|
||||
|
||||
@@ -255,7 +255,9 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i
|
||||
| `targetLang` | string | Required | - | The target language required by the DeepL translation service |
|
||||
|
||||
#### Google Vertex AI
|
||||
For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is:
|
||||
For Vertex, the corresponding `type` is `vertex`. It supports two authentication modes:
|
||||
|
||||
**Standard Mode** (using Service Account):
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|---------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
@@ -266,6 +268,15 @@ For Vertex, the corresponding `type` is `vertex`. Its unique configuration field
|
||||
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
|
||||
| `vertexTokenRefreshAhead` | number | Optional | - | Vertex access token refresh ahead time in seconds |
|
||||
|
||||
**Express Mode** (using API Key, simplified configuration):
|
||||
|
||||
Express Mode is a simplified access mode introduced by Vertex AI. You can quickly get started with just an API Key, without configuring a Service Account. See [Vertex AI Express Mode documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview).
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `apiTokens` | array of string | Required | - | API Key for Express Mode, obtained from Google Cloud Console under API & Services > Credentials |
|
||||
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
|
||||
|
||||
#### AWS Bedrock
|
||||
|
||||
For AWS Bedrock, the corresponding `type` is `bedrock`. It supports two authentication methods:
|
||||
@@ -1728,7 +1739,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Standard Mode)
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
@@ -1786,6 +1797,57 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Express Mode)
|
||||
|
||||
Express Mode is a simplified access mode for Vertex AI. You only need an API Key to get started quickly.
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
```
|
||||
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-flash",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I am Gemini, an AI assistant developed by Google. How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.5-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 25,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services
|
||||
|
||||
AWS Bedrock supports two authentication methods:
|
||||
|
||||
@@ -129,6 +129,14 @@ func TestGeneric(t *testing.T) {
|
||||
test.RunGenericOnHttpRequestBodyTests(t)
|
||||
}
|
||||
|
||||
func TestVertex(t *testing.T) {
|
||||
test.RunVertexParseConfigTests(t)
|
||||
test.RunVertexExpressModeOnHttpRequestHeadersTests(t)
|
||||
test.RunVertexExpressModeOnHttpRequestBodyTests(t)
|
||||
test.RunVertexExpressModeOnHttpResponseBodyTests(t)
|
||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||
}
|
||||
|
||||
func TestBedrock(t *testing.T) {
|
||||
test.RunBedrockParseConfigTests(t)
|
||||
test.RunBedrockOnHttpRequestHeadersTests(t)
|
||||
|
||||
@@ -27,8 +27,11 @@ const (
|
||||
vertexAuthDomain = "oauth2.googleapis.com"
|
||||
vertexDomain = "aiplatform.googleapis.com"
|
||||
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
|
||||
// Express Mode 路径模板 (不含 project/location)
|
||||
vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s"
|
||||
vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s"
|
||||
vertexChatCompletionAction = "generateContent"
|
||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||
vertexAnthropicMessageAction = "rawPredict"
|
||||
@@ -42,6 +45,13 @@ const (
|
||||
type vertexProviderInitializer struct{}
|
||||
|
||||
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
// Express Mode: 如果配置了 apiTokens,则使用 API Key 认证
|
||||
if len(config.apiTokens) > 0 {
|
||||
// Express Mode 不需要其他配置
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标准模式: 保持原有验证逻辑
|
||||
if config.vertexAuthKey == "" {
|
||||
return errors.New("missing vertexAuthKey in vertex provider config")
|
||||
}
|
||||
@@ -63,19 +73,32 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
|
||||
func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(v.DefaultCapabilities())
|
||||
return &vertexProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
Domain: vertexAuthDomain,
|
||||
ServiceName: config.vertexAuthServiceName,
|
||||
Port: 443,
|
||||
}),
|
||||
|
||||
provider := &vertexProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
claude: &claudeProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 仅标准模式需要 OAuth 客户端(Express Mode 通过 apiTokens 配置)
|
||||
if !provider.isExpressMode() {
|
||||
provider.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
Domain: vertexAuthDomain,
|
||||
ServiceName: config.vertexAuthServiceName,
|
||||
Port: 443,
|
||||
})
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// isExpressMode 检测是否启用 Express Mode
|
||||
// 如果配置了 apiTokens,则使用 Express Mode(API Key 认证)
|
||||
func (v *vertexProvider) isExpressMode() bool {
|
||||
return len(v.config.apiTokens) > 0
|
||||
}
|
||||
|
||||
type vertexProvider struct {
|
||||
@@ -106,11 +129,19 @@ func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
|
||||
func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
var finalVertexDomain string
|
||||
if v.config.vertexRegion != vertexGlobalRegion {
|
||||
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
|
||||
} else {
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 固定域名,不带 region 前缀
|
||||
finalVertexDomain = vertexDomain
|
||||
} else {
|
||||
// 标准模式: 带 region 前缀
|
||||
if v.config.vertexRegion != vertexGlobalRegion {
|
||||
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
|
||||
} else {
|
||||
finalVertexDomain = vertexDomain
|
||||
}
|
||||
}
|
||||
|
||||
util.OverwriteRequestHostHeader(headers, finalVertexDomain)
|
||||
}
|
||||
|
||||
@@ -156,6 +187,16 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
headers := util.GetRequestHeaders()
|
||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 不需要 Authorization header,API Key 已在 URL 中
|
||||
headers.Del("Authorization")
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// 标准模式: 需要获取 OAuth token
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
@@ -422,7 +463,23 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
|
||||
} else {
|
||||
action = vertexAnthropicMessageAction
|
||||
}
|
||||
return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?,使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
@@ -434,7 +491,23 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
} else {
|
||||
action = vertexChatCompletionAction
|
||||
}
|
||||
return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse),使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
|
||||
|
||||
499
plugins/wasm-go/extensions/ai-proxy/test/vertex.go
Normal file
499
plugins/wasm-go/extensions/ai-proxy/test/vertex.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:Vertex 标准模式配置
|
||||
var basicVertexConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
|
||||
"vertexRegion": "us-central1",
|
||||
"vertexProjectId": "test-project-id",
|
||||
"vertexAuthServiceName": "test-auth-service",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex Express Mode 配置(使用 apiTokens)
|
||||
var vertexExpressModeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"apiTokens": []string{"test-api-key-123456789"},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex Express Mode 配置(含模型映射)
|
||||
var vertexExpressModeWithModelMappingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"apiTokens": []string{"test-api-key-123456789"},
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-4": "gemini-2.5-flash",
|
||||
"gpt-3.5-turbo": "gemini-2.5-flash-lite",
|
||||
"text-embedding-ada-002": "text-embedding-001",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex Express Mode 配置(含安全设置)
|
||||
var vertexExpressModeWithSafetyConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"apiTokens": []string{"test-api-key-123456789"},
|
||||
"geminiSafetySetting": map[string]string{
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_LOW_AND_ABOVE",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:无效 Vertex 标准模式配置(缺少 vertexAuthKey)
|
||||
var invalidVertexStandardModeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
// 缺少必需的标准模式配置
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunVertexParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex 标准模式配置解析
|
||||
t.Run("vertex standard mode config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicVertexConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 配置解析
|
||||
t.Run("vertex express mode config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 配置(含模型映射)
|
||||
t.Run("vertex express mode with model mapping config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试无效 Vertex 标准模式配置(缺少 vertexAuthKey)
|
||||
t.Run("invalid vertex standard mode config - missing auth key", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(invalidVertexStandardModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 配置(含安全设置)
|
||||
t.Run("vertex express mode with safety setting config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithSafetyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 请求头处理(聊天完成接口)
|
||||
t.Run("vertex express mode chat completion request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回HeaderStopIteration,因为需要处理请求体
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证请求头是否被正确处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host是否被改为 vertex 域名(Express Mode 使用不带 region 前缀的域名)
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), "Host header should be changed to vertex domain without region prefix")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasVertexLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "vertex") {
|
||||
hasVertexLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasVertexLogs, "Should have vertex processing logs")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求头处理(嵌入接口)
|
||||
t.Run("vertex express mode embeddings request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证嵌入接口的请求头处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host转换
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), "Host header should be changed to vertex domain")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 请求体处理(聊天完成接口)
|
||||
t.Run("vertex express mode chat completion request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// Express Mode 不需要暂停等待 OAuth token
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体是否被正确处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证请求体被转换为 Vertex 格式
|
||||
require.Contains(t, string(processedBody), "contents", "Request should be converted to vertex format")
|
||||
require.Contains(t, string(processedBody), "generationConfig", "Request should contain vertex generation config")
|
||||
|
||||
// 验证路径包含 API Key
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter")
|
||||
require.Contains(t, pathHeader, "/v1/publishers/google/models/", "Path should use Express Mode format without project/location")
|
||||
|
||||
// 验证没有 Authorization header(Express Mode 使用 URL 参数)
|
||||
hasAuthHeader := false
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == "Authorization" && header[1] != "" {
|
||||
hasAuthHeader = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.False(t, hasAuthHeader, "Authorization header should be removed in Express Mode")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasVertexLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "vertex") {
|
||||
hasVertexLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasVertexLogs, "Should have vertex processing logs")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求体处理(嵌入接口)
|
||||
t.Run("vertex express mode embeddings request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"text-embedding-001","input":"test text"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证嵌入接口的请求体处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证请求体被转换为 Vertex 格式
|
||||
require.Contains(t, string(processedBody), "instances", "Request should be converted to vertex format")
|
||||
|
||||
// 验证路径包含 API Key
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求体处理(流式请求)
|
||||
t.Run("vertex express mode streaming request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置流式请求体
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证路径包含流式 action
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "streamGenerateContent", "Path should contain streaming action")
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求体处理(含模型映射)
|
||||
t.Run("vertex express mode with model mapping request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体(使用 OpenAI 模型名)
|
||||
requestBody := `{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证路径包含映射后的模型名
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeOnHttpResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 响应体处理(聊天完成接口)
|
||||
t.Run("vertex express mode chat completion response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性,确保IsResponseFromUpstream()返回true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体(Vertex 格式)
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{
|
||||
"text": "Hello! How can I help you today?"
|
||||
}]
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"index": 0
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 9,
|
||||
"candidatesTokenCount": 12,
|
||||
"totalTokenCount": 21
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体是否被正确处理
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
// 验证响应体内容(转换为OpenAI格式)
|
||||
responseStr := string(processedResponseBody)
|
||||
|
||||
// 检查响应体是否被转换
|
||||
if strings.Contains(responseStr, "chat.completion") {
|
||||
require.Contains(t, responseStr, "assistant", "Response should contain assistant role")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage information")
|
||||
}
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasResponseBodyLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "response") || strings.Contains(log, "body") || strings.Contains(log, "vertex") {
|
||||
hasResponseBodyLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasResponseBodyLogs, "Should have response body processing logs")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 流式响应处理
|
||||
t.Run("vertex express mode streaming response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置流式请求体
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置流式响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 模拟流式响应体
|
||||
chunk1 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":5,"totalTokenCount":14}}`
|
||||
chunk2 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello! How can I help you today?"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":12,"totalTokenCount":21}}`
|
||||
|
||||
// 处理流式响应体
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
// 验证流式响应处理
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasStreamingLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "vertex") {
|
||||
hasStreamingLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasStreamingLogs, "Should have streaming response processing logs")
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user