diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 9c3c8bf91..a62172372 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -331,6 +331,20 @@ Express Mode 是 Vertex AI 推出的简化访问模式,只需 API Key 即可 | `apiTokens` | array of string | 必填 | - | Express Mode 使用的 API Key,从 Google Cloud Console 的 API & Services > Credentials 获取 | | `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) | +**OpenAI 兼容模式**(使用 Vertex AI Chat Completions API): + +Vertex AI 提供了 OpenAI 兼容的 Chat Completions API 端点,可以直接使用 OpenAI 格式的请求和响应,无需进行协议转换。详见 [Vertex AI OpenAI 兼容性文档](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview)。 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------| +| `vertexOpenAICompatible` | boolean | 非必填 | false | 启用 OpenAI 兼容模式。启用后将使用 Vertex AI 的 OpenAI-compatible Chat Completions API | +| `vertexAuthKey` | string | 必填 | - | 用于认证的 Google Service Account JSON Key | +| `vertexRegion` | string | 必填 | - | Google Cloud 区域(如 us-central1, europe-west4 等) | +| `vertexProjectId` | string | 必填 | - | Google Cloud 项目 ID | +| `vertexAuthServiceName` | string | 必填 | - | 用于 OAuth2 认证的服务名称 | + +**注意**:OpenAI 兼容模式与 Express Mode 互斥,不能同时配置 `apiTokens` 和 `vertexOpenAICompatible`。 + #### AWS Bedrock AWS Bedrock 所对应的 type 为 bedrock。它支持两种认证方式: @@ -2082,6 +2096,74 @@ provider: } ``` +### 使用 OpenAI 协议代理 Google Vertex 服务(OpenAI 兼容模式) + +OpenAI 兼容模式使用 Vertex AI 的 OpenAI-compatible Chat Completions API,请求和响应都使用 OpenAI 格式,无需进行协议转换。 + +**配置信息** + +```yaml +provider: + type: vertex + vertexOpenAICompatible: true + vertexAuthKey: | + { + "type": "service_account", + "project_id": "your-project-id", + "private_key_id": "your-private-key-id", + "private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n", + "client_email": "your-service-account@your-project.iam.gserviceaccount.com", + "token_uri": "https://oauth2.googleapis.com/token" + } + vertexRegion: us-central1 + vertexProjectId: your-project-id + vertexAuthServiceName: your-auth-service-name + modelMapping: + "gpt-4": "gemini-2.0-flash" + "*": "gemini-1.5-flash" +``` + +**请求示例** + +```json +{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "你好,你是谁?" + } + ], + "stream": false +} +``` + +**响应示例** + +```json +{ + "id": "chatcmpl-abc123", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "你好!我是由 Google 开发的 Gemini 模型。我可以帮助回答问题、提供信息和进行对话。有什么我可以帮您的吗?" + }, + "finish_reason": "stop" + } + ], + "created": 1729986750, + "model": "gemini-2.0-flash", + "object": "chat.completion", + "usage": { + "prompt_tokens": 12, + "completion_tokens": 35, + "total_tokens": 47 + } +} +``` + ### 使用 OpenAI 协议代理 AWS Bedrock 服务 AWS Bedrock 支持两种认证方式: diff --git a/plugins/wasm-go/extensions/ai-proxy/README_EN.md b/plugins/wasm-go/extensions/ai-proxy/README_EN.md index 83583972c..b1831f9dc 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README_EN.md +++ b/plugins/wasm-go/extensions/ai-proxy/README_EN.md @@ -277,6 +277,20 @@ Express Mode is a simplified access mode introduced by Vertex AI. You can quickl | `apiTokens` | array of string | Required | - | API Key for Express Mode, obtained from Google Cloud Console under API & Services > Credentials | | `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. | +**OpenAI Compatible Mode** (using Vertex AI Chat Completions API): + +Vertex AI provides an OpenAI-compatible Chat Completions API endpoint, allowing you to use OpenAI format requests and responses directly without protocol conversion. See [Vertex AI OpenAI Compatibility documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview). + +| Name | Data Type | Requirement | Default | Description | +|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `vertexOpenAICompatible` | boolean | Optional | false | Enable OpenAI compatible mode. When enabled, uses Vertex AI's OpenAI-compatible Chat Completions API | +| `vertexAuthKey` | string | Required | - | Google Service Account JSON Key for authentication | +| `vertexRegion` | string | Required | - | Google Cloud region (e.g., us-central1, europe-west4) | +| `vertexProjectId` | string | Required | - | Google Cloud Project ID | +| `vertexAuthServiceName` | string | Required | - | Service name for OAuth2 authentication | + +**Note**: OpenAI Compatible Mode and Express Mode are mutually exclusive. You cannot configure both `apiTokens` and `vertexOpenAICompatible` at the same time. + #### AWS Bedrock For AWS Bedrock, the corresponding `type` is `bedrock`. It supports two authentication methods: @@ -1848,6 +1862,71 @@ provider: } ``` +### Utilizing OpenAI Protocol Proxy for Google Vertex Services (OpenAI Compatible Mode) + +OpenAI Compatible Mode uses Vertex AI's OpenAI-compatible Chat Completions API. Both requests and responses use OpenAI format, requiring no protocol conversion. + +**Configuration Information** +```yaml +provider: + type: vertex + vertexOpenAICompatible: true + vertexAuthKey: | + { + "type": "service_account", + "project_id": "your-project-id", + "private_key_id": "your-private-key-id", + "private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n", + "client_email": "your-service-account@your-project.iam.gserviceaccount.com", + "token_uri": "https://oauth2.googleapis.com/token" + } + vertexRegion: us-central1 + vertexProjectId: your-project-id + vertexAuthServiceName: your-auth-service-name + modelMapping: + "gpt-4": "gemini-2.0-flash" + "*": "gemini-1.5-flash" +``` + +**Request Example** +```json +{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "Hello, who are you?" + } + ], + "stream": false +} +``` + +**Response Example** +```json +{ + "id": "chatcmpl-abc123", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I am Gemini, an AI model developed by Google. I can help answer questions, provide information, and engage in conversations. How can I assist you today?" + }, + "finish_reason": "stop" + } + ], + "created": 1729986750, + "model": "gemini-2.0-flash", + "object": "chat.completion", + "usage": { + "prompt_tokens": 12, + "completion_tokens": 35, + "total_tokens": 47 + } +} +``` + ### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services AWS Bedrock supports two authentication methods: diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index faf7cb125..563664f89 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -387,6 +387,9 @@ type ProviderConfig struct { // @Title zh-CN Vertex token刷新提前时间 // @Description zh-CN 用于Google服务账号认证,access token过期时间判定提前刷新,单位为秒,默认值为60秒 vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"` + // @Title zh-CN Vertex AI OpenAI兼容模式 + // @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API,请求和响应均使用OpenAI格式,无需协议转换。与Express Mode(apiTokens)互斥。 + vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"` // @Title zh-CN 翻译服务需指定的目标语种 // @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。 targetLang string `required:"false" yaml:"targetLang" json:"targetLang"` @@ -540,6 +543,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.vertexTokenRefreshAhead == 0 { c.vertexTokenRefreshAhead = 60 } + c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool() c.targetLang = json.Get("targetLang").String() if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index b947501e0..38bea82f7 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -21,6 +21,7 @@ import ( "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) const ( @@ -32,6 +33,9 @@ const ( // Express Mode 路径模板 (不含 project/location) vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s" vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s" + // OpenAI-compatible endpoint 路径模板 + // /v1beta1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi/chat/completions + vertexOpenAICompatiblePathTemplate = "/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions" vertexChatCompletionAction = "generateContent" vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse" vertexAnthropicMessageAction = "rawPredict" @@ -39,6 +43,7 @@ const ( vertexEmbeddingAction = "predict" vertexGlobalRegion = "global" contextClaudeMarker = "isClaudeRequest" + contextOpenAICompatibleMarker = "isOpenAICompatibleRequest" vertexAnthropicVersion = "vertex-2023-10-16" ) @@ -47,10 +52,28 @@ type vertexProviderInitializer struct{} func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error { // Express Mode: 如果配置了 apiTokens,则使用 API Key 认证 if len(config.apiTokens) > 0 { + // Express Mode 与 OpenAI 兼容模式互斥 + if config.vertexOpenAICompatible { + return errors.New("vertexOpenAICompatible is not compatible with Express Mode (apiTokens)") + } // Express Mode 不需要其他配置 return nil } + // OpenAI 兼容模式: 需要 OAuth 认证配置 + if config.vertexOpenAICompatible { + if config.vertexAuthKey == "" { + return errors.New("missing vertexAuthKey in vertex provider config for OpenAI compatible mode") + } + if config.vertexRegion == "" || config.vertexProjectId == "" { + return errors.New("missing vertexRegion or vertexProjectId in vertex provider config for OpenAI compatible mode") + } + if config.vertexAuthServiceName == "" { + return errors.New("missing vertexAuthServiceName in vertex provider config for OpenAI compatible mode") + } + return nil + } + // 标准模式: 保持原有验证逻辑 if config.vertexAuthKey == "" { return errors.New("missing vertexAuthKey in vertex provider config") @@ -101,6 +124,12 @@ func (v *vertexProvider) isExpressMode() bool { return len(v.config.apiTokens) > 0 } +// isOpenAICompatibleMode 检测是否启用 OpenAI 兼容模式 +// 使用 Vertex AI 的 OpenAI-compatible Chat Completions API +func (v *vertexProvider) isOpenAICompatibleMode() bool { + return v.config.vertexOpenAICompatible +} + type vertexProvider struct { client wrapper.HttpClient config ProviderConfig @@ -184,7 +213,30 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if v.config.IsOriginal() { return types.ActionContinue, nil } + headers := util.GetRequestHeaders() + + // OpenAI 兼容模式: 不转换请求体,只设置路径和进行模型映射 + if v.isOpenAICompatibleMode() { + ctx.SetContext(contextOpenAICompatibleMarker, true) + body, err := v.onOpenAICompatibleRequestBody(ctx, apiName, body, headers) + headers.Set("Content-Length", fmt.Sprint(len(body))) + util.ReplaceRequestHeaders(headers) + _ = proxywasm.ReplaceHttpRequestBody(body) + if err != nil { + return types.ActionContinue, err + } + // OpenAI 兼容模式需要 OAuth token + cached, err := v.getToken() + if cached { + return types.ActionContinue, nil + } + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err + } + body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers) headers.Set("Content-Length", fmt.Sprint(len(body))) @@ -220,6 +272,32 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap } } +// onOpenAICompatibleRequestBody 处理 OpenAI 兼容模式的请求 +// 不转换请求体格式,只进行模型映射和路径设置 +func (v *vertexProvider) onOpenAICompatibleRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return nil, fmt.Errorf("OpenAI compatible mode only supports chat completions API") + } + + // 解析请求进行模型映射 + request := &chatCompletionRequest{} + if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil { + return nil, err + } + + // 设置 OpenAI 兼容端点路径 + path := v.getOpenAICompatibleRequestPath() + util.OverwriteRequestPathHeader(headers, path) + + // 如果模型被映射,需要更新请求体中的模型字段 + if request.Model != "" { + body, _ = sjson.SetBytes(body, "model", request.Model) + } + + // 保持 OpenAI 格式,直接返回(可能更新了模型字段) + return body, nil +} + func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { request := &chatCompletionRequest{} err := v.config.parseRequestAndMapModel(ctx, request, body) @@ -261,6 +339,12 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [ } func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { + // OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列 + // Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX + if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) { + return util.DecodeUnicodeEscapesInSSE(chunk), nil + } + if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) { return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk) } @@ -301,6 +385,12 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A } func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { + // OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列 + // Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX + if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) { + return util.DecodeUnicodeEscapes(body), nil + } + if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) { return v.claude.TransformResponseBody(ctx, apiName, body) } @@ -510,6 +600,11 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream return path } +// getOpenAICompatibleRequestPath 获取 OpenAI 兼容模式的请求路径 +func (v *vertexProvider) getOpenAICompatibleRequestPath() string { + return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion) +} + func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest { safetySettings := make([]vertexChatSafetySetting, 0) for category, threshold := range v.config.geminiSafetySetting { diff --git a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go index e2ba84ade..eb0fabf8c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/vertex.go @@ -78,6 +78,52 @@ var invalidVertexStandardModeConfig = func() json.RawMessage { return data }() +// 测试配置:Vertex OpenAI 兼容模式配置 +var vertexOpenAICompatibleModeConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "vertexOpenAICompatible": true, + "vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`, + "vertexRegion": "us-central1", + "vertexProjectId": "test-project-id", + "vertexAuthServiceName": "test-auth-service", + }, + }) + return data +}() + +// 测试配置:Vertex OpenAI 兼容模式配置(含模型映射) +var vertexOpenAICompatibleModeWithModelMappingConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "vertexOpenAICompatible": true, + "vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`, + "vertexRegion": "us-central1", + "vertexProjectId": "test-project-id", + "vertexAuthServiceName": "test-auth-service", + "modelMapping": map[string]string{ + "gpt-4": "gemini-2.0-flash", + "gpt-3.5-turbo": "gemini-1.5-flash", + }, + }, + }) + return data +}() + +// 测试配置:无效配置 - Express Mode 与 OpenAI 兼容模式互斥 +var invalidVertexExpressAndOpenAICompatibleConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "vertex", + "apiTokens": []string{"test-api-key"}, + "vertexOpenAICompatible": true, + }, + }) + return data +}() + func RunVertexParseConfigTests(t *testing.T) { test.RunGoTest(t, func(t *testing.T) { // 测试 Vertex 标准模式配置解析 @@ -130,6 +176,35 @@ func RunVertexParseConfigTests(t *testing.T) { require.NoError(t, err) require.NotNil(t, config) }) + + // 测试 Vertex OpenAI 兼容模式配置解析 + t.Run("vertex openai compatible mode config", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + // 测试 Vertex OpenAI 兼容模式配置(含模型映射) + t.Run("vertex openai compatible mode with model mapping config", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeWithModelMappingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + // 测试无效配置 - Express Mode 与 OpenAI 兼容模式互斥 + t.Run("invalid config - express mode and openai compatible mode conflict", func(t *testing.T) { + host, status := test.NewTestHost(invalidVertexExpressAndOpenAICompatibleConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusFailed, status) + }) }) } @@ -446,6 +521,131 @@ func RunVertexExpressModeOnHttpResponseBodyTests(t *testing.T) { }) } +func RunVertexOpenAICompatibleModeOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex OpenAI 兼容模式请求头处理 + t.Run("vertex openai compatible mode request headers", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 设置请求头 + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 应该返回HeaderStopIteration,因为需要处理请求体 + require.Equal(t, types.HeaderStopIteration, action) + + // 验证请求头是否被正确处理 + requestHeaders := host.GetRequestHeaders() + require.NotNil(t, requestHeaders) + + // 验证Host是否被改为 vertex 域名(带 region 前缀) + require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "us-central1-aiplatform.googleapis.com"), "Host header should be changed to vertex domain with region prefix") + }) + }) +} + +func RunVertexOpenAICompatibleModeOnHttpRequestBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex OpenAI 兼容模式请求体处理(不转换格式,保持 OpenAI 格式) + t.Run("vertex openai compatible mode request body - no format conversion", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体(OpenAI 格式) + requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + // OpenAI 兼容模式需要等待 OAuth token,所以返回 ActionPause + require.Equal(t, types.ActionPause, action) + + // 验证请求体保持 OpenAI 格式(不转换为 Vertex 原生格式) + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + // OpenAI 兼容模式应该保持 messages 字段,而不是转换为 contents + require.Contains(t, string(processedBody), "messages", "Request should keep OpenAI format with messages field") + require.NotContains(t, string(processedBody), "contents", "Request should NOT be converted to vertex native format") + + // 验证路径为 OpenAI 兼容端点 + requestHeaders := host.GetRequestHeaders() + pathHeader := "" + for _, header := range requestHeaders { + if header[0] == ":path" { + pathHeader = header[1] + break + } + } + require.Contains(t, pathHeader, "/v1beta1/projects/", "Path should use OpenAI compatible endpoint format") + require.Contains(t, pathHeader, "/endpoints/openapi/chat/completions", "Path should contain openapi chat completions endpoint") + }) + + // 测试 Vertex OpenAI 兼容模式请求体处理(含模型映射) + t.Run("vertex openai compatible mode with model mapping", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeWithModelMappingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体(使用 OpenAI 模型名) + requestBody := `{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + require.Equal(t, types.ActionPause, action) + + // 验证请求体中的模型名被映射 + processedBody := host.GetRequestBody() + require.NotNil(t, processedBody) + + // 模型名应该被映射为 gemini-2.0-flash + require.Contains(t, string(processedBody), "gemini-2.0-flash", "Model name should be mapped to gemini-2.0-flash") + }) + + // 测试 Vertex OpenAI 兼容模式不支持 Embeddings API + t.Run("vertex openai compatible mode - embeddings not supported", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/embeddings"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"text-embedding-001","input":"test text"}` + action := host.CallOnHttpRequestBody([]byte(requestBody)) + + // OpenAI 兼容模式只支持 chat completions,embeddings 应该返回错误 + require.Equal(t, types.ActionContinue, action) + }) + }) +} + func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) { test.RunTest(t, func(t *testing.T) { // 测试 Vertex Express Mode 流式响应处理 @@ -497,3 +697,192 @@ func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) { }) }) } + +func RunVertexOpenAICompatibleModeOnHttpResponseBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + // 测试 Vertex OpenAI 兼容模式响应体处理(直接透传,不转换格式) + t.Run("vertex openai compatible mode response body - passthrough", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置响应属性,确保IsResponseFromUpstream()返回true + host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream")) + + // 设置响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 设置响应体(OpenAI 格式 - 因为 Vertex AI OpenAI-compatible API 返回的就是 OpenAI 格式) + responseBody := `{ + "id": "chatcmpl-abc123", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I help you today?" + }, + "finish_reason": "stop" + }], + "created": 1729986750, + "model": "gemini-2.0-flash", + "object": "chat.completion", + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + }` + action := host.CallOnHttpResponseBody([]byte(responseBody)) + + require.Equal(t, types.ActionContinue, action) + + // 验证响应体被直接透传(不进行格式转换) + processedResponseBody := host.GetResponseBody() + require.NotNil(t, processedResponseBody) + + // 响应应该保持原样 + responseStr := string(processedResponseBody) + require.Contains(t, responseStr, "chatcmpl-abc123", "Response should be passed through unchanged") + require.Contains(t, responseStr, "chat.completion", "Response should contain original object type") + }) + + // 测试 Vertex OpenAI 兼容模式流式响应处理(直接透传) + t.Run("vertex openai compatible mode streaming response - passthrough", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置流式请求体 + requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}],"stream":true}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置流式响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "text/event-stream"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 模拟 OpenAI 格式的流式响应(Vertex AI OpenAI-compatible API 返回) + chunk1 := `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}` + chunk2 := `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":"stop"}]}` + + // 处理流式响应体 - 应该直接透传 + action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false) + require.Equal(t, types.ActionContinue, action1) + + action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), true) + require.Equal(t, types.ActionContinue, action2) + }) + + // 测试 Vertex OpenAI 兼容模式流式响应处理(Unicode 转义解码) + t.Run("vertex openai compatible mode streaming response - unicode escape decoding", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置流式请求体 + requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}],"stream":true}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置流式响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "text/event-stream"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 模拟带有 Unicode 转义的流式响应(Vertex AI OpenAI-compatible API 可能返回的格式) + // \u4e2d\u6587 = 中文 + chunkWithUnicode := `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4e2d\u6587\u6d4b\u8bd5"},"finish_reason":null}]}` + + // 处理流式响应体 - 应该解码 Unicode 转义 + action := host.CallOnHttpStreamingResponseBody([]byte(chunkWithUnicode), false) + require.Equal(t, types.ActionContinue, action) + + // 验证响应体中的 Unicode 转义已被解码 + responseBody := host.GetResponseBody() + require.NotNil(t, responseBody) + + responseStr := string(responseBody) + // 应该包含解码后的中文字符,而不是 \uXXXX 转义序列 + require.Contains(t, responseStr, "中文测试", "Unicode escapes should be decoded to Chinese characters") + require.NotContains(t, responseStr, `\u4e2d`, "Should not contain Unicode escape sequences") + }) + + // 测试 Vertex OpenAI 兼容模式非流式响应处理(Unicode 转义解码) + t.Run("vertex openai compatible mode response body - unicode escape decoding", func(t *testing.T) { + host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + // 先设置请求头 + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + // 设置请求体 + requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}` + host.CallOnHttpRequestBody([]byte(requestBody)) + + // 设置响应头 + responseHeaders := [][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + } + host.CallOnHttpResponseHeaders(responseHeaders) + + // 模拟带有 Unicode 转义的响应体 + // \u76c8\u5229\u80fd\u529b = 盈利能力 + responseBodyWithUnicode := `{"id":"chatcmpl-abc123","object":"chat.completion","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"message":{"role":"assistant","content":"\u76c8\u5229\u80fd\u529b\u5206\u6790"},"finish_reason":"stop"}]}` + + // 处理响应体 - 应该解码 Unicode 转义 + action := host.CallOnHttpResponseBody([]byte(responseBodyWithUnicode)) + require.Equal(t, types.ActionContinue, action) + + // 验证响应体中的 Unicode 转义已被解码 + processedResponseBody := host.GetResponseBody() + require.NotNil(t, processedResponseBody) + + responseStr := string(processedResponseBody) + // 应该包含解码后的中文字符 + require.Contains(t, responseStr, "盈利能力分析", "Unicode escapes should be decoded to Chinese characters") + require.NotContains(t, responseStr, `\u76c8`, "Should not contain Unicode escape sequences") + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/util/string.go b/plugins/wasm-go/extensions/ai-proxy/util/string.go index 69d1fd469..0b21456fd 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/string.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/string.go @@ -1,6 +1,10 @@ package util -import "regexp" +import ( + "regexp" + "strconv" + "strings" +) func StripPrefix(s string, prefix string) string { if len(prefix) != 0 && len(s) >= len(prefix) && s[0:len(prefix)] == prefix { @@ -18,3 +22,43 @@ func MatchStatus(status string, patterns []string) bool { } return false } + +// unicodeEscapeRegex matches Unicode escape sequences like \uXXXX +var unicodeEscapeRegex = regexp.MustCompile(`\\u([0-9a-fA-F]{4})`) + +// DecodeUnicodeEscapes decodes Unicode escape sequences (\uXXXX) in a string to UTF-8 characters. +// This is useful when a JSON response contains ASCII-safe encoded non-ASCII characters. +func DecodeUnicodeEscapes(input []byte) []byte { + result := unicodeEscapeRegex.ReplaceAllFunc(input, func(match []byte) []byte { + // match is like \uXXXX, extract the hex part (XXXX) + hexStr := string(match[2:6]) + codePoint, err := strconv.ParseInt(hexStr, 16, 32) + if err != nil { + return match // return original if parse fails + } + return []byte(string(rune(codePoint))) + }) + return result +} + +// DecodeUnicodeEscapesInSSE decodes Unicode escape sequences in SSE formatted data. +// It processes each line that starts with "data: " and decodes Unicode escapes in the JSON payload. +func DecodeUnicodeEscapesInSSE(input []byte) []byte { + lines := strings.Split(string(input), "\n") + var result strings.Builder + for i, line := range lines { + if strings.HasPrefix(line, "data: ") { + // Decode Unicode escapes in the JSON payload + jsonData := line[6:] + decodedData := DecodeUnicodeEscapes([]byte(jsonData)) + result.WriteString("data: ") + result.Write(decodedData) + } else { + result.WriteString(line) + } + if i < len(lines)-1 { + result.WriteString("\n") + } + } + return []byte(result.String()) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/util/string_test.go b/plugins/wasm-go/extensions/ai-proxy/util/string_test.go new file mode 100644 index 000000000..4042baf03 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/util/string_test.go @@ -0,0 +1,108 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDecodeUnicodeEscapes(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "Chinese characters", + input: `\u4e2d\u6587\u6d4b\u8bd5`, + expected: `中文测试`, + }, + { + name: "Mixed content", + input: `Hello \u4e16\u754c World`, + expected: `Hello 世界 World`, + }, + { + name: "No escape sequences", + input: `Hello World`, + expected: `Hello World`, + }, + { + name: "JSON with Unicode escapes", + input: `{"content":"\u76c8\u5229\u80fd\u529b"}`, + expected: `{"content":"盈利能力"}`, + }, + { + name: "Full width parentheses", + input: `\uff08\u76c8\u5229\uff09`, + expected: `(盈利)`, + }, + { + name: "Empty string", + input: ``, + expected: ``, + }, + { + name: "Invalid escape sequence (not modified)", + input: `\u00GG`, + expected: `\u00GG`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DecodeUnicodeEscapes([]byte(tt.input)) + assert.Equal(t, tt.expected, string(result)) + }) + } +} + +func TestDecodeUnicodeEscapesInSSE(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "SSE data with Unicode escapes", + input: `data: {"choices":[{"delta":{"content":"\u4e2d\u6587"}}]} + +`, + expected: `data: {"choices":[{"delta":{"content":"中文"}}]} + +`, + }, + { + name: "Multiple SSE data lines", + input: `data: {"content":"\u4e2d\u6587"} +data: {"content":"\u82f1\u6587"} +data: [DONE] +`, + expected: `data: {"content":"中文"} +data: {"content":"英文"} +data: [DONE] +`, + }, + { + name: "Non-data lines unchanged", + input: ": comment\nevent: message\ndata: test\n", + expected: ": comment\nevent: message\ndata: test\n", + }, + { + name: "Real Vertex AI response format", + input: `data: {"choices":[{"delta":{"content":"\uff08\u76c8\u5229\u80fd\u529b\uff09","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"} + +`, + expected: `data: {"choices":[{"delta":{"content":"(盈利能力)","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"} + +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DecodeUnicodeEscapesInSSE([]byte(tt.input)) + assert.Equal(t, tt.expected, string(result)) + }) + } +}