mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 15:10:54 +08:00
add support for image generation in Vertex AI provider (#3335)
This commit is contained in:
@@ -26,6 +26,8 @@ description: AI 代理插件配置参考
|
||||
|
||||
> 请求路径后缀匹配 `/v1/embeddings` 时,对应文本向量场景,会用 OpenAI 的文本向量协议解析请求 Body,再转换为对应 LLM 厂商的文本向量协议
|
||||
|
||||
> 请求路径后缀匹配 `/v1/images/generations` 时,对应文生图场景,会用 OpenAI 的图片生成协议解析请求 Body,再转换为对应 LLM 厂商的图片生成协议
|
||||
|
||||
## 运行属性
|
||||
|
||||
插件执行阶段:`默认阶段`
|
||||
@@ -2164,6 +2166,108 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 图片生成服务
|
||||
|
||||
Vertex AI 支持使用 Gemini 模型进行图片生成。通过 ai-proxy 插件,可以使用 OpenAI 的 `/v1/images/generations` 接口协议来调用 Vertex AI 的图片生成能力。
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
modelMapping:
|
||||
"dall-e-3": "gemini-2.0-flash-exp"
|
||||
geminiSafetySetting:
|
||||
HARM_CATEGORY_HARASSMENT: "OFF"
|
||||
HARM_CATEGORY_HATE_SPEECH: "OFF"
|
||||
HARM_CATEGORY_SEXUALLY_EXPLICIT: "OFF"
|
||||
HARM_CATEGORY_DANGEROUS_CONTENT: "OFF"
|
||||
```
|
||||
|
||||
**使用 curl 请求**
|
||||
|
||||
```bash
|
||||
curl -X POST "http://your-gateway-address/v1/images/generations" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"prompt": "一只可爱的橘猫在阳光下打盹",
|
||||
"size": "1024x1024"
|
||||
}'
|
||||
```
|
||||
|
||||
**使用 OpenAI Python SDK**
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="any-value", # 可以是任意值,认证由网关处理
|
||||
base_url="http://your-gateway-address/v1"
|
||||
)
|
||||
|
||||
response = client.images.generate(
|
||||
model="gemini-2.0-flash-exp",
|
||||
prompt="一只可爱的橘猫在阳光下打盹",
|
||||
size="1024x1024",
|
||||
n=1
|
||||
)
|
||||
|
||||
# 获取生成的图片(base64 编码)
|
||||
image_data = response.data[0].b64_json
|
||||
print(f"Generated image (base64): {image_data[:100]}...")
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"created": 1729986750,
|
||||
"data": [
|
||||
{
|
||||
"b64_json": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAAA..."
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"total_tokens": 1356,
|
||||
"input_tokens": 13,
|
||||
"output_tokens": 1120
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**支持的尺寸参数**
|
||||
|
||||
Vertex AI 支持的宽高比(aspectRatio):`1:1`、`3:2`、`2:3`、`3:4`、`4:3`、`4:5`、`5:4`、`9:16`、`16:9`、`21:9`
|
||||
|
||||
Vertex AI 支持的分辨率(imageSize):`1k`、`2k`、`4k`
|
||||
|
||||
| OpenAI size 参数 | Vertex AI aspectRatio | Vertex AI imageSize |
|
||||
|------------------|----------------------|---------------------|
|
||||
| 256x256 | 1:1 | 1k |
|
||||
| 512x512 | 1:1 | 1k |
|
||||
| 1024x1024 | 1:1 | 1k |
|
||||
| 1792x1024 | 16:9 | 2k |
|
||||
| 1024x1792 | 9:16 | 2k |
|
||||
| 2048x2048 | 1:1 | 2k |
|
||||
| 4096x4096 | 1:1 | 4k |
|
||||
| 1536x1024 | 3:2 | 2k |
|
||||
| 1024x1536 | 2:3 | 2k |
|
||||
| 1024x768 | 4:3 | 1k |
|
||||
| 768x1024 | 3:4 | 1k |
|
||||
| 1280x1024 | 5:4 | 1k |
|
||||
| 1024x1280 | 4:5 | 1k |
|
||||
| 2560x1080 | 21:9 | 2k |
|
||||
|
||||
**注意事项**
|
||||
|
||||
- 图片生成使用 Gemini 模型(如 `gemini-2.0-flash-exp`、`gemini-3-pro-image-preview`),不同模型的可用性可能因区域而异
|
||||
- 返回的图片数据为 base64 编码格式(`b64_json`)
|
||||
- 可以通过 `geminiSafetySetting` 配置内容安全过滤级别
|
||||
- 如果需要使用模型映射(如将 `dall-e-3` 映射到 Gemini 模型),可以配置 `modelMapping`
|
||||
|
||||
### 使用 OpenAI 协议代理 AWS Bedrock 服务
|
||||
|
||||
AWS Bedrock 支持两种认证方式:
|
||||
|
||||
@@ -25,6 +25,8 @@ The plugin now supports **automatic protocol detection**, allowing seamless comp
|
||||
|
||||
> When the request path suffix matches `/v1/embeddings`, it corresponds to text vector scenarios. The request body will be parsed using OpenAI's text vector protocol and then converted to the corresponding LLM vendor's text vector protocol.
|
||||
|
||||
> When the request path suffix matches `/v1/images/generations`, it corresponds to text-to-image scenarios. The request body will be parsed using OpenAI's image generation protocol and then converted to the corresponding LLM vendor's image generation protocol.
|
||||
|
||||
## Execution Properties
|
||||
Plugin execution phase: `Default Phase`
|
||||
Plugin execution priority: `100`
|
||||
@@ -1927,6 +1929,108 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Image Generation
|
||||
|
||||
Vertex AI supports image generation using Gemini models. Through the ai-proxy plugin, you can use OpenAI's `/v1/images/generations` API to call Vertex AI's image generation capabilities.
|
||||
|
||||
**Configuration Information**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
modelMapping:
|
||||
"dall-e-3": "gemini-2.0-flash-exp"
|
||||
geminiSafetySetting:
|
||||
HARM_CATEGORY_HARASSMENT: "OFF"
|
||||
HARM_CATEGORY_HATE_SPEECH: "OFF"
|
||||
HARM_CATEGORY_SEXUALLY_EXPLICIT: "OFF"
|
||||
HARM_CATEGORY_DANGEROUS_CONTENT: "OFF"
|
||||
```
|
||||
|
||||
**Using curl**
|
||||
|
||||
```bash
|
||||
curl -X POST "http://your-gateway-address/v1/images/generations" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"prompt": "A cute orange cat napping in the sunshine",
|
||||
"size": "1024x1024"
|
||||
}'
|
||||
```
|
||||
|
||||
**Using OpenAI Python SDK**
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="any-value", # Can be any value, authentication is handled by the gateway
|
||||
base_url="http://your-gateway-address/v1"
|
||||
)
|
||||
|
||||
response = client.images.generate(
|
||||
model="gemini-2.0-flash-exp",
|
||||
prompt="A cute orange cat napping in the sunshine",
|
||||
size="1024x1024",
|
||||
n=1
|
||||
)
|
||||
|
||||
# Get the generated image (base64 encoded)
|
||||
image_data = response.data[0].b64_json
|
||||
print(f"Generated image (base64): {image_data[:100]}...")
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
|
||||
```json
|
||||
{
|
||||
"created": 1729986750,
|
||||
"data": [
|
||||
{
|
||||
"b64_json": "iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAAAA..."
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"total_tokens": 1356,
|
||||
"input_tokens": 13,
|
||||
"output_tokens": 1120
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Supported Size Parameters**
|
||||
|
||||
Vertex AI supported aspect ratios: `1:1`, `3:2`, `2:3`, `3:4`, `4:3`, `4:5`, `5:4`, `9:16`, `16:9`, `21:9`
|
||||
|
||||
Vertex AI supported resolutions (imageSize): `1k`, `2k`, `4k`
|
||||
|
||||
| OpenAI size parameter | Vertex AI aspectRatio | Vertex AI imageSize |
|
||||
|-----------------------|----------------------|---------------------|
|
||||
| 256x256 | 1:1 | 1k |
|
||||
| 512x512 | 1:1 | 1k |
|
||||
| 1024x1024 | 1:1 | 1k |
|
||||
| 1792x1024 | 16:9 | 2k |
|
||||
| 1024x1792 | 9:16 | 2k |
|
||||
| 2048x2048 | 1:1 | 2k |
|
||||
| 4096x4096 | 1:1 | 4k |
|
||||
| 1536x1024 | 3:2 | 2k |
|
||||
| 1024x1536 | 2:3 | 2k |
|
||||
| 1024x768 | 4:3 | 1k |
|
||||
| 768x1024 | 3:4 | 1k |
|
||||
| 1280x1024 | 5:4 | 1k |
|
||||
| 1024x1280 | 4:5 | 1k |
|
||||
| 2560x1080 | 21:9 | 2k |
|
||||
|
||||
**Notes**
|
||||
|
||||
- Image generation uses Gemini models (e.g., `gemini-2.0-flash-exp`, `gemini-3-pro-image-preview`). Model availability may vary by region
|
||||
- The returned image data is in base64 encoded format (`b64_json`)
|
||||
- Content safety filtering levels can be configured via `geminiSafetySetting`
|
||||
- If you need model mapping (e.g., mapping `dall-e-3` to a Gemini model), configure `modelMapping`
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services
|
||||
|
||||
AWS Bedrock supports two authentication methods:
|
||||
|
||||
@@ -135,6 +135,8 @@ func TestVertex(t *testing.T) {
|
||||
test.RunVertexExpressModeOnHttpRequestBodyTests(t)
|
||||
test.RunVertexExpressModeOnHttpResponseBodyTests(t)
|
||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
|
||||
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
|
||||
}
|
||||
|
||||
func TestBedrock(t *testing.T) {
|
||||
|
||||
@@ -89,8 +89,9 @@ func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,10 +266,15 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return v.onChatCompletionRequestBody(ctx, body, headers)
|
||||
} else {
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,6 +344,119 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageGenerationRequest{}
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 图片生成不使用流式端点,需要完整响应
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
||||
// 构建安全设置
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
safetySettings = append(safetySettings, vertexChatSafetySetting{
|
||||
Category: category,
|
||||
Threshold: threshold,
|
||||
})
|
||||
}
|
||||
|
||||
// 解析尺寸参数
|
||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
||||
|
||||
// 确定输出 MIME 类型
|
||||
mimeType := "image/png"
|
||||
if request.OutputFormat != "" {
|
||||
switch request.OutputFormat {
|
||||
case "jpeg", "jpg":
|
||||
mimeType = "image/jpeg"
|
||||
case "webp":
|
||||
mimeType = "image/webp"
|
||||
default:
|
||||
mimeType = "image/png"
|
||||
}
|
||||
}
|
||||
|
||||
vertexRequest := &vertexChatRequest{
|
||||
Contents: []vertexChatContent{{
|
||||
Role: roleUser,
|
||||
Parts: []vertexPart{{
|
||||
Text: request.Prompt,
|
||||
}},
|
||||
}},
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: vertexChatGenerationConfig{
|
||||
Temperature: 1.0,
|
||||
MaxOutputTokens: 32768,
|
||||
ResponseModalities: []string{"TEXT", "IMAGE"},
|
||||
ImageConfig: &vertexImageConfig{
|
||||
AspectRatio: aspectRatio,
|
||||
ImageSize: imageSize,
|
||||
ImageOutputOptions: &vertexImageOutputOptions{
|
||||
MimeType: mimeType,
|
||||
},
|
||||
PersonGeneration: "ALLOW_ALL",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return vertexRequest
|
||||
}
|
||||
|
||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||
// Vertex AI 支持的 aspectRatio: 1:1, 3:2, 2:3, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9, 21:9
|
||||
// Vertex AI 支持的 imageSize: 1k, 2k, 4k
|
||||
func (v *vertexProvider) parseImageSize(size string) (aspectRatio, imageSize string) {
|
||||
// 默认值
|
||||
aspectRatio = "1:1"
|
||||
imageSize = "1k"
|
||||
|
||||
if size == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 预定义的尺寸映射(OpenAI 标准尺寸)
|
||||
sizeMapping := map[string]struct {
|
||||
aspectRatio string
|
||||
imageSize string
|
||||
}{
|
||||
// OpenAI DALL-E 标准尺寸
|
||||
"256x256": {"1:1", "1k"},
|
||||
"512x512": {"1:1", "1k"},
|
||||
"1024x1024": {"1:1", "1k"},
|
||||
"1792x1024": {"16:9", "2k"},
|
||||
"1024x1792": {"9:16", "2k"},
|
||||
// 扩展尺寸支持
|
||||
"2048x2048": {"1:1", "2k"},
|
||||
"4096x4096": {"1:1", "4k"},
|
||||
// 3:2 和 2:3 比例
|
||||
"1536x1024": {"3:2", "2k"},
|
||||
"1024x1536": {"2:3", "2k"},
|
||||
// 4:3 和 3:4 比例
|
||||
"1024x768": {"4:3", "1k"},
|
||||
"768x1024": {"3:4", "1k"},
|
||||
"1365x1024": {"4:3", "1k"},
|
||||
"1024x1365": {"3:4", "1k"},
|
||||
// 5:4 和 4:5 比例
|
||||
"1280x1024": {"5:4", "1k"},
|
||||
"1024x1280": {"4:5", "1k"},
|
||||
// 21:9 超宽比例
|
||||
"2560x1080": {"21:9", "2k"},
|
||||
}
|
||||
|
||||
if mapping, ok := sizeMapping[size]; ok {
|
||||
return mapping.aspectRatio, mapping.imageSize
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
|
||||
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX
|
||||
@@ -394,10 +513,16 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.TransformResponseBody(ctx, apiName, body)
|
||||
}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
} else {
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationResponseBody(ctx, body)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -490,6 +615,54 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
|
||||
return &response
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
// 使用 gjson 直接提取字段,避免完整反序列化大型 base64 数据
|
||||
// 这样可以显著减少内存分配和复制次数
|
||||
response := v.buildImageGenerationResponseFromJSON(body)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
// buildImageGenerationResponseFromJSON 使用 gjson 从原始 JSON 中提取图片生成响应
|
||||
// 相比 json.Unmarshal 完整反序列化,这种方式内存效率更高
|
||||
func (v *vertexProvider) buildImageGenerationResponseFromJSON(body []byte) *imageGenerationResponse {
|
||||
result := gjson.ParseBytes(body)
|
||||
data := make([]imageGenerationData, 0)
|
||||
|
||||
// 遍历所有 candidates,提取图片数据
|
||||
candidates := result.Get("candidates")
|
||||
candidates.ForEach(func(_, candidate gjson.Result) bool {
|
||||
parts := candidate.Get("content.parts")
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
// 跳过思考过程 (thought: true)
|
||||
if part.Get("thought").Bool() {
|
||||
return true
|
||||
}
|
||||
// 提取图片数据
|
||||
inlineData := part.Get("inlineData.data")
|
||||
if inlineData.Exists() && inlineData.String() != "" {
|
||||
data = append(data, imageGenerationData{
|
||||
B64: inlineData.String(),
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
|
||||
// 提取 usage 信息
|
||||
usage := result.Get("usageMetadata")
|
||||
|
||||
return &imageGenerationResponse{
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Data: data,
|
||||
Usage: &imageGenerationUsage{
|
||||
TotalTokens: int(usage.Get("totalTokenCount").Int()),
|
||||
InputTokens: int(usage.Get("promptTokenCount").Int()),
|
||||
OutputTokens: int(usage.Get("candidatesTokenCount").Int()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
||||
var choice chatCompletionChoice
|
||||
choice.Delta = &chatMessage{}
|
||||
@@ -574,12 +747,18 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if apiName == ApiNameEmbeddings {
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
action = vertexEmbeddingAction
|
||||
} else if stream {
|
||||
action = vertexChatCompletionStreamAction
|
||||
} else {
|
||||
case ApiNameImageGeneration:
|
||||
// 图片生成使用非流式端点,需要完整响应
|
||||
action = vertexChatCompletionAction
|
||||
default:
|
||||
if stream {
|
||||
action = vertexChatCompletionStreamAction
|
||||
} else {
|
||||
action = vertexChatCompletionAction
|
||||
}
|
||||
}
|
||||
|
||||
if v.isExpressMode() {
|
||||
@@ -689,7 +868,7 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
|
||||
})
|
||||
}
|
||||
case contentTypeImageUrl:
|
||||
vpart, err := convertImageContent(part.ImageUrl.Url)
|
||||
vpart, err := convertMediaContent(part.ImageUrl.Url)
|
||||
if err != nil {
|
||||
log.Errorf("unable to convert image content: %v", err)
|
||||
} else {
|
||||
@@ -804,12 +983,25 @@ type vertexChatSafetySetting struct {
|
||||
}
|
||||
|
||||
type vertexChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ImageConfig *vertexImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
type vertexImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||
ImageSize string `json:"imageSize,omitempty"`
|
||||
ImageOutputOptions *vertexImageOutputOptions `json:"imageOutputOptions,omitempty"`
|
||||
PersonGeneration string `json:"personGeneration,omitempty"`
|
||||
}
|
||||
|
||||
type vertexImageOutputOptions struct {
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
}
|
||||
|
||||
type vertexThinkingConfig struct {
|
||||
@@ -1020,32 +1212,106 @@ func setCachedAccessToken(key string, accessToken string, expireTime int64) erro
|
||||
return proxywasm.SetSharedData(key, data, cas)
|
||||
}
|
||||
|
||||
func convertImageContent(imageUrl string) (vertexPart, error) {
|
||||
// convertMediaContent 将 OpenAI 格式的媒体 URL 转换为 Vertex AI 格式
|
||||
// 支持图片、视频、音频等多种媒体类型
|
||||
func convertMediaContent(mediaUrl string) (vertexPart, error) {
|
||||
part := vertexPart{}
|
||||
if strings.HasPrefix(imageUrl, "http") {
|
||||
arr := strings.Split(imageUrl, ".")
|
||||
mimeType := "image/" + arr[len(arr)-1]
|
||||
if strings.HasPrefix(mediaUrl, "http") {
|
||||
mimeType := detectMimeTypeFromURL(mediaUrl)
|
||||
part.FileData = &fileData{
|
||||
MimeType: mimeType,
|
||||
FileUri: imageUrl,
|
||||
FileUri: mediaUrl,
|
||||
}
|
||||
return part, nil
|
||||
} else {
|
||||
// Base64 data URL 格式: data:<mimeType>;base64,<data>
|
||||
re := regexp.MustCompile(`^data:([^;]+);base64,`)
|
||||
matches := re.FindStringSubmatch(imageUrl)
|
||||
matches := re.FindStringSubmatch(mediaUrl)
|
||||
if len(matches) < 2 {
|
||||
return part, fmt.Errorf("invalid base64 format")
|
||||
return part, fmt.Errorf("invalid base64 format, expected data:<mimeType>;base64,<data>")
|
||||
}
|
||||
|
||||
mimeType := matches[1] // e.g. image/png
|
||||
mimeType := matches[1] // e.g. image/png, video/mp4, audio/mp3
|
||||
parts := strings.Split(mimeType, "/")
|
||||
if len(parts) < 2 {
|
||||
return part, fmt.Errorf("invalid mimeType")
|
||||
return part, fmt.Errorf("invalid mimeType: %s", mimeType)
|
||||
}
|
||||
part.InlineData = &blob{
|
||||
MimeType: mimeType,
|
||||
Data: strings.TrimPrefix(imageUrl, matches[0]),
|
||||
Data: strings.TrimPrefix(mediaUrl, matches[0]),
|
||||
}
|
||||
return part, nil
|
||||
}
|
||||
}
|
||||
|
||||
// detectMimeTypeFromURL 根据 URL 的文件扩展名检测 MIME 类型
|
||||
// 支持图片、视频、音频和文档类型
|
||||
func detectMimeTypeFromURL(url string) string {
|
||||
// 移除查询参数和片段标识符
|
||||
if idx := strings.Index(url, "?"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
if idx := strings.Index(url, "#"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
|
||||
// 获取最后一个路径段
|
||||
lastSlash := strings.LastIndex(url, "/")
|
||||
if lastSlash != -1 {
|
||||
url = url[lastSlash+1:]
|
||||
}
|
||||
|
||||
// 获取扩展名
|
||||
lastDot := strings.LastIndex(url, ".")
|
||||
if lastDot == -1 || lastDot == len(url)-1 {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
ext := strings.ToLower(url[lastDot+1:])
|
||||
|
||||
// 扩展名到 MIME 类型的映射
|
||||
mimeTypes := map[string]string{
|
||||
// 图片格式
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/x-icon",
|
||||
"heic": "image/heic",
|
||||
"heif": "image/heif",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
// 视频格式
|
||||
"mp4": "video/mp4",
|
||||
"mpeg": "video/mpeg",
|
||||
"mpg": "video/mpeg",
|
||||
"mov": "video/quicktime",
|
||||
"avi": "video/x-msvideo",
|
||||
"wmv": "video/x-ms-wmv",
|
||||
"webm": "video/webm",
|
||||
"mkv": "video/x-matroska",
|
||||
"flv": "video/x-flv",
|
||||
"3gp": "video/3gpp",
|
||||
"3g2": "video/3gpp2",
|
||||
"m4v": "video/x-m4v",
|
||||
// 音频格式
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
"m4a": "audio/mp4",
|
||||
"wma": "audio/x-ms-wma",
|
||||
"opus": "audio/opus",
|
||||
// 文档格式
|
||||
"pdf": "application/pdf",
|
||||
}
|
||||
|
||||
if mimeType, ok := mimeTypes[ext]; ok {
|
||||
return mimeType
|
||||
}
|
||||
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
@@ -886,3 +886,348 @@ func RunVertexOpenAICompatibleModeOnHttpResponseBodyTests(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 图片生成测试 ====================
|
||||
|
||||
func RunVertexExpressModeImageGenerationRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 图片生成请求体处理
|
||||
t.Run("vertex express mode image generation request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体(OpenAI 图片生成格式)
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A cute orange cat napping in the sunshine","size":"1024x1024"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// Express Mode 不需要暂停等待 OAuth token
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体是否被正确处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证请求体被转换为 Vertex 格式
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "contents", "Request should be converted to vertex format with contents")
|
||||
require.Contains(t, bodyStr, "generationConfig", "Request should contain generationConfig")
|
||||
require.Contains(t, bodyStr, "responseModalities", "Request should contain responseModalities for image generation")
|
||||
require.Contains(t, bodyStr, "IMAGE", "Request should specify IMAGE in responseModalities")
|
||||
require.Contains(t, bodyStr, "imageConfig", "Request should contain imageConfig")
|
||||
|
||||
// 验证路径包含 API Key 和正确的模型
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter")
|
||||
require.Contains(t, pathHeader, "/v1/publishers/google/models/", "Path should use Express Mode format")
|
||||
require.Contains(t, pathHeader, "generateContent", "Path should use generateContent action for image generation")
|
||||
require.NotContains(t, pathHeader, "streamGenerateContent", "Path should NOT use streaming for image generation")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 图片生成请求体处理(自定义尺寸)
|
||||
t.Run("vertex express mode image generation with custom size", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体(宽屏尺寸)
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A beautiful sunset over the ocean","size":"1792x1024"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体是否正确处理尺寸映射
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
// 1792x1024 应该映射为 16:9 宽高比
|
||||
require.Contains(t, bodyStr, "aspectRatio", "Request should contain aspectRatio in imageConfig")
|
||||
require.Contains(t, bodyStr, "16:9", "Request should map 1792x1024 to 16:9 aspect ratio")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 图片生成请求体处理(含安全设置)
|
||||
t.Run("vertex express mode image generation with safety settings", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithSafetyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A mountain landscape"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体包含安全设置
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
bodyStr := string(processedBody)
|
||||
require.Contains(t, bodyStr, "safetySettings", "Request should contain safetySettings")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 图片生成请求体处理(含模型映射)
|
||||
t.Run("vertex express mode image generation with model mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体(使用映射前的模型名称)
|
||||
requestBody := `{"model":"gpt-4","prompt":"A futuristic city"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证路径中使用了映射后的模型名称
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeImageGenerationResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 图片生成响应体处理
|
||||
t.Run("vertex express mode image generation response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"A cute cat"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性,确保IsResponseFromUpstream()返回true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体(Vertex 图片生成格式)
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finishReason": "STOP"
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 10,
|
||||
"candidatesTokenCount": 1024,
|
||||
"totalTokenCount": 1034
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体是否被正确处理
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
|
||||
// 验证响应体被转换为 OpenAI 图片生成格式
|
||||
require.Contains(t, responseStr, "created", "Response should contain created field")
|
||||
require.Contains(t, responseStr, "data", "Response should contain data array")
|
||||
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field with base64 image data")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage information")
|
||||
require.Contains(t, responseStr, "total_tokens", "Response should contain total_tokens in usage")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 图片生成响应体处理(跳过思考过程)
|
||||
t.Run("vertex express mode image generation response body - skip thinking", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-3-pro-image-preview","prompt":"An Eiffel tower"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体(包含思考过程和图片)
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": "Considering visual elements...",
|
||||
"thought": true
|
||||
},
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finishReason": "STOP"
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 13,
|
||||
"candidatesTokenCount": 1120,
|
||||
"totalTokenCount": 1356,
|
||||
"thoughtsTokenCount": 223
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体是否被正确处理
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
|
||||
// 验证响应体只包含图片数据,不包含思考过程文本
|
||||
require.Contains(t, responseStr, "b64_json", "Response should contain b64_json field")
|
||||
require.NotContains(t, responseStr, "Considering visual elements", "Response should NOT contain thinking text")
|
||||
require.NotContains(t, responseStr, "thought", "Response should NOT contain thought field")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 图片生成响应体处理(空图片数据)
|
||||
t.Run("vertex express mode image generation response body - no image", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-2.0-flash-exp","prompt":"test"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体(只有文本,没有图片)
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [{
|
||||
"text": "I cannot generate that image."
|
||||
}]
|
||||
},
|
||||
"finishReason": "SAFETY"
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 5,
|
||||
"candidatesTokenCount": 10,
|
||||
"totalTokenCount": 15
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体是否被正确处理(即使没有图片)
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
|
||||
// 验证响应体结构正确,data 数组为空
|
||||
require.Contains(t, responseStr, "created", "Response should contain created field")
|
||||
require.Contains(t, responseStr, "data", "Response should contain data array")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user