refactor code

This commit is contained in:
Yoan.liu
2025-05-22 17:09:14 +08:00
parent 75326b1ddd
commit 9cdc59b272
15 changed files with 312 additions and 176 deletions

View File

@@ -3,6 +3,7 @@ package domain
import (
"encoding/json"
"fmt"
"strconv"
)
type Value any
@@ -11,6 +12,7 @@ type (
ComparisonOperator string
LogicalOperator string
ValueType string
ExprType string
)
const (
@@ -29,6 +31,12 @@ const (
Number ValueType = "number"
String ValueType = "string"
Boolean ValueType = "boolean"
ConstExprType ExprType = "const"
VarExprType ExprType = "var"
CompareExprType ExprType = "compare"
LogicalExprType ExprType = "logical"
NotExprType ExprType = "not"
)
type EvalResult struct {
@@ -40,14 +48,40 @@ func (e *EvalResult) GetFloat64() (float64, error) {
if e.Type != Number {
return 0, fmt.Errorf("type mismatch: %s", e.Type)
}
switch v := e.Value.(type) {
case int:
return float64(v), nil
case float64:
return v, nil
default:
return 0, fmt.Errorf("unsupported type: %T", v)
stringValue, ok := e.Value.(string)
if !ok {
return 0, fmt.Errorf("value is not a string: %v", e.Value)
}
floatValue, err := strconv.ParseFloat(stringValue, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse float64: %v", err)
}
return floatValue, nil
}
func (e *EvalResult) GetBool() (bool, error) {
if e.Type != Boolean {
return false, fmt.Errorf("type mismatch: %s", e.Type)
}
strValue, ok := e.Value.(string)
if ok {
if strValue == "true" {
return true, nil
} else if strValue == "false" {
return false, nil
}
return false, fmt.Errorf("value is not a boolean: %v", e.Value)
}
boolValue, ok := e.Value.(bool)
if !ok {
return false, fmt.Errorf("value is not a boolean: %v", e.Value)
}
return boolValue, nil
}
func (e *EvalResult) GreaterThan(other *EvalResult) (*EvalResult, error) {
@@ -232,9 +266,17 @@ func (e *EvalResult) And(other *EvalResult) (*EvalResult, error) {
}
switch e.Type {
case Boolean:
left, err := e.GetBool()
if err != nil {
return nil, err
}
right, err := other.GetBool()
if err != nil {
return nil, err
}
return &EvalResult{
Type: Boolean,
Value: e.Value.(bool) && other.Value.(bool),
Value: left && right,
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
@@ -247,9 +289,17 @@ func (e *EvalResult) Or(other *EvalResult) (*EvalResult, error) {
}
switch e.Type {
case Boolean:
left, err := e.GetBool()
if err != nil {
return nil, err
}
right, err := other.GetBool()
if err != nil {
return nil, err
}
return &EvalResult{
Type: Boolean,
Value: e.Value.(bool) || other.Value.(bool),
Value: left || right,
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
@@ -260,9 +310,13 @@ func (e *EvalResult) Not() (*EvalResult, error) {
if e.Type != Boolean {
return nil, fmt.Errorf("type mismatch: %s", e.Type)
}
boolValue, err := e.GetBool()
if err != nil {
return nil, err
}
return &EvalResult{
Type: Boolean,
Value: !e.Value.(bool),
Value: !boolValue,
}, nil
}
@@ -272,9 +326,17 @@ func (e *EvalResult) Is(other *EvalResult) (*EvalResult, error) {
}
switch e.Type {
case Boolean:
left, err := e.GetBool()
if err != nil {
return nil, err
}
right, err := other.GetBool()
if err != nil {
return nil, err
}
return &EvalResult{
Type: Boolean,
Value: e.Value.(bool) == other.Value.(bool),
Value: left == right,
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
@@ -282,17 +344,17 @@ func (e *EvalResult) Is(other *EvalResult) (*EvalResult, error) {
}
type Expr interface {
GetType() string
GetType() ExprType
Eval(variables map[string]map[string]any) (*EvalResult, error)
}
type ConstExpr struct {
Type string `json:"type"`
Type ExprType `json:"type"`
Value Value `json:"value"`
ValueType ValueType `json:"valueType"`
}
func (c ConstExpr) GetType() string { return c.Type }
func (c ConstExpr) GetType() ExprType { return c.Type }
func (c ConstExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
return &EvalResult{
@@ -302,11 +364,11 @@ func (c ConstExpr) Eval(variables map[string]map[string]any) (*EvalResult, error
}
type VarExpr struct {
Type string `json:"type"`
Type ExprType `json:"type"`
Selector WorkflowNodeIOValueSelector `json:"selector"`
}
func (v VarExpr) GetType() string { return v.Type }
func (v VarExpr) GetType() ExprType { return v.Type }
func (v VarExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
if v.Selector.Id == "" {
@@ -330,13 +392,13 @@ func (v VarExpr) Eval(variables map[string]map[string]any) (*EvalResult, error)
}
type CompareExpr struct {
Type string `json:"type"` // compare
Type ExprType `json:"type"` // compare
Op ComparisonOperator `json:"op"`
Left Expr `json:"left"`
Right Expr `json:"right"`
}
func (c CompareExpr) GetType() string { return c.Type }
func (c CompareExpr) GetType() ExprType { return c.Type }
func (c CompareExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
left, err := c.Left.Eval(variables)
@@ -369,13 +431,13 @@ func (c CompareExpr) Eval(variables map[string]map[string]any) (*EvalResult, err
}
type LogicalExpr struct {
Type string `json:"type"` // logical
Type ExprType `json:"type"` // logical
Op LogicalOperator `json:"op"`
Left Expr `json:"left"`
Right Expr `json:"right"`
}
func (l LogicalExpr) GetType() string { return l.Type }
func (l LogicalExpr) GetType() ExprType { return l.Type }
func (l LogicalExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
left, err := l.Left.Eval(variables)
@@ -398,11 +460,11 @@ func (l LogicalExpr) Eval(variables map[string]map[string]any) (*EvalResult, err
}
type NotExpr struct {
Type string `json:"type"` // not
Expr Expr `json:"expr"`
Type ExprType `json:"type"` // not
Expr Expr `json:"expr"`
}
func (n NotExpr) GetType() string { return n.Type }
func (n NotExpr) GetType() ExprType { return n.Type }
func (n NotExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
inner, err := n.Expr.Eval(variables)
@@ -413,7 +475,7 @@ func (n NotExpr) Eval(variables map[string]map[string]any) (*EvalResult, error)
}
type rawExpr struct {
Type string `json:"type"`
Type ExprType `json:"type"`
}
func MarshalExpr(e Expr) ([]byte, error) {
@@ -427,31 +489,31 @@ func UnmarshalExpr(data []byte) (Expr, error) {
}
switch typ.Type {
case "const":
case ConstExprType:
var e ConstExpr
if err := json.Unmarshal(data, &e); err != nil {
return nil, err
}
return e, nil
case "var":
case VarExprType:
var e VarExpr
if err := json.Unmarshal(data, &e); err != nil {
return nil, err
}
return e, nil
case "compare":
case CompareExprType:
var e CompareExprRaw
if err := json.Unmarshal(data, &e); err != nil {
return nil, err
}
return e.ToCompareExpr()
case "logical":
case LogicalExprType:
var e LogicalExprRaw
if err := json.Unmarshal(data, &e); err != nil {
return nil, err
}
return e.ToLogicalExpr()
case "not":
case NotExprType:
var e NotExprRaw
if err := json.Unmarshal(data, &e); err != nil {
return nil, err
@@ -463,7 +525,7 @@ func UnmarshalExpr(data []byte) (Expr, error) {
}
type CompareExprRaw struct {
Type string `json:"type"`
Type ExprType `json:"type"`
Op ComparisonOperator `json:"op"`
Left json.RawMessage `json:"left"`
Right json.RawMessage `json:"right"`
@@ -487,7 +549,7 @@ func (r CompareExprRaw) ToCompareExpr() (CompareExpr, error) {
}
type LogicalExprRaw struct {
Type string `json:"type"`
Type ExprType `json:"type"`
Op LogicalOperator `json:"op"`
Left json.RawMessage `json:"left"`
Right json.RawMessage `json:"right"`
@@ -511,7 +573,7 @@ func (r LogicalExprRaw) ToLogicalExpr() (LogicalExpr, error) {
}
type NotExprRaw struct {
Type string `json:"type"`
Type ExprType `json:"type"`
Expr json.RawMessage `json:"expr"`
}

View File

@@ -88,8 +88,10 @@ type WorkflowNodeConfigForCondition struct {
}
type WorkflowNodeConfigForInspect struct {
Host string `json:"host"` // 主机
Domain string `json:"domain"` // 域名
Port string `json:"port"` // 端口
Path string `json:"path"` // 路径
}
type WorkflowNodeConfigForUpload struct {
@@ -134,9 +136,14 @@ func (n *WorkflowNode) GetConfigForCondition() WorkflowNodeConfigForCondition {
}
func (n *WorkflowNode) GetConfigForInspect() WorkflowNodeConfigForInspect {
host := maputil.GetString(n.Config, "host")
if host == "" {
return WorkflowNodeConfigForInspect{}
}
domain := maputil.GetString(n.Config, "domain")
if domain == "" {
return WorkflowNodeConfigForInspect{}
domain = host
}
port := maputil.GetString(n.Config, "port")
@@ -144,9 +151,13 @@ func (n *WorkflowNode) GetConfigForInspect() WorkflowNodeConfigForInspect {
port = "443"
}
path := maputil.GetString(n.Config, "path")
return WorkflowNodeConfigForInspect{
Domain: domain,
Port: port,
Host: host,
Path: path,
}
}

View File

@@ -100,8 +100,8 @@ func (n *applyNode) Process(ctx context.Context) error {
}
// 添加中间结果
n.outputs["certificate.validated"] = true
n.outputs["certificate.daysLeft"] = int(time.Until(certificate.ExpireAt).Hours() / 24)
n.outputs[outputCertificateValidatedKey] = "true"
n.outputs[outputCertificateDaysLeftKey] = fmt.Sprintf("%d", int(time.Until(certificate.ExpireAt).Hours()/24))
n.logger.Info("apply completed")
@@ -146,9 +146,9 @@ func (n *applyNode) checkCanSkip(ctx context.Context, lastOutput *domain.Workflo
renewalInterval := time.Duration(currentNodeConfig.SkipBeforeExpiryDays) * time.Hour * 24
expirationTime := time.Until(lastCertificate.ExpireAt)
if expirationTime > renewalInterval {
n.outputs["certificate.validated"] = true
n.outputs["certificate.daysLeft"] = int(expirationTime.Hours() / 24)
n.outputs[outputCertificateValidatedKey] = "true"
n.outputs[outputCertificateDaysLeftKey] = fmt.Sprintf("%d", int(expirationTime.Hours()/24))
return true, fmt.Sprintf("the certificate has already been issued (expires in %dd, next renewal in %dd)", int(expirationTime.Hours()/24), currentNodeConfig.SkipBeforeExpiryDays)
}

View File

@@ -0,0 +1,6 @@
package nodeprocessor
const (
outputCertificateValidatedKey = "certificate.validated"
outputCertificateDaysLeftKey = "certificate.daysLeft"
)

View File

@@ -3,9 +3,12 @@ package nodeprocessor
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"math"
"net"
"net/http"
"strings"
"time"
"github.com/usual2970/certimate/internal/domain"
@@ -26,13 +29,13 @@ func NewInspectNode(node *domain.WorkflowNode) *inspectNode {
}
func (n *inspectNode) Process(ctx context.Context) error {
n.logger.Info("enter inspect website certificate node ...")
n.logger.Info("entering inspect certificate node...")
nodeConfig := n.node.GetConfigForInspect()
err := n.inspect(ctx, nodeConfig)
if err != nil {
n.logger.Warn("inspect website certificate failed: " + err.Error())
n.logger.Warn("inspect certificate failed: " + err.Error())
return err
}
@@ -40,18 +43,35 @@ func (n *inspectNode) Process(ctx context.Context) error {
}
func (n *inspectNode) inspect(ctx context.Context, nodeConfig domain.WorkflowNodeConfigForInspect) error {
// 定义重试参数
maxRetries := 3
retryInterval := 2 * time.Second
var cert *tls.Certificate
var lastError error
var certInfo *x509.Certificate
domainWithPort := nodeConfig.Domain + ":" + nodeConfig.Port
host := nodeConfig.Host
port := nodeConfig.Port
if port == "" {
port = "443"
}
domain := nodeConfig.Domain
if domain == "" {
domain = host
}
path := nodeConfig.Path
if path != "" && !strings.HasPrefix(path, "/") {
path = "/" + path
}
targetAddr := fmt.Sprintf("%s:%s", host, port)
n.logger.Info(fmt.Sprintf("Inspecting certificate at %s (validating domain: %s)", targetAddr, domain))
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
n.logger.Info(fmt.Sprintf("Retry #%d connecting to %s", attempt, domainWithPort))
n.logger.Info(fmt.Sprintf("Retry #%d connecting to %s", attempt, targetAddr))
select {
case <-ctx.Done():
return ctx.Err()
@@ -60,30 +80,65 @@ func (n *inspectNode) inspect(ctx context.Context, nodeConfig domain.WorkflowNod
}
}
dialer := &net.Dialer{
Timeout: 10 * time.Second,
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
}).DialContext,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: domain, // Set SNI to domain for proper certificate selection
},
ForceAttemptHTTP2: false,
DisableKeepAlives: true,
}
conn, err := tls.DialWithDialer(dialer, "tcp", domainWithPort, &tls.Config{
InsecureSkipVerify: true, // Allow self-signed certificates
})
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
scheme := "https"
urlStr := fmt.Sprintf("%s://%s", scheme, targetAddr)
if path != "" {
urlStr = urlStr + path
}
req, err := http.NewRequestWithContext(ctx, "HEAD", urlStr, nil)
if err != nil {
lastError = fmt.Errorf("failed to connect to %s: %w", domainWithPort, err)
lastError = fmt.Errorf("failed to create HTTP request: %w", err)
n.logger.Warn(fmt.Sprintf("Request creation attempt #%d failed: %s", attempt+1, lastError.Error()))
continue
}
if domain != host {
req.Host = domain
}
req.Header.Set("User-Agent", "CertificateValidator/1.0")
req.Header.Set("Accept", "*/*")
resp, err := client.Do(req)
if err != nil {
lastError = fmt.Errorf("HTTP request failed: %w", err)
n.logger.Warn(fmt.Sprintf("Connection attempt #%d failed: %s", attempt+1, lastError.Error()))
continue
}
// Get certificate information
certInfo := conn.ConnectionState().PeerCertificates[0]
conn.Close()
// Certificate information retrieved successfully
cert = &tls.Certificate{
Certificate: [][]byte{certInfo.Raw},
Leaf: certInfo,
if resp.TLS == nil || len(resp.TLS.PeerCertificates) == 0 {
resp.Body.Close()
lastError = fmt.Errorf("no TLS certificates received in HTTP response")
n.logger.Warn(fmt.Sprintf("Certificate retrieval attempt #%d failed: %s", attempt+1, lastError.Error()))
continue
}
certInfo = resp.TLS.PeerCertificates[0]
resp.Body.Close()
lastError = nil
n.logger.Info(fmt.Sprintf("Successfully retrieved certificate information for %s", domainWithPort))
n.logger.Info(fmt.Sprintf("Successfully retrieved certificate from %s", targetAddr))
break
}
@@ -91,69 +146,46 @@ func (n *inspectNode) inspect(ctx context.Context, nodeConfig domain.WorkflowNod
return fmt.Errorf("failed to retrieve certificate after %d attempts: %w", maxRetries, lastError)
}
certInfo := cert.Leaf
now := time.Now()
isValid := now.Before(certInfo.NotAfter) && now.After(certInfo.NotBefore)
// Check domain matching
domainMatch := false
if len(certInfo.DNSNames) > 0 {
for _, dnsName := range certInfo.DNSNames {
if matchDomain(nodeConfig.Domain, dnsName) {
domainMatch = true
break
}
if certInfo == nil {
outputs := map[string]any{
outputCertificateValidatedKey: "false",
outputCertificateDaysLeftKey: "0",
}
} else if matchDomain(nodeConfig.Domain, certInfo.Subject.CommonName) {
domainMatch = true
n.setOutputs(outputs)
return nil
}
isValid = isValid && domainMatch
now := time.Now()
isValidTime := now.Before(certInfo.NotAfter) && now.After(certInfo.NotBefore)
domainMatch := true
if err := certInfo.VerifyHostname(domain); err != nil {
domainMatch = false
}
isValid := isValidTime && domainMatch
daysRemaining := math.Floor(certInfo.NotAfter.Sub(now).Hours() / 24)
// Set node outputs
outputs := map[string]any{
"certificate.validated": isValid,
"certificate.daysLeft": daysRemaining,
isValidStr := "false"
if isValid {
isValidStr = "true"
}
outputs := map[string]any{
outputCertificateValidatedKey: isValidStr,
outputCertificateDaysLeftKey: fmt.Sprintf("%d", int(daysRemaining)),
}
n.setOutputs(outputs)
n.logger.Info(fmt.Sprintf("Certificate inspection completed - Target: %s, Domain: %s, Valid: %s, Days Remaining: %d",
targetAddr, domain, isValidStr, int(daysRemaining)))
return nil
}
func (n *inspectNode) setOutputs(outputs map[string]any) {
n.outputs = outputs
}
func matchDomain(requestDomain, certDomain string) bool {
if requestDomain == certDomain {
return true
}
if len(certDomain) > 2 && certDomain[0] == '*' && certDomain[1] == '.' {
wildcardSuffix := certDomain[1:]
requestDomainLen := len(requestDomain)
suffixLen := len(wildcardSuffix)
if requestDomainLen > suffixLen && requestDomain[requestDomainLen-suffixLen:] == wildcardSuffix {
remainingPart := requestDomain[:requestDomainLen-suffixLen]
if len(remainingPart) > 0 && !contains(remainingPart, '.') {
return true
}
}
}
return false
}
func contains(s string, c byte) bool {
for i := 0; i < len(s); i++ {
if s[i] == c {
return true
}
}
return false
}

View File

@@ -69,8 +69,8 @@ func (n *uploadNode) Process(ctx context.Context) error {
return err
}
n.outputs["certificate.validated"] = true
n.outputs["certificate.daysLeft"] = int(time.Until(certificate.ExpireAt).Hours() / 24)
n.outputs[outputCertificateValidatedKey] = "true"
n.outputs[outputCertificateDaysLeftKey] = fmt.Sprintf("%d", int(time.Until(certificate.ExpireAt).Hours()/24))
n.logger.Info("upload completed")
@@ -91,8 +91,8 @@ func (n *uploadNode) checkCanSkip(ctx context.Context, lastOutput *domain.Workfl
lastCertificate, _ := n.certRepo.GetByWorkflowNodeId(ctx, n.node.Id)
if lastCertificate != nil {
n.outputs["certificate.validated"] = true
n.outputs["certificate.daysLeft"] = int(time.Until(lastCertificate.ExpireAt).Hours() / 24)
n.outputs[outputCertificateValidatedKey] = "true"
n.outputs[outputCertificateDaysLeftKey] = fmt.Sprintf("%d", int(time.Until(lastCertificate.ExpireAt).Hours()/24))
return true, "the certificate has already been uploaded"
}
}