mirror of
https://github.com/alibaba/higress.git
synced 2026-04-21 20:17:29 +08:00
support vertex's claude (#3236)
This commit is contained in:
@@ -203,19 +203,20 @@ type claudeThinkingConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type claudeTextGenRequest struct {
|
type claudeTextGenRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model,omitempty"`
|
||||||
Messages []claudeChatMessage `json:"messages"`
|
Messages []claudeChatMessage `json:"messages"`
|
||||||
System *claudeSystemPrompt `json:"system,omitempty"`
|
System *claudeSystemPrompt `json:"system,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||||
Tools []claudeTool `json:"tools,omitempty"`
|
Tools []claudeTool `json:"tools,omitempty"`
|
||||||
ServiceTier string `json:"service_tier,omitempty"`
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
|
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
|
||||||
|
AnthropicVersion string `json:"anthropic_version,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type claudeTextGenResponse struct {
|
type claudeTextGenResponse struct {
|
||||||
|
|||||||
@@ -27,11 +27,16 @@ const (
|
|||||||
vertexAuthDomain = "oauth2.googleapis.com"
|
vertexAuthDomain = "oauth2.googleapis.com"
|
||||||
vertexDomain = "aiplatform.googleapis.com"
|
vertexDomain = "aiplatform.googleapis.com"
|
||||||
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
|
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
|
||||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||||
vertexChatCompletionAction = "generateContent"
|
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
|
||||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
vertexChatCompletionAction = "generateContent"
|
||||||
vertexEmbeddingAction = "predict"
|
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||||
vertexGlobalRegion = "global"
|
vertexAnthropicMessageAction = "rawPredict"
|
||||||
|
vertexAnthropicMessageStreamAction = "streamRawPredict"
|
||||||
|
vertexEmbeddingAction = "predict"
|
||||||
|
vertexGlobalRegion = "global"
|
||||||
|
contextClaudeMarker = "isClaudeRequest"
|
||||||
|
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||||
)
|
)
|
||||||
|
|
||||||
type vertexProviderInitializer struct{}
|
type vertexProviderInitializer struct{}
|
||||||
@@ -66,6 +71,10 @@ func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provi
|
|||||||
Port: 443,
|
Port: 443,
|
||||||
}),
|
}),
|
||||||
contextCache: createContextCache(&config),
|
contextCache: createContextCache(&config),
|
||||||
|
claude: &claudeProvider{
|
||||||
|
config: config,
|
||||||
|
contextCache: createContextCache(&config),
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,6 +82,7 @@ type vertexProvider struct {
|
|||||||
client wrapper.HttpClient
|
client wrapper.HttpClient
|
||||||
config ProviderConfig
|
config ProviderConfig
|
||||||
contextCache *contextCache
|
contextCache *contextCache
|
||||||
|
claude *claudeProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *vertexProvider) GetProviderType() string {
|
func (v *vertexProvider) GetProviderType() string {
|
||||||
@@ -145,6 +155,7 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
|||||||
}
|
}
|
||||||
headers := util.GetRequestHeaders()
|
headers := util.GetRequestHeaders()
|
||||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||||
|
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||||
util.ReplaceRequestHeaders(headers)
|
util.ReplaceRequestHeaders(headers)
|
||||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -174,11 +185,26 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
if strings.HasPrefix(request.Model, "claude") {
|
||||||
util.OverwriteRequestPathHeader(headers, path)
|
ctx.SetContext(contextClaudeMarker, true)
|
||||||
|
path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||||
|
util.OverwriteRequestPathHeader(headers, path)
|
||||||
|
|
||||||
vertexRequest := v.buildVertexChatRequest(request)
|
claudeRequest := v.claude.buildClaudeTextGenRequest(request)
|
||||||
return json.Marshal(vertexRequest)
|
claudeRequest.Model = ""
|
||||||
|
claudeRequest.AnthropicVersion = vertexAnthropicVersion
|
||||||
|
claudeBody, err := json.Marshal(claudeRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return claudeBody, nil
|
||||||
|
} else {
|
||||||
|
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||||
|
util.OverwriteRequestPathHeader(headers, path)
|
||||||
|
|
||||||
|
vertexRequest := v.buildVertexChatRequest(request)
|
||||||
|
return json.Marshal(vertexRequest)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||||
@@ -194,6 +220,9 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||||
|
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||||
|
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
|
||||||
|
}
|
||||||
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
|
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
|
||||||
if isLastChunk {
|
if isLastChunk {
|
||||||
return []byte(ssePrefix + "[DONE]\n\n"), nil
|
return []byte(ssePrefix + "[DONE]\n\n"), nil
|
||||||
@@ -231,6 +260,9 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||||
|
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||||
|
return v.claude.TransformResponseBody(ctx, apiName, body)
|
||||||
|
}
|
||||||
if apiName == ApiNameChatCompletion {
|
if apiName == ApiNameChatCompletion {
|
||||||
return v.onChatCompletionResponseBody(ctx, body)
|
return v.onChatCompletionResponseBody(ctx, body)
|
||||||
} else {
|
} else {
|
||||||
@@ -326,6 +358,7 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
|
|||||||
|
|
||||||
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
||||||
var choice chatCompletionChoice
|
var choice chatCompletionChoice
|
||||||
|
choice.Delta = &chatMessage{}
|
||||||
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
|
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
|
||||||
part := vertexResp.Candidates[0].Content.Parts[0]
|
part := vertexResp.Candidates[0].Content.Parts[0]
|
||||||
if part.FunctionCall != nil {
|
if part.FunctionCall != nil {
|
||||||
@@ -376,6 +409,16 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon
|
|||||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||||
|
action := ""
|
||||||
|
if stream {
|
||||||
|
action = vertexAnthropicMessageStreamAction
|
||||||
|
} else {
|
||||||
|
action = vertexAnthropicMessageAction
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||||
|
}
|
||||||
|
|
||||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||||
action := ""
|
action := ""
|
||||||
if apiName == ApiNameEmbeddings {
|
if apiName == ApiNameEmbeddings {
|
||||||
|
|||||||
Reference in New Issue
Block a user