support vertex's claude (#3236)

This commit is contained in:
rinfx
2025-12-20 10:33:53 +08:00
committed by GitHub
parent 08d4f556a1
commit e8bcbde5f4
2 changed files with 66 additions and 22 deletions

View File

@@ -203,19 +203,20 @@ type claudeThinkingConfig struct {
}
type claudeTextGenRequest struct {
Model string `json:"model"`
Messages []claudeChatMessage `json:"messages"`
System *claudeSystemPrompt `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
Tools []claudeTool `json:"tools,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
Model string `json:"model,omitempty"`
Messages []claudeChatMessage `json:"messages"`
System *claudeSystemPrompt `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
Tools []claudeTool `json:"tools,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
AnthropicVersion string `json:"anthropic_version,omitempty"`
}
type claudeTextGenResponse struct {

View File

@@ -27,11 +27,16 @@ const (
vertexAuthDomain = "oauth2.googleapis.com"
vertexDomain = "aiplatform.googleapis.com"
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
vertexChatCompletionAction = "generateContent"
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
vertexEmbeddingAction = "predict"
vertexGlobalRegion = "global"
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
vertexChatCompletionAction = "generateContent"
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
vertexAnthropicMessageAction = "rawPredict"
vertexAnthropicMessageStreamAction = "streamRawPredict"
vertexEmbeddingAction = "predict"
vertexGlobalRegion = "global"
contextClaudeMarker = "isClaudeRequest"
vertexAnthropicVersion = "vertex-2023-10-16"
)
type vertexProviderInitializer struct{}
@@ -66,6 +71,10 @@ func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provi
Port: 443,
}),
contextCache: createContextCache(&config),
claude: &claudeProvider{
config: config,
contextCache: createContextCache(&config),
},
}, nil
}
@@ -73,6 +82,7 @@ type vertexProvider struct {
client wrapper.HttpClient
config ProviderConfig
contextCache *contextCache
claude *claudeProvider
}
func (v *vertexProvider) GetProviderType() string {
@@ -145,6 +155,7 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}
headers := util.GetRequestHeaders()
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
headers.Set("Content-Length", fmt.Sprint(len(body)))
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
@@ -174,11 +185,26 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
if err != nil {
return nil, err
}
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
if strings.HasPrefix(request.Model, "claude") {
ctx.SetContext(contextClaudeMarker, true)
path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexChatRequest(request)
return json.Marshal(vertexRequest)
claudeRequest := v.claude.buildClaudeTextGenRequest(request)
claudeRequest.Model = ""
claudeRequest.AnthropicVersion = vertexAnthropicVersion
claudeBody, err := json.Marshal(claudeRequest)
if err != nil {
return nil, err
}
return claudeBody, nil
} else {
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexChatRequest(request)
return json.Marshal(vertexRequest)
}
}
func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
@@ -194,6 +220,9 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
}
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
}
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
if isLastChunk {
return []byte(ssePrefix + "[DONE]\n\n"), nil
@@ -231,6 +260,9 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
}
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.TransformResponseBody(ctx, apiName, body)
}
if apiName == ApiNameChatCompletion {
return v.onChatCompletionResponseBody(ctx, body)
} else {
@@ -326,6 +358,7 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
var choice chatCompletionChoice
choice.Delta = &chatMessage{}
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
part := vertexResp.Candidates[0].Content.Parts[0]
if part.FunctionCall != nil {
@@ -376,6 +409,16 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string {
action := ""
if stream {
action = vertexAnthropicMessageStreamAction
} else {
action = vertexAnthropicMessageAction
}
return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
}
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {