add support for image generation in Vertex AI provider (#3335)

This commit is contained in:
woody
2026-01-19 16:40:29 +08:00
committed by GitHub
parent ac69eb5b27
commit 399d2f372e
5 changed files with 848 additions and 27 deletions

View File

@@ -89,8 +89,9 @@ func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): vertexPathTemplate,
string(ApiNameEmbeddings): vertexPathTemplate,
string(ApiNameChatCompletion): vertexPathTemplate,
string(ApiNameEmbeddings): vertexPathTemplate,
string(ApiNameImageGeneration): vertexPathTemplate,
}
}
@@ -265,10 +266,15 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}
func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName == ApiNameChatCompletion {
switch apiName {
case ApiNameChatCompletion:
return v.onChatCompletionRequestBody(ctx, body, headers)
} else {
case ApiNameEmbeddings:
return v.onEmbeddingsRequestBody(ctx, body, headers)
case ApiNameImageGeneration:
return v.onImageGenerationRequestBody(ctx, body, headers)
default:
return body, nil
}
}
@@ -338,6 +344,119 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &imageGenerationRequest{}
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
// 图片生成不使用流式端点,需要完整响应
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexImageGenerationRequest(request)
return json.Marshal(vertexRequest)
}
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
// 构建安全设置
safetySettings := make([]vertexChatSafetySetting, 0)
for category, threshold := range v.config.geminiSafetySetting {
safetySettings = append(safetySettings, vertexChatSafetySetting{
Category: category,
Threshold: threshold,
})
}
// 解析尺寸参数
aspectRatio, imageSize := v.parseImageSize(request.Size)
// 确定输出 MIME 类型
mimeType := "image/png"
if request.OutputFormat != "" {
switch request.OutputFormat {
case "jpeg", "jpg":
mimeType = "image/jpeg"
case "webp":
mimeType = "image/webp"
default:
mimeType = "image/png"
}
}
vertexRequest := &vertexChatRequest{
Contents: []vertexChatContent{{
Role: roleUser,
Parts: []vertexPart{{
Text: request.Prompt,
}},
}},
SafetySettings: safetySettings,
GenerationConfig: vertexChatGenerationConfig{
Temperature: 1.0,
MaxOutputTokens: 32768,
ResponseModalities: []string{"TEXT", "IMAGE"},
ImageConfig: &vertexImageConfig{
AspectRatio: aspectRatio,
ImageSize: imageSize,
ImageOutputOptions: &vertexImageOutputOptions{
MimeType: mimeType,
},
PersonGeneration: "ALLOW_ALL",
},
},
}
return vertexRequest
}
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
// Vertex AI 支持的 aspectRatio: 1:1, 3:2, 2:3, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9, 21:9
// Vertex AI 支持的 imageSize: 1k, 2k, 4k
func (v *vertexProvider) parseImageSize(size string) (aspectRatio, imageSize string) {
// 默认值
aspectRatio = "1:1"
imageSize = "1k"
if size == "" {
return
}
// 预定义的尺寸映射OpenAI 标准尺寸)
sizeMapping := map[string]struct {
aspectRatio string
imageSize string
}{
// OpenAI DALL-E 标准尺寸
"256x256": {"1:1", "1k"},
"512x512": {"1:1", "1k"},
"1024x1024": {"1:1", "1k"},
"1792x1024": {"16:9", "2k"},
"1024x1792": {"9:16", "2k"},
// 扩展尺寸支持
"2048x2048": {"1:1", "2k"},
"4096x4096": {"1:1", "4k"},
// 3:2 和 2:3 比例
"1536x1024": {"3:2", "2k"},
"1024x1536": {"2:3", "2k"},
// 4:3 和 3:4 比例
"1024x768": {"4:3", "1k"},
"768x1024": {"3:4", "1k"},
"1365x1024": {"4:3", "1k"},
"1024x1365": {"3:4", "1k"},
// 5:4 和 4:5 比例
"1280x1024": {"5:4", "1k"},
"1024x1280": {"4:5", "1k"},
// 21:9 超宽比例
"2560x1080": {"21:9", "2k"},
}
if mapping, ok := sizeMapping[size]; ok {
return mapping.aspectRatio, mapping.imageSize
}
return
}
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON将非 ASCII 字符编码为 \uXXXX
@@ -394,10 +513,16 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.TransformResponseBody(ctx, apiName, body)
}
if apiName == ApiNameChatCompletion {
switch apiName {
case ApiNameChatCompletion:
return v.onChatCompletionResponseBody(ctx, body)
} else {
case ApiNameEmbeddings:
return v.onEmbeddingsResponseBody(ctx, body)
case ApiNameImageGeneration:
return v.onImageGenerationResponseBody(ctx, body)
default:
return body, nil
}
}
@@ -490,6 +615,54 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
return &response
}
func (v *vertexProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
// 使用 gjson 直接提取字段,避免完整反序列化大型 base64 数据
// 这样可以显著减少内存分配和复制次数
response := v.buildImageGenerationResponseFromJSON(body)
return json.Marshal(response)
}
// buildImageGenerationResponseFromJSON 使用 gjson 从原始 JSON 中提取图片生成响应
// 相比 json.Unmarshal 完整反序列化,这种方式内存效率更高
func (v *vertexProvider) buildImageGenerationResponseFromJSON(body []byte) *imageGenerationResponse {
result := gjson.ParseBytes(body)
data := make([]imageGenerationData, 0)
// 遍历所有 candidates提取图片数据
candidates := result.Get("candidates")
candidates.ForEach(func(_, candidate gjson.Result) bool {
parts := candidate.Get("content.parts")
parts.ForEach(func(_, part gjson.Result) bool {
// 跳过思考过程 (thought: true)
if part.Get("thought").Bool() {
return true
}
// 提取图片数据
inlineData := part.Get("inlineData.data")
if inlineData.Exists() && inlineData.String() != "" {
data = append(data, imageGenerationData{
B64: inlineData.String(),
})
}
return true
})
return true
})
// 提取 usage 信息
usage := result.Get("usageMetadata")
return &imageGenerationResponse{
Created: time.Now().UnixMilli() / 1000,
Data: data,
Usage: &imageGenerationUsage{
TotalTokens: int(usage.Get("totalTokenCount").Int()),
InputTokens: int(usage.Get("promptTokenCount").Int()),
OutputTokens: int(usage.Get("candidatesTokenCount").Int()),
},
}
}
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
var choice chatCompletionChoice
choice.Delta = &chatMessage{}
@@ -574,12 +747,18 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {
switch apiName {
case ApiNameEmbeddings:
action = vertexEmbeddingAction
} else if stream {
action = vertexChatCompletionStreamAction
} else {
case ApiNameImageGeneration:
// 图片生成使用非流式端点,需要完整响应
action = vertexChatCompletionAction
default:
if stream {
action = vertexChatCompletionStreamAction
} else {
action = vertexChatCompletionAction
}
}
if v.isExpressMode() {
@@ -689,7 +868,7 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
})
}
case contentTypeImageUrl:
vpart, err := convertImageContent(part.ImageUrl.Url)
vpart, err := convertMediaContent(part.ImageUrl.Url)
if err != nil {
log.Errorf("unable to convert image content: %v", err)
} else {
@@ -804,12 +983,25 @@ type vertexChatSafetySetting struct {
}
type vertexChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ImageConfig *vertexImageConfig `json:"imageConfig,omitempty"`
}
type vertexImageConfig struct {
AspectRatio string `json:"aspectRatio,omitempty"`
ImageSize string `json:"imageSize,omitempty"`
ImageOutputOptions *vertexImageOutputOptions `json:"imageOutputOptions,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
type vertexImageOutputOptions struct {
MimeType string `json:"mimeType,omitempty"`
}
type vertexThinkingConfig struct {
@@ -1020,32 +1212,106 @@ func setCachedAccessToken(key string, accessToken string, expireTime int64) erro
return proxywasm.SetSharedData(key, data, cas)
}
func convertImageContent(imageUrl string) (vertexPart, error) {
// convertMediaContent 将 OpenAI 格式的媒体 URL 转换为 Vertex AI 格式
// 支持图片、视频、音频等多种媒体类型
func convertMediaContent(mediaUrl string) (vertexPart, error) {
part := vertexPart{}
if strings.HasPrefix(imageUrl, "http") {
arr := strings.Split(imageUrl, ".")
mimeType := "image/" + arr[len(arr)-1]
if strings.HasPrefix(mediaUrl, "http") {
mimeType := detectMimeTypeFromURL(mediaUrl)
part.FileData = &fileData{
MimeType: mimeType,
FileUri: imageUrl,
FileUri: mediaUrl,
}
return part, nil
} else {
// Base64 data URL 格式: data:<mimeType>;base64,<data>
re := regexp.MustCompile(`^data:([^;]+);base64,`)
matches := re.FindStringSubmatch(imageUrl)
matches := re.FindStringSubmatch(mediaUrl)
if len(matches) < 2 {
return part, fmt.Errorf("invalid base64 format")
return part, fmt.Errorf("invalid base64 format, expected data:<mimeType>;base64,<data>")
}
mimeType := matches[1] // e.g. image/png
mimeType := matches[1] // e.g. image/png, video/mp4, audio/mp3
parts := strings.Split(mimeType, "/")
if len(parts) < 2 {
return part, fmt.Errorf("invalid mimeType")
return part, fmt.Errorf("invalid mimeType: %s", mimeType)
}
part.InlineData = &blob{
MimeType: mimeType,
Data: strings.TrimPrefix(imageUrl, matches[0]),
Data: strings.TrimPrefix(mediaUrl, matches[0]),
}
return part, nil
}
}
// detectMimeTypeFromURL 根据 URL 的文件扩展名检测 MIME 类型
// 支持图片、视频、音频和文档类型
func detectMimeTypeFromURL(url string) string {
// 移除查询参数和片段标识符
if idx := strings.Index(url, "?"); idx != -1 {
url = url[:idx]
}
if idx := strings.Index(url, "#"); idx != -1 {
url = url[:idx]
}
// 获取最后一个路径段
lastSlash := strings.LastIndex(url, "/")
if lastSlash != -1 {
url = url[lastSlash+1:]
}
// 获取扩展名
lastDot := strings.LastIndex(url, ".")
if lastDot == -1 || lastDot == len(url)-1 {
return "application/octet-stream"
}
ext := strings.ToLower(url[lastDot+1:])
// 扩展名到 MIME 类型的映射
mimeTypes := map[string]string{
// 图片格式
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
"svg": "image/svg+xml",
"ico": "image/x-icon",
"heic": "image/heic",
"heif": "image/heif",
"tiff": "image/tiff",
"tif": "image/tiff",
// 视频格式
"mp4": "video/mp4",
"mpeg": "video/mpeg",
"mpg": "video/mpeg",
"mov": "video/quicktime",
"avi": "video/x-msvideo",
"wmv": "video/x-ms-wmv",
"webm": "video/webm",
"mkv": "video/x-matroska",
"flv": "video/x-flv",
"3gp": "video/3gpp",
"3g2": "video/3gpp2",
"m4v": "video/x-m4v",
// 音频格式
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"m4a": "audio/mp4",
"wma": "audio/x-ms-wma",
"opus": "audio/opus",
// 文档格式
"pdf": "application/pdf",
}
if mimeType, ok := mimeTypes[ext]; ok {
return mimeType
}
return "application/octet-stream"
}