mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
feat(ai-proxy): support Google Cloud Vertex (#2119)
Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
@@ -262,6 +262,19 @@ Dify 所对应的 `type` 为 `dify`。它特有的配置字段如下:
|
||||
| `inputVariable` | string | 非必填 | - | dify 中应用类型为 workflow 时需要设置输入变量,当 botType 为 workflow 时一起使用 |
|
||||
| `outputVariable` | string | 非必填 | - | dify 中应用类型为 workflow 时需要设置输出变量,当 botType 为 workflow 时一起使用 |
|
||||
|
||||
#### Google Vertex AI
|
||||
|
||||
Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
| `vertexAuthKey` | string | 必填 | - | 用于认证的 Google Service Account JSON Key,格式为 PEM 编码的 PKCS#8 私钥和 client_email 等信息 |
|
||||
| `vertexRegion` | string | 必填 | - | Google Cloud 区域(如 us-central1, europe-west4 等),用于构建 Vertex API 地址 |
|
||||
| `vertexProjectId` | string | 必填 | - | Google Cloud 项目 ID,用于标识目标 GCP 项目 |
|
||||
| `vertexAuthServiceName` | string | 必填 | - | 用于 OAuth2 认证的服务名称,该服务为了访问oauth2.googleapis.com |
|
||||
| `vertexGeminiSafetySetting` | map of string | 非必填 | - | Gemini 模型的内容安全过滤设置。 |
|
||||
| `vertexTokenRefreshAhead` | number | 非必填 | - | Vertex access token刷新提前时间(单位秒) |
|
||||
|
||||
## 用法示例
|
||||
|
||||
### 使用 OpenAI 协议代理 Azure OpenAI 服务
|
||||
@@ -1629,6 +1642,69 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
vertexAuthKey: |
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "your-project-id",
|
||||
"private_key_id": "your-private-key-id",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
vertexRegion: us-central1
|
||||
vertexProjectId: your-project-id
|
||||
vertexAuthServiceName: your-auth-service-name
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.0-flash-001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是 Vertex AI 提供的 Gemini 模型,由 Google 开发的人工智能助手。我可以回答问题、提供信息和帮助完成各种任务。有什么我可以帮您的吗?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.0-flash-001",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 15,
|
||||
"completion_tokens": 43,
|
||||
"total_tokens": 58
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## 完整配置示例
|
||||
|
||||
### Kubernetes 示例
|
||||
|
||||
@@ -208,6 +208,18 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i
|
||||
| ------------ | --------- | ----------- | ------- | ------------------------------------ |
|
||||
| `targetLang` | string | Required | - | The target language required by the DeepL translation service |
|
||||
|
||||
#### Google Vertex AI
|
||||
For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is:
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|---------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `vertexAuthKey` | string | Required | - | Google Service Account JSON Key used for authentication. The format should be PEM encoded PKCS#8 private key along with client_email and other information |
|
||||
| `vertexRegion` | string | Required | - | Google Cloud region (e.g., us-central1, europe-west4) used to build the Vertex API address |
|
||||
| `vertexProjectId` | string | Required | - | Google Cloud Project ID, used to identify the target GCP project |
|
||||
| `vertexAuthServiceName` | string | Required | - | Service name for OAuth2 authentication, used to access oauth2.googleapis.com |
|
||||
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
|
||||
| `vertexTokenRefreshAhead` | number | Optional | - | Vertex access token refresh ahead time in seconds |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Using OpenAI Protocol Proxy for Azure OpenAI Service
|
||||
@@ -1411,6 +1423,64 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
vertexAuthKey: |
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "your-project-id",
|
||||
"private_key_id": "your-private-key-id",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
vertexRegion: us-central1
|
||||
vertexProjectId: your-project-id
|
||||
vertexAuthServiceName: your-auth-service-name
|
||||
```
|
||||
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.0-flash-001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I am the Gemini model provided by Vertex AI, developed by Google. I can answer questions, provide information, and assist in completing various tasks. How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.0-flash-001",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 15,
|
||||
"completion_tokens": 43,
|
||||
"total_tokens": 58
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
### Kubernetes Example
|
||||
|
||||
@@ -90,6 +90,7 @@ const (
|
||||
providerTypeTogetherAI = "together-ai"
|
||||
providerTypeDify = "dify"
|
||||
providerTypeBedrock = "bedrock"
|
||||
providerTypeVertex = "vertex"
|
||||
|
||||
protocolOpenAI = "openai"
|
||||
protocolOriginal = "original"
|
||||
@@ -161,6 +162,7 @@ var (
|
||||
providerTypeTogetherAI: &togetherAIProviderInitializer{},
|
||||
providerTypeDify: &difyProviderInitializer{},
|
||||
providerTypeBedrock: &bedrockProviderInitializer{},
|
||||
providerTypeVertex: &vertexProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -298,6 +300,21 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Gemini AI内容过滤和安全级别设定
|
||||
// @Description zh-CN 仅适用于 Gemini AI 服务。参考:https://ai.google.dev/gemini-api/docs/safety-settings
|
||||
geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"`
|
||||
// @Title zh-CN Vertex AI访问区域
|
||||
// @Description zh-CN 仅适用于Vertex AI服务。如需查看支持的区域的完整列表,请参阅https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations?hl=zh-cn#available-regions
|
||||
vertexRegion string `required:"false" yaml:"vertexRegion" json:"vertexRegion"`
|
||||
// @Title zh-CN Vertex AI项目Id
|
||||
// @Description zh-CN 仅适用于Vertex AI服务。创建和管理项目请参阅https://cloud.google.com/resource-manager/docs/creating-managing-projects?hl=zh-cn#identifiers
|
||||
vertexProjectId string `required:"false" yaml:"vertexProjectId" json:"vertexProjectId"`
|
||||
// @Title zh-CN Vertex 认证秘钥
|
||||
// @Description zh-CN 用于Google服务账号认证的完整JSON密钥文件内容,获取可参考https://cloud.google.com/iam/docs/keys-create-delete?hl=zh-cn#iam-service-account-keys-create-console
|
||||
vertexAuthKey string `required:"false" yaml:"vertexAuthKey" json:"vertexAuthKey"`
|
||||
// @Title zh-CN Vertex 认证服务名
|
||||
// @Description zh-CN 用于Google服务账号认证的服务,DNS类型的服务名
|
||||
vertexAuthServiceName string `required:"false" yaml:"vertexAuthServiceName" json:"vertexAuthServiceName"`
|
||||
// @Title zh-CN Vertex token刷新提前时间
|
||||
// @Description zh-CN 用于Google服务账号认证,access token过期时间判定提前刷新,单位为秒,默认值为60秒
|
||||
vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"`
|
||||
// @Title zh-CN 翻译服务需指定的目标语种
|
||||
// @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。
|
||||
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
|
||||
@@ -390,12 +407,20 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.minimaxApiType = json.Get("minimaxApiType").String()
|
||||
c.minimaxGroupId = json.Get("minimaxGroupId").String()
|
||||
c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
|
||||
if c.typ == providerTypeGemini {
|
||||
if c.typ == providerTypeGemini || c.typ == providerTypeVertex {
|
||||
c.geminiSafetySetting = make(map[string]string)
|
||||
for k, v := range json.Get("geminiSafetySetting").Map() {
|
||||
c.geminiSafetySetting[k] = v.String()
|
||||
}
|
||||
}
|
||||
c.vertexRegion = json.Get("vertexRegion").String()
|
||||
c.vertexProjectId = json.Get("vertexProjectId").String()
|
||||
c.vertexAuthKey = json.Get("vertexAuthKey").String()
|
||||
c.vertexAuthServiceName = json.Get("vertexAuthServiceName").String()
|
||||
c.vertexTokenRefreshAhead = json.Get("vertexTokenRefreshAhead").Int()
|
||||
if c.vertexTokenRefreshAhead == 0 {
|
||||
c.vertexTokenRefreshAhead = 60
|
||||
}
|
||||
c.targetLang = json.Get("targetLang").String()
|
||||
|
||||
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
|
||||
|
||||
668
plugins/wasm-go/extensions/ai-proxy/provider/vertex.go
Normal file
668
plugins/wasm-go/extensions/ai-proxy/provider/vertex.go
Normal file
@@ -0,0 +1,668 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
vertexAuthDomain = "oauth2.googleapis.com"
|
||||
vertexDomain = "{REGION}-aiplatform.googleapis.com"
|
||||
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexChatCompletionAction = "generateContent"
|
||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||
vertexEmbeddingAction = "predict"
|
||||
)
|
||||
|
||||
type vertexProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.vertexAuthKey == "" {
|
||||
return errors.New("missing vertexAuthKey in vertex provider config")
|
||||
}
|
||||
if config.vertexRegion == "" || config.vertexProjectId == "" {
|
||||
return errors.New("missing vertexRegion or vertexProjectId in vertex provider config")
|
||||
}
|
||||
if config.vertexAuthServiceName == "" {
|
||||
return errors.New("missing vertexAuthServiceName in vertex provider config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): vertexPathTemplate,
|
||||
string(ApiNameEmbeddings): vertexPathTemplate,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(v.DefaultCapabilities())
|
||||
return &vertexProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
Domain: vertexAuthDomain,
|
||||
ServiceName: config.vertexAuthServiceName,
|
||||
Port: 443,
|
||||
}),
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type vertexProvider struct {
|
||||
client wrapper.HttpClient
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (v *vertexProvider) GetProviderType() string {
|
||||
return providerTypeVertex
|
||||
}
|
||||
|
||||
func (v *vertexProvider) GetApiName(path string) ApiName {
|
||||
if strings.HasSuffix(path, vertexChatCompletionAction) || strings.HasSuffix(path, vertexChatCompletionStreamAction) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if strings.HasSuffix(path, vertexEmbeddingAction) {
|
||||
return ApiNameEmbeddings
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
v.config.handleRequestHeaders(v, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
vertexRegionDomain := strings.Replace(vertexDomain, "{REGION}", v.config.vertexRegion, 1)
|
||||
util.OverwriteRequestHostHeader(headers, vertexRegionDomain)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getToken() (cached bool, err error) {
|
||||
cacheKeyName := v.buildTokenKey()
|
||||
cachedAccessToken, err := v.getCachedAccessToken(cacheKeyName)
|
||||
if err == nil && cachedAccessToken != "" {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+cachedAccessToken)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var key ServiceAccountKey
|
||||
if err := json.Unmarshal([]byte(v.config.vertexAuthKey), &key); err != nil {
|
||||
return false, fmt.Errorf("[vertex]: unable to unmarshal auth key json: %v", err)
|
||||
}
|
||||
|
||||
if key.ClientEmail == "" || key.PrivateKey == "" || key.TokenURI == "" {
|
||||
return false, fmt.Errorf("[vertex]: missing auth params")
|
||||
}
|
||||
|
||||
jwtToken, err := createJWT(&key)
|
||||
if err != nil {
|
||||
log.Errorf("[vertex]: unable to create JWT token: %v", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
err = v.getAccessToken(jwtToken)
|
||||
if err != nil {
|
||||
log.Errorf("[vertex]: unable to get access token: %v", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !v.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if v.config.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
headers := util.GetOriginalRequestHeaders()
|
||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
cached, err := v.getToken()
|
||||
if cached {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return v.onChatCompletionRequestBody(ctx, body, headers)
|
||||
} else {
|
||||
return v.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
err := v.config.parseRequestAndMapModel(ctx, request, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexChatRequest(request)
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &embeddingsRequest{}
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := v.getRequestPath(ApiNameEmbeddings, request.Model, false)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildEmbeddingRequest(request)
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
}
|
||||
responseBuilder := &strings.Builder{}
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, data := range lines {
|
||||
if len(data) < 6 {
|
||||
// ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
var vertexResp vertexChatResponse
|
||||
if err := json.Unmarshal([]byte(data), &vertexResp); err != nil {
|
||||
log.Errorf("unable to unmarshal vertex response: %v", err)
|
||||
continue
|
||||
}
|
||||
response := v.buildChatCompletionStreamResponse(ctx, &vertexResp)
|
||||
responseBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
log.Errorf("unable to marshal response: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
v.appendResponse(responseBuilder, string(responseBody))
|
||||
}
|
||||
modifiedResponseChunk := responseBuilder.String()
|
||||
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
} else {
|
||||
return v.onEmbeddingsResponseBody(ctx, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
vertexResponse := &vertexChatResponse{}
|
||||
if err := json.Unmarshal(body, vertexResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal vertex chat response: %v", err)
|
||||
}
|
||||
response := v.buildChatCompletionResponse(ctx, vertexResponse)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, response *vertexChatResponse) *chatCompletionResponse {
|
||||
fullTextResponse := chatCompletionResponse{
|
||||
Id: response.ResponseId,
|
||||
Object: objectChatCompletion,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Choices: make([]chatCompletionChoice, 0, len(response.Candidates)),
|
||||
Usage: usage{
|
||||
PromptTokens: response.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: response.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: response.UsageMetadata.TotalTokenCount,
|
||||
},
|
||||
}
|
||||
for _, candidate := range response.Candidates {
|
||||
choice := chatCompletionChoice{
|
||||
Index: candidate.Index,
|
||||
Message: &chatMessage{
|
||||
Role: roleAssistant,
|
||||
},
|
||||
FinishReason: candidate.FinishReason,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
||||
} else {
|
||||
choice.Message.Content = ""
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
vertexResponse := &vertexEmbeddingResponse{}
|
||||
if err := json.Unmarshal(body, vertexResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal vertex embeddings response: %v", err)
|
||||
}
|
||||
response := v.buildEmbeddingsResponse(ctx, vertexResponse)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertexResp *vertexEmbeddingResponse) *embeddingsResponse {
|
||||
response := embeddingsResponse{
|
||||
Object: "list",
|
||||
Data: make([]embedding, 0, len(vertexResp.Predictions)),
|
||||
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
|
||||
}
|
||||
totalTokens := 0
|
||||
for _, item := range vertexResp.Predictions {
|
||||
response.Data = append(response.Data, embedding{
|
||||
Object: `embedding`,
|
||||
Index: 0,
|
||||
Embedding: item.Embeddings.Values,
|
||||
})
|
||||
if item.Embeddings.Statistics != nil {
|
||||
totalTokens += item.Embeddings.Statistics.TokenCount
|
||||
}
|
||||
}
|
||||
response.Usage.TotalTokens = totalTokens
|
||||
return &response
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
||||
var choice chatCompletionChoice
|
||||
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
|
||||
choice.Delta = &chatMessage{Content: vertexResp.Candidates[0].Content.Parts[0].Text}
|
||||
}
|
||||
streamResponse := chatCompletionResponse{
|
||||
Id: vertexResp.ResponseId,
|
||||
Object: objectChatCompletionChunk,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Usage: usage{
|
||||
PromptTokens: vertexResp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: vertexResp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: vertexResp.UsageMetadata.TotalTokenCount,
|
||||
},
|
||||
}
|
||||
return &streamResponse
|
||||
}
|
||||
|
||||
func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if apiName == ApiNameEmbeddings {
|
||||
action = vertexEmbeddingAction
|
||||
} else if stream {
|
||||
action = vertexChatCompletionStreamAction
|
||||
} else {
|
||||
action = vertexChatCompletionAction
|
||||
}
|
||||
return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
|
||||
safetySettings := make([]vertexChatSafetySetting, 0)
|
||||
for category, threshold := range v.config.geminiSafetySetting {
|
||||
safetySettings = append(safetySettings, vertexChatSafetySetting{
|
||||
Category: category,
|
||||
Threshold: threshold,
|
||||
})
|
||||
}
|
||||
vertexRequest := vertexChatRequest{
|
||||
Contents: make([]vertexChatContent, 0),
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: vertexChatGenerationConfig{
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxOutputTokens: request.MaxTokens,
|
||||
},
|
||||
}
|
||||
if request.Tools != nil {
|
||||
functions := make([]function, 0, len(request.Tools))
|
||||
for _, tool := range request.Tools {
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
vertexRequest.Tools = []vertexTool{
|
||||
{
|
||||
FunctionDeclarations: functions,
|
||||
},
|
||||
}
|
||||
}
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range request.Messages {
|
||||
content := vertexChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []vertexPart{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// there's no assistant role in vertex and API shall vomit if role is not user or model
|
||||
if content.Role == roleAssistant {
|
||||
content.Role = "model"
|
||||
} else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason
|
||||
content.Role = roleUser
|
||||
shouldAddDummyModelMessage = true
|
||||
}
|
||||
vertexRequest.Contents = append(vertexRequest.Contents, content)
|
||||
|
||||
// if a system message is the last message, we need to add a dummy model message to make vertex happy
|
||||
if shouldAddDummyModelMessage {
|
||||
vertexRequest.Contents = append(vertexRequest.Contents, vertexChatContent{
|
||||
Role: "model",
|
||||
Parts: []vertexPart{
|
||||
{
|
||||
Text: "Okay",
|
||||
},
|
||||
},
|
||||
})
|
||||
shouldAddDummyModelMessage = false
|
||||
}
|
||||
}
|
||||
|
||||
return &vertexRequest
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildEmbeddingRequest(request *embeddingsRequest) *vertexEmbeddingRequest {
|
||||
inputs := request.ParseInput()
|
||||
instances := make([]vertexEmbeddingInstance, len(inputs))
|
||||
for i, input := range inputs {
|
||||
instances[i] = vertexEmbeddingInstance{
|
||||
Content: input,
|
||||
}
|
||||
}
|
||||
return &vertexEmbeddingRequest{Instances: instances}
|
||||
}
|
||||
|
||||
type vertexChatRequest struct {
|
||||
CachedContent string `json:"cachedContent,omitempty"`
|
||||
Contents []vertexChatContent `json:"contents"`
|
||||
SystemInstruction *vertexSystemInstruction `json:"systemInstruction,omitempty"`
|
||||
Tools []vertexTool `json:"tools,omitempty"`
|
||||
SafetySettings []vertexChatSafetySetting `json:"safetySettings,omitempty"`
|
||||
GenerationConfig vertexChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
}
|
||||
|
||||
type vertexChatContent struct {
|
||||
// The producer of the content. Must be either 'user' or 'model'.
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []vertexPart `json:"parts"`
|
||||
}
|
||||
|
||||
type vertexPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *blob `json:"inlineData,omitempty"`
|
||||
FileData *fileData `json:"fileData,omitempty"`
|
||||
}
|
||||
|
||||
type blob struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type fileData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
FileUri string `json:"fileUri"`
|
||||
}
|
||||
|
||||
type vertexSystemInstruction struct {
|
||||
Role string `json:"role"`
|
||||
Parts []vertexPart `json:"parts"`
|
||||
}
|
||||
|
||||
type vertexTool struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type vertexChatSafetySetting struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type vertexChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
}
|
||||
|
||||
type vertexEmbeddingRequest struct {
|
||||
Instances []vertexEmbeddingInstance `json:"instances"`
|
||||
Parameters *vertexEmbeddingParams `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type vertexEmbeddingInstance struct {
|
||||
TaskType string `json:"task_type"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type vertexEmbeddingParams struct {
|
||||
AutoTruncate bool `json:"autoTruncate,omitempty"`
|
||||
}
|
||||
|
||||
type vertexChatResponse struct {
|
||||
Candidates []vertexChatCandidate `json:"candidates"`
|
||||
ResponseId string `json:"responseId,omitempty"`
|
||||
PromptFeedback vertexChatPromptFeedback `json:"promptFeedback"`
|
||||
UsageMetadata vertexUsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type vertexChatCandidate struct {
|
||||
Content vertexChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
SafetyRatings []vertexChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type vertexChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type vertexChatPromptFeedback struct {
|
||||
SafetyRatings []vertexChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type vertexUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
}
|
||||
|
||||
type vertexEmbeddingResponse struct {
|
||||
Predictions []vertexPredictions `json:"predictions"`
|
||||
}
|
||||
|
||||
type vertexPredictions struct {
|
||||
Embeddings struct {
|
||||
Values []float64 `json:"values"`
|
||||
Statistics *vertexStatistics `json:"statistics,omitempty"`
|
||||
} `json:"embeddings"`
|
||||
}
|
||||
|
||||
type vertexStatistics struct {
|
||||
TokenCount int `json:"token_count"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
|
||||
type ServiceAccountKey struct {
|
||||
ClientEmail string `json:"client_email"`
|
||||
PrivateKeyID string `json:"private_key_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
TokenURI string `json:"token_uri"`
|
||||
}
|
||||
|
||||
func createJWT(key *ServiceAccountKey) (string, error) {
|
||||
// 解析 PEM 格式的 RSA 私钥
|
||||
block, _ := pem.Decode([]byte(key.PrivateKey))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("invalid PEM block")
|
||||
}
|
||||
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rsaKey := parsedKey.(*rsa.PrivateKey)
|
||||
|
||||
// 构造 JWT Header
|
||||
jwtHeader := map[string]string{
|
||||
"alg": "RS256",
|
||||
"typ": "JWT",
|
||||
"kid": key.PrivateKeyID,
|
||||
}
|
||||
headerJSON, _ := json.Marshal(jwtHeader)
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
|
||||
// 构造 JWT Claims
|
||||
now := time.Now().Unix()
|
||||
claims := map[string]interface{}{
|
||||
"iss": key.ClientEmail,
|
||||
"scope": "https://www.googleapis.com/auth/cloud-platform",
|
||||
"aud": key.TokenURI,
|
||||
"iat": now,
|
||||
"exp": now + 3600, // 1 小时有效期
|
||||
}
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
signingInput := fmt.Sprintf("%s.%s", headerB64, claimsB64)
|
||||
hashed := sha256.Sum256([]byte(signingInput))
|
||||
signature, err := rsaKey.Sign(nil, hashed[:], crypto.SHA256)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sigB64 := base64.RawURLEncoding.EncodeToString(signature)
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, sigB64), nil
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getAccessToken(jwtToken string) error {
|
||||
headers := [][2]string{
|
||||
{"Content-Type", "application/x-www-form-urlencoded"},
|
||||
}
|
||||
reqBody := "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion=" + jwtToken
|
||||
err := v.client.Post("/token", headers, []byte(reqBody), func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
responseString := string(responseBody)
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
if statusCode != http.StatusOK {
|
||||
log.Errorf("failed to create vertex access key, status: %d body: %s", statusCode, responseString)
|
||||
_ = util.ErrorHandler("ai-proxy.vertex.load_ak_failed", fmt.Errorf("failed to load vertex ak"))
|
||||
return
|
||||
}
|
||||
responseJson := gjson.Parse(responseString)
|
||||
accessToken := responseJson.Get("access_token").String()
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+accessToken)
|
||||
|
||||
expiresIn := int64(3600)
|
||||
if expiresInVal := responseJson.Get("expires_in"); expiresInVal.Exists() {
|
||||
expiresIn = expiresInVal.Int()
|
||||
}
|
||||
expireTime := time.Now().Add(time.Duration(expiresIn) * time.Second).Unix()
|
||||
keyName := v.buildTokenKey()
|
||||
err := setCachedAccessToken(keyName, accessToken, expireTime)
|
||||
if err != nil {
|
||||
log.Errorf("[vertex]: unable to cache access token: %v", err)
|
||||
}
|
||||
}, v.config.timeout)
|
||||
return err
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildTokenKey() string {
|
||||
region := v.config.vertexRegion
|
||||
projectID := v.config.vertexProjectId
|
||||
|
||||
return fmt.Sprintf("vertex-%s-%s-access-token", region, projectID)
|
||||
}
|
||||
|
||||
type cachedAccessToken struct {
|
||||
Token string `json:"token"`
|
||||
ExpireAt int64 `json:"expireAt"`
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getCachedAccessToken(key string) (string, error) {
|
||||
data, _, err := proxywasm.GetSharedData(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if data == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var tokenInfo cachedAccessToken
|
||||
if err = json.Unmarshal(data, &tokenInfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
refreshAhead := v.config.vertexTokenRefreshAhead
|
||||
|
||||
if tokenInfo.ExpireAt > now+refreshAhead {
|
||||
return tokenInfo.Token, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func setCachedAccessToken(key string, accessToken string, expireTime int64) error {
|
||||
tokenInfo := cachedAccessToken{
|
||||
Token: accessToken,
|
||||
ExpireAt: expireTime,
|
||||
}
|
||||
|
||||
_, cas, err := proxywasm.GetSharedData(key)
|
||||
if err != nil && !errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(tokenInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return proxywasm.SetSharedData(key, data, cas)
|
||||
}
|
||||
Reference in New Issue
Block a user