refactor: thread safe tasks

This commit is contained in:
Simon Ding
2025-04-08 14:01:32 +08:00
parent 9014f846be
commit 2bc71b0c66
5 changed files with 36 additions and 28 deletions

View File

@@ -20,7 +20,8 @@ func NewEngine(db *db.Client, language string) *Engine {
return &Engine{ return &Engine{
db: db, db: db,
cron: cron.New(), cron: cron.New(),
tasks: make(map[int]*Task, 0), tasks: utils.Map[int, *Task]{},
schedulers: utils.Map[string, scheduler]{},
language: language, language: language,
} }
} }
@@ -32,7 +33,7 @@ type scheduler struct {
type Engine struct { type Engine struct {
db *db.Client db *db.Client
cron *cron.Cron cron *cron.Cron
tasks map[int]*Task tasks utils.Map[int, *Task]
language string language string
schedulers utils.Map[string, scheduler] schedulers utils.Map[string, scheduler]
} }
@@ -59,7 +60,7 @@ func (c *Engine) reloadUsingBuildinDownloader(h *ent.History) error{
if err != nil { if err != nil {
return errors.Wrap(err, "download torrent") return errors.Wrap(err, "download torrent")
} }
c.tasks[h.ID] = &Task{Torrent: t} c.tasks.Store(h.ID, &Task{Torrent: t})
return nil return nil
} }
@@ -93,7 +94,7 @@ func (c *Engine) reloadTasks() {
log.Warnf("get task error: %v", err) log.Warnf("get task error: %v", err)
continue continue
} }
c.tasks[t.ID] = &Task{Torrent: to} c.tasks.Store(t.ID, &Task{Torrent: to})
} else if t.Link != "" { } else if t.Link != "" {
to, err := transmission.NewTorrent(transmission.Config{ to, err := transmission.NewTorrent(transmission.Config{
URL: dl.URL, URL: dl.URL,
@@ -104,7 +105,7 @@ func (c *Engine) reloadTasks() {
log.Warnf("get task error: %v", err) log.Warnf("get task error: %v", err)
continue continue
} }
c.tasks[t.ID] = &Task{Torrent: to} c.tasks.Store(t.ID, &Task{Torrent: to})
} }
} else if dl.Implementation == downloadclients.ImplementationQbittorrent { } else if dl.Implementation == downloadclients.ImplementationQbittorrent {
if t.Hash != "" { if t.Hash != "" {
@@ -117,7 +118,7 @@ func (c *Engine) reloadTasks() {
log.Warnf("get task error: %v", err) log.Warnf("get task error: %v", err)
continue continue
} }
c.tasks[t.ID] = &Task{Torrent: to} c.tasks.Store(t.ID, &Task{Torrent: to})
} else if t.Link != "" { } else if t.Link != "" {
to, err := qbittorrent.NewTorrent(qbittorrent.Info{ to, err := qbittorrent.NewTorrent(qbittorrent.Info{
@@ -129,8 +130,7 @@ func (c *Engine) reloadTasks() {
log.Warnf("get task error: %v", err) log.Warnf("get task error: %v", err)
continue continue
} }
c.tasks[t.ID] = &Task{Torrent: to} c.tasks.Store(t.ID, &Task{Torrent: to})
} }
} }
@@ -200,16 +200,16 @@ func (c *Engine) MustTMDB() *tmdb.Client {
} }
func (c *Engine) RemoveTaskAndTorrent(id int) error { func (c *Engine) RemoveTaskAndTorrent(id int) error {
torrent := c.tasks[id] torrent, ok := c.tasks.Load(id)
if torrent != nil { if ok {
if err := torrent.Remove(); err != nil { if err := torrent.Remove(); err != nil {
return errors.Wrap(err, "remove torrent") return errors.Wrap(err, "remove torrent")
} }
delete(c.tasks, id) c.tasks.Delete(id)
} }
return nil return nil
} }
func (c *Engine) GetTasks() map[int]*Task { func (c *Engine) GetTasks() utils.Map[int, *Task] {
return c.tasks return c.tasks
} }

View File

@@ -253,7 +253,7 @@ func (c *Engine) findEpisodeFilesPreMoving(historyId int) error {
episodeIds := c.GetEpisodeIds(his) episodeIds := c.GetEpisodeIds(his)
task := c.tasks[historyId] task, _ := c.tasks.Load(historyId)
ff, err := c.db.GetAcceptedVideoFormats() ff, err := c.db.GetAcceptedVideoFormats()
if err != nil { if err != nil {

View File

@@ -199,7 +199,7 @@ func (c *Engine) downloadTorrent(m *ent.Media, r1 torznab.Result, seasonNum int,
} }
torrent.Start() torrent.Start()
c.tasks[history.ID] = &Task{Torrent: torrent} c.tasks.Store(history.ID, &Task{Torrent: torrent})
c.sendMsg(fmt.Sprintf(message.BeginDownload, name)) c.sendMsg(fmt.Sprintf(message.BeginDownload, name))

View File

@@ -65,46 +65,50 @@ func (c *Engine) TriggerCronJob(name string) error {
func (c *Engine) checkTasks() error { func (c *Engine) checkTasks() error {
log.Debug("begin check tasks...") log.Debug("begin check tasks...")
for id, t := range c.tasks { c.tasks.Range(func(id int, t *Task) bool {
r := c.db.GetHistory(id) r := c.db.GetHistory(id)
if !t.Exists() { if !t.Exists() {
log.Infof("task no longer exists: %v", id) log.Infof("task no longer exists: %v", id)
delete(c.tasks, id) c.tasks.Delete(id)
continue return true
} }
name, err := t.Name() name, err := t.Name()
if err != nil { if err != nil {
return errors.Wrap(err, "get name") log.Warnf("get task name error: %v", err)
return true
} }
progress, err := t.Progress() progress, err := t.Progress()
if err != nil { if err != nil {
return errors.Wrap(err, "get progress") log.Warnf("get task progress error: %v", err)
return true
} }
log.Infof("task (%s) percentage done: %d%%", name, progress) log.Infof("task (%s) percentage done: %d%%", name, progress)
if progress == 100 { if progress == 100 {
if r.Status == history.StatusSeeding { if r.Status == history.StatusSeeding {
//task already success, check seed ratio //task already success, check seed ratio
torrent := c.tasks[id] torrent, _ := c.tasks.Load(id)
ratio, ok := c.isSeedRatioLimitReached(r.IndexerID, torrent) ratio, ok := c.isSeedRatioLimitReached(r.IndexerID, torrent)
if ok { if ok {
log.Infof("torrent file seed ratio reached, remove: %v, current seed ratio: %v", name, ratio) log.Infof("torrent file seed ratio reached, remove: %v, current seed ratio: %v", name, ratio)
torrent.Remove() torrent.Remove()
delete(c.tasks, id) c.tasks.Delete(id)
c.setHistoryStatus(id, history.StatusSuccess) c.setHistoryStatus(id, history.StatusSuccess)
} else { } else {
log.Infof("torrent file still sedding: %v, current seed ratio: %v", name, ratio) log.Infof("torrent file still sedding: %v, current seed ratio: %v", name, ratio)
} }
continue return true
} else if r.Status == history.StatusRunning { } else if r.Status == history.StatusRunning {
log.Infof("task is done: %v", name) log.Infof("task is done: %v", name)
c.sendMsg(fmt.Sprintf(message.DownloadComplete, name)) c.sendMsg(fmt.Sprintf(message.DownloadComplete, name))
go c.postTaskProcessing(id) go c.postTaskProcessing(id)
} }
} }
}
return true
})
return nil return nil
} }
@@ -232,7 +236,7 @@ func (c *Engine) GetEpisodeIds(r *ent.History) []int {
} }
func (c *Engine) moveCompletedTask(id int) (err1 error) { func (c *Engine) moveCompletedTask(id int) (err1 error) {
torrent := c.tasks[id] torrent, _ := c.tasks.Load(id)
r := c.db.GetHistory(id) r := c.db.GetHistory(id)
// if r.Status == history.StatusUploading { // if r.Status == history.StatusUploading {
// log.Infof("task %d is already uploading, skip", id) // log.Infof("task %d is already uploading, skip", id)
@@ -258,7 +262,7 @@ func (c *Engine) moveCompletedTask(id int) (err1 error) {
c.sendMsg(fmt.Sprintf(message.ProcessingFailed, err1)) c.sendMsg(fmt.Sprintf(message.ProcessingFailed, err1))
if downloadclient.RemoveFailedDownloads { if downloadclient.RemoveFailedDownloads {
log.Debugf("task failed, remove failed torrent and files related") log.Debugf("task failed, remove failed torrent and files related")
delete(c.tasks, r.ID) c.tasks.Delete(r.ID)
torrent.Remove() torrent.Remove()
} }
} }
@@ -289,7 +293,7 @@ func (c *Engine) moveCompletedTask(id int) (err1 error) {
if downloadclient.RemoveCompletedDownloads && ok { if downloadclient.RemoveCompletedDownloads && ok {
log.Debugf("download complete,remove torrent and files related, torrent: %v, seed ratio: %v", torrentName, r1) log.Debugf("download complete,remove torrent and files related, torrent: %v, seed ratio: %v", torrentName, r1)
c.setHistoryStatus(r.ID, history.StatusSuccess) c.setHistoryStatus(r.ID, history.StatusSuccess)
delete(c.tasks, r.ID) c.tasks.Delete(r.ID)
torrent.Remove() torrent.Remove()
} else { } else {
log.Infof("task complete but still needs seeding: %v", torrentName) log.Infof("task complete but still needs seeding: %v", torrentName)

View File

@@ -2,6 +2,7 @@ package server
import ( import (
"fmt" "fmt"
"polaris/engine"
"polaris/ent" "polaris/ent"
"polaris/ent/blacklist" "polaris/ent/blacklist"
"polaris/ent/episode" "polaris/ent/episode"
@@ -31,7 +32,8 @@ func (s *Server) GetAllActivities(c *gin.Context) (interface{}, error) {
a := Activity{ a := Activity{
History: h, History: h,
} }
for id, task := range s.core.GetTasks() { tasks := s.core.GetTasks()
tasks.Range(func(id int, task *engine.Task) bool {
if h.ID == id && task.Exists() { if h.ID == id && task.Exists() {
p, err := task.Progress() p, err := task.Progress()
if err != nil { if err != nil {
@@ -49,7 +51,9 @@ func (s *Server) GetAllActivities(c *gin.Context) (interface{}, error) {
a.UploadProgress = task.UploadProgresser() a.UploadProgress = task.UploadProgresser()
} }
} }
} return true
})
activities = append(activities, a) activities = append(activities, a)
} }
} else { } else {