mirror of
https://github.com/alibaba/higress.git
synced 2026-06-07 11:47:30 +08:00
add support for image generation in Vertex AI provider (#3335)
This commit is contained in:
@@ -89,8 +89,9 @@ func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
string(ApiNameImageGeneration): vertexPathTemplate,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,10 +266,15 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return v.onChatCompletionRequestBody(ctx, body, headers)
|
||||
} else {
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationRequestBody(ctx, body, headers)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,6 +344,119 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageGenerationRequest{}
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 图片生成不使用流式端点,需要完整响应
|
||||
path := v.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexImageGenerationRequest(request)
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexImageGenerationRequest(request *imageGenerationRequest) *vertexChatRequest {
|
||||
// 构建安全设置
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
safetySettings = append(safetySettings, vertexChatSafetySetting{
|
||||
Category: category,
|
||||
Threshold: threshold,
|
||||
})
|
||||
}
|
||||
|
||||
// 解析尺寸参数
|
||||
aspectRatio, imageSize := v.parseImageSize(request.Size)
|
||||
|
||||
// 确定输出 MIME 类型
|
||||
mimeType := "image/png"
|
||||
if request.OutputFormat != "" {
|
||||
switch request.OutputFormat {
|
||||
case "jpeg", "jpg":
|
||||
mimeType = "image/jpeg"
|
||||
case "webp":
|
||||
mimeType = "image/webp"
|
||||
default:
|
||||
mimeType = "image/png"
|
||||
}
|
||||
}
|
||||
|
||||
vertexRequest := &vertexChatRequest{
|
||||
Contents: []vertexChatContent{{
|
||||
Role: roleUser,
|
||||
Parts: []vertexPart{{
|
||||
Text: request.Prompt,
|
||||
}},
|
||||
}},
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: vertexChatGenerationConfig{
|
||||
Temperature: 1.0,
|
||||
MaxOutputTokens: 32768,
|
||||
ResponseModalities: []string{"TEXT", "IMAGE"},
|
||||
ImageConfig: &vertexImageConfig{
|
||||
AspectRatio: aspectRatio,
|
||||
ImageSize: imageSize,
|
||||
ImageOutputOptions: &vertexImageOutputOptions{
|
||||
MimeType: mimeType,
|
||||
},
|
||||
PersonGeneration: "ALLOW_ALL",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return vertexRequest
|
||||
}
|
||||
|
||||
// parseImageSize 解析 OpenAI 格式的尺寸字符串(如 "1024x1024")为 Vertex AI 的 aspectRatio 和 imageSize
|
||||
// Vertex AI 支持的 aspectRatio: 1:1, 3:2, 2:3, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9, 21:9
|
||||
// Vertex AI 支持的 imageSize: 1k, 2k, 4k
|
||||
func (v *vertexProvider) parseImageSize(size string) (aspectRatio, imageSize string) {
|
||||
// 默认值
|
||||
aspectRatio = "1:1"
|
||||
imageSize = "1k"
|
||||
|
||||
if size == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 预定义的尺寸映射(OpenAI 标准尺寸)
|
||||
sizeMapping := map[string]struct {
|
||||
aspectRatio string
|
||||
imageSize string
|
||||
}{
|
||||
// OpenAI DALL-E 标准尺寸
|
||||
"256x256": {"1:1", "1k"},
|
||||
"512x512": {"1:1", "1k"},
|
||||
"1024x1024": {"1:1", "1k"},
|
||||
"1792x1024": {"16:9", "2k"},
|
||||
"1024x1792": {"9:16", "2k"},
|
||||
// 扩展尺寸支持
|
||||
"2048x2048": {"1:1", "2k"},
|
||||
"4096x4096": {"1:1", "4k"},
|
||||
// 3:2 和 2:3 比例
|
||||
"1536x1024": {"3:2", "2k"},
|
||||
"1024x1536": {"2:3", "2k"},
|
||||
// 4:3 和 3:4 比例
|
||||
"1024x768": {"4:3", "1k"},
|
||||
"768x1024": {"3:4", "1k"},
|
||||
"1365x1024": {"4:3", "1k"},
|
||||
"1024x1365": {"3:4", "1k"},
|
||||
// 5:4 和 4:5 比例
|
||||
"1280x1024": {"5:4", "1k"},
|
||||
"1024x1280": {"4:5", "1k"},
|
||||
// 21:9 超宽比例
|
||||
"2560x1080": {"21:9", "2k"},
|
||||
}
|
||||
|
||||
if mapping, ok := sizeMapping[size]; ok {
|
||||
return mapping.aspectRatio, mapping.imageSize
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
|
||||
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX
|
||||
@@ -394,10 +513,16 @@ func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.TransformResponseBody(ctx, apiName, body)
|
||||
}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
} else {
|
||||
case ApiNameEmbeddings:
|
||||
return v.onEmbeddingsResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
return v.onImageGenerationResponseBody(ctx, body)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -490,6 +615,54 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
|
||||
return &response
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
// 使用 gjson 直接提取字段,避免完整反序列化大型 base64 数据
|
||||
// 这样可以显著减少内存分配和复制次数
|
||||
response := v.buildImageGenerationResponseFromJSON(body)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
// buildImageGenerationResponseFromJSON 使用 gjson 从原始 JSON 中提取图片生成响应
|
||||
// 相比 json.Unmarshal 完整反序列化,这种方式内存效率更高
|
||||
func (v *vertexProvider) buildImageGenerationResponseFromJSON(body []byte) *imageGenerationResponse {
|
||||
result := gjson.ParseBytes(body)
|
||||
data := make([]imageGenerationData, 0)
|
||||
|
||||
// 遍历所有 candidates,提取图片数据
|
||||
candidates := result.Get("candidates")
|
||||
candidates.ForEach(func(_, candidate gjson.Result) bool {
|
||||
parts := candidate.Get("content.parts")
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
// 跳过思考过程 (thought: true)
|
||||
if part.Get("thought").Bool() {
|
||||
return true
|
||||
}
|
||||
// 提取图片数据
|
||||
inlineData := part.Get("inlineData.data")
|
||||
if inlineData.Exists() && inlineData.String() != "" {
|
||||
data = append(data, imageGenerationData{
|
||||
B64: inlineData.String(),
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
|
||||
// 提取 usage 信息
|
||||
usage := result.Get("usageMetadata")
|
||||
|
||||
return &imageGenerationResponse{
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Data: data,
|
||||
Usage: &imageGenerationUsage{
|
||||
TotalTokens: int(usage.Get("totalTokenCount").Int()),
|
||||
InputTokens: int(usage.Get("promptTokenCount").Int()),
|
||||
OutputTokens: int(usage.Get("candidatesTokenCount").Int()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
||||
var choice chatCompletionChoice
|
||||
choice.Delta = &chatMessage{}
|
||||
@@ -574,12 +747,18 @@ func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if apiName == ApiNameEmbeddings {
|
||||
switch apiName {
|
||||
case ApiNameEmbeddings:
|
||||
action = vertexEmbeddingAction
|
||||
} else if stream {
|
||||
action = vertexChatCompletionStreamAction
|
||||
} else {
|
||||
case ApiNameImageGeneration:
|
||||
// 图片生成使用非流式端点,需要完整响应
|
||||
action = vertexChatCompletionAction
|
||||
default:
|
||||
if stream {
|
||||
action = vertexChatCompletionStreamAction
|
||||
} else {
|
||||
action = vertexChatCompletionAction
|
||||
}
|
||||
}
|
||||
|
||||
if v.isExpressMode() {
|
||||
@@ -689,7 +868,7 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
|
||||
})
|
||||
}
|
||||
case contentTypeImageUrl:
|
||||
vpart, err := convertImageContent(part.ImageUrl.Url)
|
||||
vpart, err := convertMediaContent(part.ImageUrl.Url)
|
||||
if err != nil {
|
||||
log.Errorf("unable to convert image content: %v", err)
|
||||
} else {
|
||||
@@ -804,12 +983,25 @@ type vertexChatSafetySetting struct {
|
||||
}
|
||||
|
||||
type vertexChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
ThinkingConfig vertexThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ImageConfig *vertexImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
type vertexImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||
ImageSize string `json:"imageSize,omitempty"`
|
||||
ImageOutputOptions *vertexImageOutputOptions `json:"imageOutputOptions,omitempty"`
|
||||
PersonGeneration string `json:"personGeneration,omitempty"`
|
||||
}
|
||||
|
||||
type vertexImageOutputOptions struct {
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
}
|
||||
|
||||
type vertexThinkingConfig struct {
|
||||
@@ -1020,32 +1212,106 @@ func setCachedAccessToken(key string, accessToken string, expireTime int64) erro
|
||||
return proxywasm.SetSharedData(key, data, cas)
|
||||
}
|
||||
|
||||
func convertImageContent(imageUrl string) (vertexPart, error) {
|
||||
// convertMediaContent 将 OpenAI 格式的媒体 URL 转换为 Vertex AI 格式
|
||||
// 支持图片、视频、音频等多种媒体类型
|
||||
func convertMediaContent(mediaUrl string) (vertexPart, error) {
|
||||
part := vertexPart{}
|
||||
if strings.HasPrefix(imageUrl, "http") {
|
||||
arr := strings.Split(imageUrl, ".")
|
||||
mimeType := "image/" + arr[len(arr)-1]
|
||||
if strings.HasPrefix(mediaUrl, "http") {
|
||||
mimeType := detectMimeTypeFromURL(mediaUrl)
|
||||
part.FileData = &fileData{
|
||||
MimeType: mimeType,
|
||||
FileUri: imageUrl,
|
||||
FileUri: mediaUrl,
|
||||
}
|
||||
return part, nil
|
||||
} else {
|
||||
// Base64 data URL 格式: data:<mimeType>;base64,<data>
|
||||
re := regexp.MustCompile(`^data:([^;]+);base64,`)
|
||||
matches := re.FindStringSubmatch(imageUrl)
|
||||
matches := re.FindStringSubmatch(mediaUrl)
|
||||
if len(matches) < 2 {
|
||||
return part, fmt.Errorf("invalid base64 format")
|
||||
return part, fmt.Errorf("invalid base64 format, expected data:<mimeType>;base64,<data>")
|
||||
}
|
||||
|
||||
mimeType := matches[1] // e.g. image/png
|
||||
mimeType := matches[1] // e.g. image/png, video/mp4, audio/mp3
|
||||
parts := strings.Split(mimeType, "/")
|
||||
if len(parts) < 2 {
|
||||
return part, fmt.Errorf("invalid mimeType")
|
||||
return part, fmt.Errorf("invalid mimeType: %s", mimeType)
|
||||
}
|
||||
part.InlineData = &blob{
|
||||
MimeType: mimeType,
|
||||
Data: strings.TrimPrefix(imageUrl, matches[0]),
|
||||
Data: strings.TrimPrefix(mediaUrl, matches[0]),
|
||||
}
|
||||
return part, nil
|
||||
}
|
||||
}
|
||||
|
||||
// detectMimeTypeFromURL 根据 URL 的文件扩展名检测 MIME 类型
|
||||
// 支持图片、视频、音频和文档类型
|
||||
func detectMimeTypeFromURL(url string) string {
|
||||
// 移除查询参数和片段标识符
|
||||
if idx := strings.Index(url, "?"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
if idx := strings.Index(url, "#"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
|
||||
// 获取最后一个路径段
|
||||
lastSlash := strings.LastIndex(url, "/")
|
||||
if lastSlash != -1 {
|
||||
url = url[lastSlash+1:]
|
||||
}
|
||||
|
||||
// 获取扩展名
|
||||
lastDot := strings.LastIndex(url, ".")
|
||||
if lastDot == -1 || lastDot == len(url)-1 {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
ext := strings.ToLower(url[lastDot+1:])
|
||||
|
||||
// 扩展名到 MIME 类型的映射
|
||||
mimeTypes := map[string]string{
|
||||
// 图片格式
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
"svg": "image/svg+xml",
|
||||
"ico": "image/x-icon",
|
||||
"heic": "image/heic",
|
||||
"heif": "image/heif",
|
||||
"tiff": "image/tiff",
|
||||
"tif": "image/tiff",
|
||||
// 视频格式
|
||||
"mp4": "video/mp4",
|
||||
"mpeg": "video/mpeg",
|
||||
"mpg": "video/mpeg",
|
||||
"mov": "video/quicktime",
|
||||
"avi": "video/x-msvideo",
|
||||
"wmv": "video/x-ms-wmv",
|
||||
"webm": "video/webm",
|
||||
"mkv": "video/x-matroska",
|
||||
"flv": "video/x-flv",
|
||||
"3gp": "video/3gpp",
|
||||
"3g2": "video/3gpp2",
|
||||
"m4v": "video/x-m4v",
|
||||
// 音频格式
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
"m4a": "audio/mp4",
|
||||
"wma": "audio/x-ms-wma",
|
||||
"opus": "audio/opus",
|
||||
// 文档格式
|
||||
"pdf": "application/pdf",
|
||||
}
|
||||
|
||||
if mimeType, ok := mimeTypes[ext]; ok {
|
||||
return mimeType
|
||||
}
|
||||
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user