fix
This commit is contained in:
88
pkg/crontab/crontab.go
Normal file
88
pkg/crontab/crontab.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"errors"
|
||||
"jiacrontab/pkg/pqueue"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Task 任务
|
||||
type Task = pqueue.Item
|
||||
|
||||
type Crontab struct {
|
||||
pq pqueue.PriorityQueue
|
||||
mux sync.RWMutex
|
||||
ready chan *Task
|
||||
}
|
||||
|
||||
func New() *Crontab {
|
||||
return &Crontab{
|
||||
pq: pqueue.New(10000),
|
||||
ready: make(chan *Task, 10000),
|
||||
}
|
||||
}
|
||||
|
||||
// AddJob 添加未经处理的job
|
||||
func (c *Crontab) AddJob(j *Job) error {
|
||||
nt, err := j.NextExecutionTime(time.Now())
|
||||
if err != nil {
|
||||
return errors.New("Invalid execution time")
|
||||
}
|
||||
c.mux.Lock()
|
||||
heap.Push(&c.pq, &Task{
|
||||
Priority: nt.UnixNano(),
|
||||
Value: j,
|
||||
})
|
||||
c.mux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddJob 添加延时任务
|
||||
func (c *Crontab) AddTask(t *Task) {
|
||||
c.mux.Lock()
|
||||
heap.Push(&c.pq, t)
|
||||
c.mux.Unlock()
|
||||
}
|
||||
|
||||
func (c *Crontab) Len() int {
|
||||
c.mux.RLock()
|
||||
len := len(c.pq)
|
||||
c.mux.RUnlock()
|
||||
return len
|
||||
}
|
||||
|
||||
func (c *Crontab) GetAllTask() []*Task {
|
||||
c.mux.Lock()
|
||||
list := c.pq
|
||||
c.mux.Unlock()
|
||||
return list
|
||||
}
|
||||
|
||||
func (c *Crontab) Ready() <-chan *Task {
|
||||
return c.ready
|
||||
}
|
||||
|
||||
func (c *Crontab) QueueScanWorker() {
|
||||
refreshTicker := time.NewTicker(20 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-refreshTicker.C:
|
||||
if len(c.pq) == 0 {
|
||||
continue
|
||||
}
|
||||
start:
|
||||
c.mux.Lock()
|
||||
now := time.Now().UnixNano()
|
||||
job, _ := c.pq.PeekAndShift(now)
|
||||
c.mux.Unlock()
|
||||
if job == nil {
|
||||
continue
|
||||
}
|
||||
c.ready <- job
|
||||
goto start
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
57
pkg/crontab/crontab_test.go
Normal file
57
pkg/crontab/crontab_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_crontab_Ready(t *testing.T) {
|
||||
var timeLayout = "2006-01-02 15:04:05"
|
||||
c := New()
|
||||
now := time.Now().Add(6 * time.Second)
|
||||
c.AddTask(&Task{
|
||||
Value: "test1" + now.Format(timeLayout),
|
||||
Priority: now.UnixNano(),
|
||||
})
|
||||
|
||||
now = time.Now().Add(1 * time.Second)
|
||||
|
||||
c.AddTask(&Task{
|
||||
Value: "test2" + now.Format(timeLayout),
|
||||
Priority: now.UnixNano(),
|
||||
})
|
||||
now = time.Now().Add(3 * time.Second)
|
||||
|
||||
c.AddTask(&Task{
|
||||
Value: "test3" + now.Format(timeLayout),
|
||||
Priority: now.UnixNano(),
|
||||
})
|
||||
|
||||
now = time.Now().Add(4 * time.Second)
|
||||
c.AddTask(&Task{
|
||||
Value: "test4" + now.Format(timeLayout),
|
||||
Priority: now.UnixNano(),
|
||||
})
|
||||
|
||||
now = time.Now().Add(3 * time.Second)
|
||||
c.AddTask(&Task{
|
||||
Value: "test5" + now.Format(timeLayout),
|
||||
Priority: now.UnixNano(),
|
||||
})
|
||||
|
||||
bts, _ := json.MarshalIndent(c.GetAllTask(), "", "")
|
||||
fmt.Println(string(bts))
|
||||
|
||||
go c.QueueScanWorker()
|
||||
|
||||
go func() {
|
||||
for v := range c.Ready() {
|
||||
bts, _ := json.MarshalIndent(v, "", "")
|
||||
fmt.Println(string(bts))
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
213
pkg/crontab/job.go
Normal file
213
pkg/crontab/job.go
Normal file
@@ -0,0 +1,213 @@
|
||||
// package crontab 实现定时调度
|
||||
// 借鉴https://github.com/robfig/cron
|
||||
// 部分实现添加注释
|
||||
// 向https://github.com/robfig/cron项目致敬
|
||||
package crontab
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"jiacrontab/pkg/util"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
starBit = 1 << 63
|
||||
)
|
||||
|
||||
type bounds struct {
|
||||
min, max uint
|
||||
names map[string]uint
|
||||
}
|
||||
|
||||
// The bounds for each field.
|
||||
var (
|
||||
seconds = bounds{0, 59, nil}
|
||||
minutes = bounds{0, 59, nil}
|
||||
hours = bounds{0, 23, nil}
|
||||
dom = bounds{1, 31, nil}
|
||||
months = bounds{1, 12, map[string]uint{
|
||||
"jan": 1,
|
||||
"feb": 2,
|
||||
"mar": 3,
|
||||
"apr": 4,
|
||||
"may": 5,
|
||||
"jun": 6,
|
||||
"jul": 7,
|
||||
"aug": 8,
|
||||
"sep": 9,
|
||||
"oct": 10,
|
||||
"nov": 11,
|
||||
"dec": 12,
|
||||
}}
|
||||
dow = bounds{0, 6, map[string]uint{
|
||||
"sun": 0,
|
||||
"mon": 1,
|
||||
"tue": 2,
|
||||
"wed": 3,
|
||||
"thu": 4,
|
||||
"fri": 5,
|
||||
"sat": 6,
|
||||
}}
|
||||
)
|
||||
|
||||
type Job struct {
|
||||
Second string
|
||||
Minute string
|
||||
Hour string
|
||||
Day string
|
||||
Weekday string
|
||||
Month string
|
||||
|
||||
ID uint
|
||||
now time.Time
|
||||
lastExecutionTime time.Time
|
||||
nextExecutionTime time.Time
|
||||
|
||||
second, minute, hour, dom, month, dow uint64
|
||||
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (j *Job) Format() string {
|
||||
return fmt.Sprintf("second: %s minute: %s hour: %s day: %s weekday: %s month: %s",
|
||||
j.Second, j.Minute, j.Hour, j.Day, j.Weekday, j.Month)
|
||||
}
|
||||
func (j *Job) GetNextExecTime() time.Time {
|
||||
return j.nextExecutionTime
|
||||
}
|
||||
|
||||
func (j *Job) GetLastExecTime() time.Time {
|
||||
return j.lastExecutionTime
|
||||
}
|
||||
|
||||
// parse 解析定时规则
|
||||
// 根据规则生成符和条件的日期
|
||||
// 例如:*/2 如果位于分位,则生成0,2,4,6....58
|
||||
// 生成的日期逐条的被映射到uint64数值中
|
||||
// min |= 1<<2
|
||||
func (j *Job) parse() error {
|
||||
var err error
|
||||
field := func(field string, r bounds) uint64 {
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
var bits uint64
|
||||
bits, err = getField(field, r)
|
||||
return bits
|
||||
}
|
||||
j.second = field(j.Second, seconds)
|
||||
j.minute = field(j.Minute, minutes)
|
||||
j.hour = field(j.Hour, hours)
|
||||
j.dom = field(j.Day, dom)
|
||||
j.month = field(j.Month, months)
|
||||
j.dow = field(j.Weekday, dow)
|
||||
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
// NextExecTime 获得下次执行时间
|
||||
func (j *Job) NextExecutionTime(t time.Time) (time.Time, error) {
|
||||
if err := j.parse(); err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond)
|
||||
added := false
|
||||
defer func() {
|
||||
j.lastExecutionTime, j.nextExecutionTime = j.nextExecutionTime, t
|
||||
}()
|
||||
|
||||
// 设置最大调度周期为5年
|
||||
yearLimit := t.Year() + 5
|
||||
|
||||
WRAP:
|
||||
if t.Year() > yearLimit {
|
||||
return time.Time{}, errors.New("Over 5 years")
|
||||
}
|
||||
|
||||
for 1<<uint(t.Month())&j.month == 0 {
|
||||
// If we have to add a month, reset the other parts to 0.
|
||||
if !added {
|
||||
added = true
|
||||
// Otherwise, set the date at the beginning (since the current time is irrelevant).
|
||||
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
t = t.AddDate(0, 1, 0)
|
||||
|
||||
// Wrapped around.
|
||||
if t.Month() == time.January {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
// Now get a day in that month.
|
||||
for !dayMatches(j, t) {
|
||||
if !added {
|
||||
added = true
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
t = t.AddDate(0, 0, 1)
|
||||
|
||||
if t.Day() == 1 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
for 1<<uint(t.Hour())&j.hour == 0 {
|
||||
if !added {
|
||||
added = true
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), 0, 0, 0, t.Location())
|
||||
}
|
||||
t = t.Add(1 * time.Hour)
|
||||
|
||||
if t.Hour() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
for 1<<uint(t.Minute())&j.minute == 0 {
|
||||
if !added {
|
||||
added = true
|
||||
t = t.Truncate(time.Minute)
|
||||
}
|
||||
t = t.Add(1 * time.Minute)
|
||||
|
||||
if t.Minute() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
for 1<<uint(t.Second())&j.second == 0 {
|
||||
if !added {
|
||||
added = true
|
||||
t = t.Truncate(time.Second)
|
||||
}
|
||||
t = t.Add(1 * time.Second)
|
||||
|
||||
if t.Second() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func dayMatches(j *Job, t time.Time) bool {
|
||||
|
||||
if j.Day == "L" {
|
||||
l := util.CountDaysOfMonth(t.Year(), int(t.Month()))
|
||||
j.dom = getBits(uint(l), uint(l), 1)
|
||||
}
|
||||
|
||||
var (
|
||||
domMatch bool = 1<<uint(t.Day())&j.dom > 0
|
||||
dowMatch bool = 1<<uint(t.Weekday())&j.dow > 0
|
||||
)
|
||||
|
||||
if j.dom&starBit > 0 || j.dow&starBit > 0 {
|
||||
return domMatch && dowMatch
|
||||
}
|
||||
return domMatch || dowMatch
|
||||
}
|
||||
70
pkg/crontab/job_test.go
Normal file
70
pkg/crontab/job_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
"jiacrontab/pkg/test"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestJob_NextExecutionTime(t *testing.T) {
|
||||
timeLayout := "2006-01-02 15:04:05"
|
||||
j := &Job{
|
||||
Second: "48",
|
||||
Minute: "3",
|
||||
Hour: "12",
|
||||
Day: "25",
|
||||
Weekday: "*",
|
||||
Month: "1",
|
||||
}
|
||||
|
||||
tt, err := j.NextExecutionTime(time.Now())
|
||||
test.Nil(t, err)
|
||||
test.Equal(t, "2020-01-25 12:03:48", tt.Format(timeLayout))
|
||||
|
||||
tt, err = j.NextExecutionTime(tt)
|
||||
test.Nil(t, err)
|
||||
test.Equal(t, "2021-01-25 12:03:48", tt.Format(timeLayout))
|
||||
|
||||
tt, err = j.NextExecutionTime(tt)
|
||||
test.Equal(t, "2022-01-25 12:03:48", tt.Format(timeLayout))
|
||||
|
||||
j = &Job{
|
||||
Second: "58",
|
||||
Minute: "*/4",
|
||||
Hour: "12",
|
||||
Day: "4",
|
||||
Weekday: "*",
|
||||
Month: "3",
|
||||
}
|
||||
tt, err = j.NextExecutionTime(time.Now())
|
||||
test.Nil(t, err)
|
||||
test.Equal(t, "2020-03-04 12:00:58", tt.Format(timeLayout))
|
||||
|
||||
tt, err = j.NextExecutionTime(tt)
|
||||
test.Nil(t, err)
|
||||
test.Equal(t, "2020-03-04 12:04:58", tt.Format(timeLayout))
|
||||
|
||||
tt, err = j.NextExecutionTime(tt)
|
||||
test.Nil(t, err)
|
||||
test.Equal(t, "2020-03-04 12:08:58", tt.Format(timeLayout))
|
||||
|
||||
j = &Job{
|
||||
Second: "0",
|
||||
Minute: "*",
|
||||
Hour: "*",
|
||||
Day: "*",
|
||||
Weekday: "*",
|
||||
Month: "*",
|
||||
}
|
||||
|
||||
tt, err = j.NextExecutionTime(time.Now())
|
||||
test.Nil(t, err)
|
||||
t.Log(tt, j.GetLastExecTime())
|
||||
for i := 0; i < 1000; i++ {
|
||||
tt, err = j.NextExecutionTime(tt)
|
||||
test.Nil(t, err)
|
||||
t.Log(tt, j.GetLastExecTime())
|
||||
}
|
||||
|
||||
t.Log("end")
|
||||
}
|
||||
128
pkg/crontab/parse.go
Normal file
128
pkg/crontab/parse.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getRange(expr string, r bounds) (uint64, error) {
|
||||
var (
|
||||
start, end, step uint
|
||||
rangeAndStep = strings.Split(expr, "/")
|
||||
lowAndHigh = strings.Split(rangeAndStep[0], "-")
|
||||
singleDigit = len(lowAndHigh) == 1
|
||||
err error
|
||||
)
|
||||
|
||||
var extra uint64
|
||||
if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" {
|
||||
start = r.min
|
||||
end = r.max
|
||||
extra = starBit
|
||||
} else {
|
||||
if lowAndHigh[0] == "L" {
|
||||
return 0, nil
|
||||
}
|
||||
start, err = parseIntOrName(lowAndHigh[0], r.names)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
switch len(lowAndHigh) {
|
||||
case 1:
|
||||
end = start
|
||||
case 2:
|
||||
end, err = parseIntOrName(lowAndHigh[1], r.names)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("Too many hyphens: %s", expr)
|
||||
}
|
||||
}
|
||||
|
||||
switch len(rangeAndStep) {
|
||||
case 1:
|
||||
step = 1
|
||||
case 2:
|
||||
step, err = mustParseInt(rangeAndStep[1])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Special handling: "N/step" means "N-max/step".
|
||||
if singleDigit {
|
||||
end = r.max
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("Too many slashes: %s", expr)
|
||||
}
|
||||
|
||||
if start < r.min {
|
||||
return 0, fmt.Errorf("Beginning of range (%d) below minimum (%d): %s", start, r.min, expr)
|
||||
}
|
||||
if end > r.max {
|
||||
return 0, fmt.Errorf("End of range (%d) above maximum (%d): %s", end, r.max, expr)
|
||||
}
|
||||
if start > end {
|
||||
return 0, fmt.Errorf("Beginning of range (%d) beyond end of range (%d): %s", start, end, expr)
|
||||
}
|
||||
if step == 0 {
|
||||
return 0, fmt.Errorf("Step of range should be a positive number: %s", expr)
|
||||
}
|
||||
|
||||
return getBits(start, end, step) | extra, nil
|
||||
}
|
||||
|
||||
func parseIntOrName(expr string, names map[string]uint) (uint, error) {
|
||||
if names != nil {
|
||||
if namedInt, ok := names[strings.ToLower(expr)]; ok {
|
||||
return namedInt, nil
|
||||
}
|
||||
}
|
||||
return mustParseInt(expr)
|
||||
}
|
||||
|
||||
// mustParseInt parses the given expression as an int or returns an error.
|
||||
func mustParseInt(expr string) (uint, error) {
|
||||
num, err := strconv.Atoi(expr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Failed to parse int from %s: %s", expr, err)
|
||||
}
|
||||
if num < 0 {
|
||||
return 0, fmt.Errorf("Negative number (%d) not allowed: %s", num, expr)
|
||||
}
|
||||
|
||||
return uint(num), nil
|
||||
}
|
||||
|
||||
// getBits sets all bits in the range [min, max], modulo the given step size.
|
||||
func getBits(min, max, step uint) uint64 {
|
||||
var bits uint64
|
||||
|
||||
// If step is 1, use shifts.
|
||||
if step == 1 {
|
||||
return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min)
|
||||
}
|
||||
|
||||
// Else, use a simple loop.
|
||||
for i := min; i <= max; i += step {
|
||||
bits |= 1 << i
|
||||
}
|
||||
return bits
|
||||
}
|
||||
|
||||
func getField(field string, r bounds) (uint64, error) {
|
||||
var bits uint64
|
||||
ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' })
|
||||
for _, expr := range ranges {
|
||||
bit, err := getRange(expr, r)
|
||||
if err != nil {
|
||||
return bits, err
|
||||
}
|
||||
bits |= bit
|
||||
}
|
||||
return bits, nil
|
||||
}
|
||||
Reference in New Issue
Block a user