diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 79c2b321b..6db9b5d48 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -203,19 +203,20 @@ type claudeThinkingConfig struct { } type claudeTextGenRequest struct { - Model string `json:"model"` - Messages []claudeChatMessage `json:"messages"` - System *claudeSystemPrompt `json:"system,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"` - Tools []claudeTool `json:"tools,omitempty"` - ServiceTier string `json:"service_tier,omitempty"` - Thinking *claudeThinkingConfig `json:"thinking,omitempty"` + Model string `json:"model,omitempty"` + Messages []claudeChatMessage `json:"messages"` + System *claudeSystemPrompt `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"` + Tools []claudeTool `json:"tools,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + Thinking *claudeThinkingConfig `json:"thinking,omitempty"` + AnthropicVersion string `json:"anthropic_version,omitempty"` } type claudeTextGenResponse struct { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index fe15cc3ad..c66f5c44e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -27,11 +27,16 @@ const ( vertexAuthDomain = "oauth2.googleapis.com" vertexDomain = "aiplatform.googleapis.com" // /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION} - vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s" - vertexChatCompletionAction = "generateContent" - vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse" - vertexEmbeddingAction = "predict" - vertexGlobalRegion = "global" + vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s" + vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s" + vertexChatCompletionAction = "generateContent" + vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse" + vertexAnthropicMessageAction = "rawPredict" + vertexAnthropicMessageStreamAction = "streamRawPredict" + vertexEmbeddingAction = "predict" + vertexGlobalRegion = "global" + contextClaudeMarker = "isClaudeRequest" + vertexAnthropicVersion = "vertex-2023-10-16" ) type vertexProviderInitializer struct{} @@ -66,6 +71,10 @@ func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provi Port: 443, }), contextCache: createContextCache(&config), + claude: &claudeProvider{ + config: config, + contextCache: createContextCache(&config), + }, }, nil } @@ -73,6 +82,7 @@ type vertexProvider struct { client wrapper.HttpClient config ProviderConfig contextCache *contextCache + claude *claudeProvider } func (v *vertexProvider) GetProviderType() string { @@ -145,6 +155,7 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } headers := util.GetRequestHeaders() body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers) + headers.Set("Content-Length", fmt.Sprint(len(body))) util.ReplaceRequestHeaders(headers) _ = proxywasm.ReplaceHttpRequestBody(body) if err != nil { @@ -174,11 +185,26 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo if err != nil { return nil, err } - path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) - util.OverwriteRequestPathHeader(headers, path) + if strings.HasPrefix(request.Model, "claude") { + ctx.SetContext(contextClaudeMarker, true) + path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + util.OverwriteRequestPathHeader(headers, path) - vertexRequest := v.buildVertexChatRequest(request) - return json.Marshal(vertexRequest) + claudeRequest := v.claude.buildClaudeTextGenRequest(request) + claudeRequest.Model = "" + claudeRequest.AnthropicVersion = vertexAnthropicVersion + claudeBody, err := json.Marshal(claudeRequest) + if err != nil { + return nil, err + } + return claudeBody, nil + } else { + path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + util.OverwriteRequestPathHeader(headers, path) + + vertexRequest := v.buildVertexChatRequest(request) + return json.Marshal(vertexRequest) + } } func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) { @@ -194,6 +220,9 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [ } func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { + if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) { + return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk) + } log.Infof("[vertexProvider] receive chunk body: %s", string(chunk)) if isLastChunk { return []byte(ssePrefix + "[DONE]\n\n"), nil @@ -231,6 +260,9 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A } func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { + if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) { + return v.claude.TransformResponseBody(ctx, apiName, body) + } if apiName == ApiNameChatCompletion { return v.onChatCompletionResponseBody(ctx, body) } else { @@ -326,6 +358,7 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse { var choice chatCompletionChoice + choice.Delta = &chatMessage{} if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 { part := vertexResp.Candidates[0].Content.Parts[0] if part.FunctionCall != nil { @@ -376,6 +409,16 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } +func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string { + action := "" + if stream { + action = vertexAnthropicMessageStreamAction + } else { + action = vertexAnthropicMessageAction + } + return fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action) +} + func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string { action := "" if apiName == ApiNameEmbeddings {