mirror of
https://github.com/alibaba/higress.git
synced 2026-06-02 17:17:27 +08:00
feat/ai proxy vertex ai compatible (#3324)
This commit is contained in:
@@ -387,6 +387,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Vertex token刷新提前时间
|
||||
// @Description zh-CN 用于Google服务账号认证,access token过期时间判定提前刷新,单位为秒,默认值为60秒
|
||||
vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"`
|
||||
// @Title zh-CN Vertex AI OpenAI兼容模式
|
||||
// @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API,请求和响应均使用OpenAI格式,无需协议转换。与Express Mode(apiTokens)互斥。
|
||||
vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"`
|
||||
// @Title zh-CN 翻译服务需指定的目标语种
|
||||
// @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。
|
||||
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
|
||||
@@ -540,6 +543,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
if c.vertexTokenRefreshAhead == 0 {
|
||||
c.vertexTokenRefreshAhead = 60
|
||||
}
|
||||
c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool()
|
||||
c.targetLang = json.Get("targetLang").String()
|
||||
|
||||
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -32,6 +33,9 @@ const (
|
||||
// Express Mode 路径模板 (不含 project/location)
|
||||
vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s"
|
||||
vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s"
|
||||
// OpenAI-compatible endpoint 路径模板
|
||||
// /v1beta1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi/chat/completions
|
||||
vertexOpenAICompatiblePathTemplate = "/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions"
|
||||
vertexChatCompletionAction = "generateContent"
|
||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||
vertexAnthropicMessageAction = "rawPredict"
|
||||
@@ -39,6 +43,7 @@ const (
|
||||
vertexEmbeddingAction = "predict"
|
||||
vertexGlobalRegion = "global"
|
||||
contextClaudeMarker = "isClaudeRequest"
|
||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
)
|
||||
|
||||
@@ -47,10 +52,28 @@ type vertexProviderInitializer struct{}
|
||||
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
// Express Mode: 如果配置了 apiTokens,则使用 API Key 认证
|
||||
if len(config.apiTokens) > 0 {
|
||||
// Express Mode 与 OpenAI 兼容模式互斥
|
||||
if config.vertexOpenAICompatible {
|
||||
return errors.New("vertexOpenAICompatible is not compatible with Express Mode (apiTokens)")
|
||||
}
|
||||
// Express Mode 不需要其他配置
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenAI 兼容模式: 需要 OAuth 认证配置
|
||||
if config.vertexOpenAICompatible {
|
||||
if config.vertexAuthKey == "" {
|
||||
return errors.New("missing vertexAuthKey in vertex provider config for OpenAI compatible mode")
|
||||
}
|
||||
if config.vertexRegion == "" || config.vertexProjectId == "" {
|
||||
return errors.New("missing vertexRegion or vertexProjectId in vertex provider config for OpenAI compatible mode")
|
||||
}
|
||||
if config.vertexAuthServiceName == "" {
|
||||
return errors.New("missing vertexAuthServiceName in vertex provider config for OpenAI compatible mode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标准模式: 保持原有验证逻辑
|
||||
if config.vertexAuthKey == "" {
|
||||
return errors.New("missing vertexAuthKey in vertex provider config")
|
||||
@@ -101,6 +124,12 @@ func (v *vertexProvider) isExpressMode() bool {
|
||||
return len(v.config.apiTokens) > 0
|
||||
}
|
||||
|
||||
// isOpenAICompatibleMode 检测是否启用 OpenAI 兼容模式
|
||||
// 使用 Vertex AI 的 OpenAI-compatible Chat Completions API
|
||||
func (v *vertexProvider) isOpenAICompatibleMode() bool {
|
||||
return v.config.vertexOpenAICompatible
|
||||
}
|
||||
|
||||
type vertexProvider struct {
|
||||
client wrapper.HttpClient
|
||||
config ProviderConfig
|
||||
@@ -184,7 +213,30 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if v.config.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
headers := util.GetRequestHeaders()
|
||||
|
||||
// OpenAI 兼容模式: 不转换请求体,只设置路径和进行模型映射
|
||||
if v.isOpenAICompatibleMode() {
|
||||
ctx.SetContext(contextOpenAICompatibleMarker, true)
|
||||
body, err := v.onOpenAICompatibleRequestBody(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
// OpenAI 兼容模式需要 OAuth token
|
||||
cached, err := v.getToken()
|
||||
if cached {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
|
||||
@@ -220,6 +272,32 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
||||
}
|
||||
}
|
||||
|
||||
// onOpenAICompatibleRequestBody 处理 OpenAI 兼容模式的请求
|
||||
// 不转换请求体格式,只进行模型映射和路径设置
|
||||
func (v *vertexProvider) onOpenAICompatibleRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return nil, fmt.Errorf("OpenAI compatible mode only supports chat completions API")
|
||||
}
|
||||
|
||||
// 解析请求进行模型映射
|
||||
request := &chatCompletionRequest{}
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置 OpenAI 兼容端点路径
|
||||
path := v.getOpenAICompatibleRequestPath()
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
// 如果模型被映射,需要更新请求体中的模型字段
|
||||
if request.Model != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", request.Model)
|
||||
}
|
||||
|
||||
// 保持 OpenAI 格式,直接返回(可能更新了模型字段)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
err := v.config.parseRequestAndMapModel(ctx, request, body)
|
||||
@@ -261,6 +339,12 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
|
||||
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX
|
||||
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
|
||||
return util.DecodeUnicodeEscapesInSSE(chunk), nil
|
||||
}
|
||||
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
|
||||
}
|
||||
@@ -301,6 +385,12 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
|
||||
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX
|
||||
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
|
||||
return util.DecodeUnicodeEscapes(body), nil
|
||||
}
|
||||
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.TransformResponseBody(ctx, apiName, body)
|
||||
}
|
||||
@@ -510,6 +600,11 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
return path
|
||||
}
|
||||
|
||||
// getOpenAICompatibleRequestPath 获取 OpenAI 兼容模式的请求路径
|
||||
func (v *vertexProvider) getOpenAICompatibleRequestPath() string {
|
||||
return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
|
||||
Reference in New Issue
Block a user