mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 15:10:54 +08:00
support vertex's claude (#3236)
This commit is contained in:
@@ -203,19 +203,20 @@ type claudeThinkingConfig struct {
|
||||
}
|
||||
|
||||
type claudeTextGenRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []claudeChatMessage `json:"messages"`
|
||||
System *claudeSystemPrompt `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []claudeChatMessage `json:"messages"`
|
||||
System *claudeSystemPrompt `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
|
||||
AnthropicVersion string `json:"anthropic_version,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenResponse struct {
|
||||
|
||||
@@ -27,11 +27,16 @@ const (
|
||||
vertexAuthDomain = "oauth2.googleapis.com"
|
||||
vertexDomain = "aiplatform.googleapis.com"
|
||||
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexChatCompletionAction = "generateContent"
|
||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||
vertexEmbeddingAction = "predict"
|
||||
vertexGlobalRegion = "global"
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
|
||||
vertexChatCompletionAction = "generateContent"
|
||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||
vertexAnthropicMessageAction = "rawPredict"
|
||||
vertexAnthropicMessageStreamAction = "streamRawPredict"
|
||||
vertexEmbeddingAction = "predict"
|
||||
vertexGlobalRegion = "global"
|
||||
contextClaudeMarker = "isClaudeRequest"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
)
|
||||
|
||||
type vertexProviderInitializer struct{}
|
||||
@@ -66,6 +71,10 @@ func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provi
|
||||
Port: 443,
|
||||
}),
|
||||
contextCache: createContextCache(&config),
|
||||
claude: &claudeProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -73,6 +82,7 @@ type vertexProvider struct {
|
||||
client wrapper.HttpClient
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
claude *claudeProvider
|
||||
}
|
||||
|
||||
func (v *vertexProvider) GetProviderType() string {
|
||||
@@ -145,6 +155,7 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
}
|
||||
headers := util.GetRequestHeaders()
|
||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
@@ -174,11 +185,26 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
if strings.HasPrefix(request.Model, "claude") {
|
||||
ctx.SetContext(contextClaudeMarker, true)
|
||||
path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexChatRequest(request)
|
||||
return json.Marshal(vertexRequest)
|
||||
claudeRequest := v.claude.buildClaudeTextGenRequest(request)
|
||||
claudeRequest.Model = ""
|
||||
claudeRequest.AnthropicVersion = vertexAnthropicVersion
|
||||
claudeBody, err := json.Marshal(claudeRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claudeBody, nil
|
||||
} else {
|
||||
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexChatRequest(request)
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
@@ -194,6 +220,9 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
|
||||
}
|
||||
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
|
||||
if isLastChunk {
|
||||
return []byte(ssePrefix + "[DONE]\n\n"), nil
|
||||
@@ -231,6 +260,9 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.TransformResponseBody(ctx, apiName, body)
|
||||
}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
} else {
|
||||
@@ -326,6 +358,7 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
|
||||
|
||||
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
||||
var choice chatCompletionChoice
|
||||
choice.Delta = &chatMessage{}
|
||||
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
|
||||
part := vertexResp.Candidates[0].Content.Parts[0]
|
||||
if part.FunctionCall != nil {
|
||||
@@ -376,6 +409,16 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if stream {
|
||||
action = vertexAnthropicMessageStreamAction
|
||||
} else {
|
||||
action = vertexAnthropicMessageAction
|
||||
}
|
||||
return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if apiName == ApiNameEmbeddings {
|
||||
|
||||
Reference in New Issue
Block a user