diff --git a/assembler.go b/assembler.go index 142226c..addd28d 100644 --- a/assembler.go +++ b/assembler.go @@ -13,7 +13,7 @@ type taskAssembler interface { } type taskAssemblerImp struct { - config *TaskConfig + *options } func (s *taskAssemblerImp) AssembleTask(ctxIn context.Context, taskDef *TaskDefinition, arg interface{}) (*Task, error) { @@ -39,7 +39,7 @@ func (s *taskAssemblerImp) AssembleTask(ctxIn context.Context, taskDef *TaskDefi task.Argument = argBytes } if ctxIn != nil { - ctxBytes, err := taskDef.ctxMarshaler(s.config).MarshalCtx(ctxIn) + ctxBytes, err := taskDef.ctxMarshaler(s.ctxMarshaler).MarshalCtx(ctxIn) if err != nil { return nil, fmt.Errorf("get ctxBytes failed, err: %w", err) } @@ -50,7 +50,7 @@ func (s *taskAssemblerImp) AssembleTask(ctxIn context.Context, taskDef *TaskDefi } func (s *taskAssemblerImp) DisassembleTask(taskDef *TaskDefinition, task *Task) (context.Context, interface{}, error) { - ctxIn, err := taskDef.ctxMarshaler(s.config).UnmarshalCtx(task.Context) + ctxIn, err := taskDef.ctxMarshaler(s.ctxMarshaler).UnmarshalCtx(task.Context) if err != nil { return nil, nil, fmt.Errorf("unmarshal task context error: %w", err) } diff --git a/assembler_test.go b/assembler_test.go index 0aac0e6..30fb79b 100644 --- a/assembler_test.go +++ b/assembler_test.go @@ -20,7 +20,7 @@ func (e *testErrCtxMarshaler) UnmarshalCtx(bytes []byte) (context.Context, error func Test_taskAssemblerImp_AssembleTask(t *testing.T) { convey.Convey("Test_taskAssemblerImp_AssembleTask", t, func() { - tass := taskAssemblerImp{config: &TaskConfig{CtxMarshaler: &defaultCtxMarshaler{}}} + tass := taskAssemblerImp{options: &options{ctxMarshaler: &defaultCtxMarshaler{}}} convey.Convey("normal", func() { convey.Convey("nil arg type", func() { task, err := tass.AssembleTask(context.TODO(), &TaskDefinition{}, nil) @@ -59,7 +59,7 @@ func Test_taskAssemblerImp_AssembleTask(t *testing.T) { func Test_taskAssemblerImp_DisassembleTask(t *testing.T) { convey.Convey("Test_taskAssemblerImp_DisassembleTask", t, func() { - tass := taskAssemblerImp{config: &TaskConfig{CtxMarshaler: &defaultCtxMarshaler{}}} + tass := taskAssemblerImp{options: &options{ctxMarshaler: &defaultCtxMarshaler{}}} convey.Convey("normal", func() { convey.Convey("nil arg type", func() { taskDef := &TaskDefinition{} diff --git a/config.go b/config.go deleted file mode 100644 index e1bb968..0000000 --- a/config.go +++ /dev/null @@ -1,238 +0,0 @@ -package gta - -import ( - "context" - "time" - - "gorm.io/gorm" -) - -// TaskConfig contains all options of a TaskManager. -type TaskConfig struct { - // must provide, db for async task table - DB *gorm.DB - // must provide, async task table name - Table string - - // optional, context for the task mansger - Context context.Context - // optional, logger factory - LoggerFactory func(ctx context.Context) Logger - // optional, determine when a normal task can be cleaned - StorageTimeout time.Duration - // optional, determine whether a initialized task is abnormal - InitializedTimeout time.Duration - // optional, determine whether a running task is abnormal - RunningTimeout time.Duration - // optional, wait timeout in Stop() process - WaitTimeout time.Duration - // optional, scan interval - ScanInterval time.Duration - // optional, instant scan interval - InstantScanInterval time.Duration - // optional, context marshaler to store or recover a context - CtxMarshaler CtxMarshaler - // optional, callback function for abnormal tasks - CheckCallback func(logger Logger, abnormalTasks []Task) - // optional, flag for dry run mode - DryRun bool - // optional, goroutine pool size for scheduling tasks - PoolSize int - - // inner use - taskRegister taskRegister - cancelFunc context.CancelFunc -} - -func (s *TaskConfig) init() error { - if s.DB == nil { - return ErrConfigNilDB - } - if s.Table == "" { - return ErrConfigEmptyTable - } - - // default value for optional config - if s.Context == nil { - s.Context = defaultContextFactory() - } - if s.LoggerFactory == nil { - s.LoggerFactory = defaultLoggerFactory - } - if s.StorageTimeout <= 0 { - s.StorageTimeout = defaultStorageTimeout - } - if s.WaitTimeout <= 0 { - s.WaitTimeout = defaultWaitTimeout - } - if s.ScanInterval <= 0 { - s.ScanInterval = defaultScanInterval - } - if s.InstantScanInterval <= 0 { - s.InstantScanInterval = defaultInstantScanInvertal - } - if s.RunningTimeout <= 0 { - s.RunningTimeout = defaultRunningTimeout - } - if s.InitializedTimeout <= 0 { - s.InitializedTimeout = defaultInitializedTimeout - } - if s.CtxMarshaler == nil { - s.CtxMarshaler = &defaultCtxMarshaler{} - } - if s.CheckCallback == nil { - s.CheckCallback = defaultCheckCallback - } - if s.PoolSize <= 0 { - s.PoolSize = defaultPoolSize - } - if s.taskRegister == nil { - s.taskRegister = &taskRegisterImp{} - } - - // check - if s.RunningTimeout > s.StorageTimeout { - return ErrConfigInvalidRunningTimeout - } - if s.InitializedTimeout > s.StorageTimeout { - return ErrConfigInvalidInitializeTimeout - } - if s.ScanInterval > s.StorageTimeout || s.ScanInterval > s.InitializedTimeout || s.ScanInterval > s.RunningTimeout { - return ErrConfigInvalidScanInterval - } - if s.InstantScanInterval > s.ScanInterval { - return ErrConfigInvalidInstantScanInterval - } - - // generate context with cancel - s.Context, s.cancelFunc = context.WithCancel(s.Context) - - return nil -} - -func (s *TaskConfig) load(options ...Option) *TaskConfig { - for _, option := range options { - option(s) - } - return s -} - -func (s *TaskConfig) logger() Logger { - return s.LoggerFactory(s.Context) -} - -func (s *TaskConfig) done() <-chan struct{} { - return s.Context.Done() -} - -func (s *TaskConfig) cancel() { - s.cancelFunc() -} - -func (s *TaskConfig) db() *gorm.DB { - return s.DB -} - -func newConfig(db *gorm.DB, table string, options ...Option) (*TaskConfig, error) { - c := (&TaskConfig{}).load(options...).load(withDB(db)).load(withTable(table)) - if err := c.init(); err != nil { - return nil, err - } - return c, nil -} - -// Logger is a logging interface for logging necessary messages. -type Logger interface { - Printf(format string, args ...interface{}) - Infof(format string, args ...interface{}) - Warnf(format string, args ...interface{}) - Errorf(format string, args ...interface{}) -} - -// CtxMarshaler is used to marshal or unmarshal context. -type CtxMarshaler interface { - MarshalCtx(ctx context.Context) ([]byte, error) - UnmarshalCtx(bytes []byte) (context.Context, error) -} - -// Option represents the optional function. -type Option func(c *TaskConfig) - -// WithConfig set the whole config. -func WithConfig(config TaskConfig) Option { - return func(c *TaskConfig) { *c = config } -} - -// WithContext set the Context option. -func WithContext(ctx context.Context) Option { - return func(c *TaskConfig) { c.Context = ctx } -} - -// WithLoggerFactory set the LoggerFactory option. -func WithLoggerFactory(f func(ctx context.Context) Logger) Option { - return func(c *TaskConfig) { c.LoggerFactory = f } -} - -// WithStorageTimeout set the StorageTimeout option. -func WithStorageTimeout(d time.Duration) Option { - return func(c *TaskConfig) { c.StorageTimeout = d } -} - -// WithInitializedTimeout set the InitializedTimeout option. -func WithInitializedTimeout(d time.Duration) Option { - return func(c *TaskConfig) { c.InitializedTimeout = d } -} - -// WithRunningTimeout set the RunningTimeout option. -func WithRunningTimeout(d time.Duration) Option { - return func(c *TaskConfig) { c.RunningTimeout = d } -} - -// WithWaitTimeout set the WaitTimeout option. -func WithWaitTimeout(d time.Duration) Option { - return func(c *TaskConfig) { c.WaitTimeout = d } -} - -// WithScanInterval set the ScanInterval option. -func WithScanInterval(d time.Duration) Option { - return func(c *TaskConfig) { c.ScanInterval = d } -} - -// WithInstantScanInterval set the InstantScanInterval option. -func WithInstantScanInterval(d time.Duration) Option { - return func(c *TaskConfig) { c.InstantScanInterval = d } -} - -// WithCtxMarshaler set the CtxMarshaler option. -func WithCtxMarshaler(m CtxMarshaler) Option { - return func(c *TaskConfig) { c.CtxMarshaler = m } -} - -// WithCheckCallback set the CheckCallback option. -func WithCheckCallback(f func(logger Logger, abnormalTasks []Task)) Option { - return func(c *TaskConfig) { c.CheckCallback = f } -} - -// WithDryRun set the DryRun option. -func WithDryRun(flag bool) Option { - return func(c *TaskConfig) { c.DryRun = flag } -} - -// WithPoolSize set the PoolSize option. -func WithPoolSize(size int) Option { - return func(c *TaskConfig) { c.PoolSize = size } -} - -func withDB(db *gorm.DB) Option { - return func(c *TaskConfig) { c.DB = db } -} - -func withTable(table string) Option { - return func(c *TaskConfig) { c.Table = table } -} - -func withTaskRegister(tr taskRegister) Option { - return func(c *TaskConfig) { - c.taskRegister = tr - } -} diff --git a/config_test.go b/config_test.go deleted file mode 100644 index e44d957..0000000 --- a/config_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package gta - -import ( - "testing" - "time" - - "github.com/smartystreets/goconvey/convey" - "gorm.io/gorm" -) - -func Test_Config_init(t *testing.T) { - defaultDB := &gorm.DB{} - defaultTable := "tasks" - - convey.Convey("Test_Config_init", t, func() { - convey.Convey("normal process", func() { - tc := TaskConfig{DB: defaultDB, Table: defaultTable} - convey.So(tc.init(), convey.ShouldBeNil) - }) - - convey.Convey("empty db factory", func() { - tc := TaskConfig{} - convey.So(tc.init(), convey.ShouldNotBeNil) - }) - - convey.Convey("empty table name", func() { - tc := TaskConfig{DB: defaultDB} - convey.So(tc.init(), convey.ShouldNotBeNil) - }) - - convey.Convey("invalid running timeout", func() { - tc := TaskConfig{DB: defaultDB, Table: defaultTable, RunningTimeout: time.Hour * 365} - convey.So(tc.init(), convey.ShouldNotBeNil) - }) - - convey.Convey("invalid initialized timeout", func() { - tc := TaskConfig{DB: defaultDB, Table: defaultTable, InitializedTimeout: time.Hour * 365} - convey.So(tc.init(), convey.ShouldNotBeNil) - }) - - convey.Convey("invalid scan interval", func() { - tc := TaskConfig{DB: defaultDB, Table: defaultTable, ScanInterval: time.Hour * 365} - convey.So(tc.init(), convey.ShouldNotBeNil) - }) - - convey.Convey("invalid instant scan interval", func() { - tc := TaskConfig{DB: defaultDB, Table: defaultTable, InstantScanInterval: time.Hour * 365} - convey.So(tc.init(), convey.ShouldNotBeNil) - }) - - }) -} diff --git a/dal.go b/dal.go index 1b395e3..2a2f777 100644 --- a/dal.go +++ b/dal.go @@ -24,11 +24,11 @@ type taskDAL interface { } type taskDALImp struct { - config *TaskConfig + *options } func (s *taskDALImp) tabledDB(tx *gorm.DB) *gorm.DB { - return tx.Table(s.config.Table) + return tx.Table(s.table) } func (s *taskDALImp) Create(tx *gorm.DB, task *Task) error { diff --git a/dal_test.go b/dal_test.go index 82baeef..fa9aa01 100644 --- a/dal_test.go +++ b/dal_test.go @@ -11,7 +11,7 @@ func Test_taskDALImp_GetInitialized(t *testing.T) { convey.Convey("Test_taskDALImp_GetInitialized", t, func() { db := testDB("Test_taskDALImp_GetInitialized") convey.Convey("normal", func() { - tdal := taskDALImp{config: &TaskConfig{DB: db, Table: "tasks"}} + tdal := taskDALImp{options: &options{db: db, table: "tasks"}} convey.Convey("only has sensitive keys", func() { convey.Convey("normal time", func() { _ = tdal.Create(db, &Task{TaskKey: "t1", TaskStatus: TaskStatusInitialized, CreatedAt: time.Now(), UpdatedAt: time.Now()}) @@ -28,7 +28,7 @@ func Test_taskDALImp_GetInitialized(t *testing.T) { }) }) convey.Convey("error", func() { - tdal := taskDALImp{config: &TaskConfig{DB: db, Table: "not exist"}} + tdal := taskDALImp{options: &options{db: db, table: "not exist"}} _, err := tdal.GetInitialized(db, nil, time.Second, nil) convey.So(err, convey.ShouldNotBeNil) }) @@ -39,7 +39,7 @@ func Test_taskDALImp_Get(t *testing.T) { convey.Convey("Test_taskDALImp_Get", t, func() { convey.Convey("error", func() { db := testDB("Test_taskDALImp_Get") - tdal := taskDALImp{config: &TaskConfig{DB: db, Table: "not exist"}} + tdal := taskDALImp{options: &options{db: db, table: "not exist"}} _, err := tdal.Get(db, 1) convey.So(err, convey.ShouldNotBeNil) }) @@ -50,7 +50,7 @@ func Test_taskDALImp_GetForUpdate(t *testing.T) { convey.Convey("Test_taskDALImp_GetForUpdate", t, func() { convey.Convey("error", func() { db := testDB("Test_taskDALImp_GetForUpdate") - tdal := taskDALImp{config: &TaskConfig{DB: db, Table: "not exist"}} + tdal := taskDALImp{options: &options{db: db, table: "not exist"}} _, err := tdal.GetForUpdate(db, 1) convey.So(err, convey.ShouldNotBeNil) }) diff --git a/default.go b/default.go deleted file mode 100644 index 4e99f06..0000000 --- a/default.go +++ /dev/null @@ -1,48 +0,0 @@ -package gta - -import ( - "context" - "time" - - "github.com/panjf2000/ants/v2" - "github.com/sirupsen/logrus" -) - -var ( - defaultStorageTimeout = time.Hour * 7 * 24 - defaultWaitTimeout = time.Second * 0 - defaultScanInterval = time.Second * 5 - defaultInstantScanInvertal = time.Millisecond * 100 - defaultRunningTimeout = time.Minute * 30 - defaultInitializedTimeout = time.Minute * 5 - defaultPoolSize = ants.DefaultAntsPoolSize - defaultRetryInterval = time.Second -) - -type defaultCtxMarshaler struct{} - -func (s *defaultCtxMarshaler) MarshalCtx(ctx context.Context) ([]byte, error) { - return nil, nil -} - -func (s *defaultCtxMarshaler) UnmarshalCtx(bytes []byte) (context.Context, error) { - return context.Background(), nil -} - -func defaultContextFactory() context.Context { - return context.Background() -} - -func defaultLoggerFactory(ctx context.Context) Logger { - return logrus.NewEntry(logrus.New()) -} - -func defaultCheckCallback(logger Logger, abnormalTasks []Task) { - if len(abnormalTasks) == 0 { - return - } - logger.Errorf("[defaultCheckCallback] abnormal tasks found, total[%v]", len(abnormalTasks)) - for _, at := range abnormalTasks { - logger.Warnf("[defaultCheckCallback] abnormal task found, id[%v], task_key[%v], task_status[%v]", at.ID, at.TaskKey, at.TaskStatus) - } -} diff --git a/definition.go b/definition.go index cb371b7..0236ab9 100644 --- a/definition.go +++ b/definition.go @@ -56,11 +56,11 @@ func (s *TaskDefinition) init(key TaskKey) error { return nil } -func (s *TaskDefinition) ctxMarshaler(config *TaskConfig) CtxMarshaler { +func (s *TaskDefinition) ctxMarshaler(global CtxMarshaler) CtxMarshaler { if m := s.CtxMarshaler; m != nil { return m } - return config.CtxMarshaler + return global } func (s *TaskDefinition) retryInterval(times int) time.Duration { diff --git a/definition_test.go b/definition_test.go index 714816a..5f3f5ca 100644 --- a/definition_test.go +++ b/definition_test.go @@ -47,13 +47,13 @@ func TestTaskDefinition_ctxMarshaler(t *testing.T) { convey.Convey("TestTaskDefinition_ctxMarshaler", t, func() { convey.Convey("empty ctxMarshal in taskDef", func() { taskDef := &TaskDefinition{} - cm := taskDef.ctxMarshaler(&TaskConfig{CtxMarshaler: &defaultCtxMarshaler{}}) + cm := taskDef.ctxMarshaler(&defaultCtxMarshaler{}) convey.So(cm, convey.ShouldNotBeNil) }) convey.Convey("specify ctxMarshal in taskDef", func() { taskDef := &TaskDefinition{CtxMarshaler: &defaultCtxMarshaler{}} - cm := taskDef.ctxMarshaler(&TaskConfig{}) + cm := taskDef.ctxMarshaler(nil) convey.So(cm, convey.ShouldNotBeNil) }) }) diff --git a/error.go b/error.go index 8319502..d7a970b 100644 --- a/error.go +++ b/error.go @@ -10,18 +10,8 @@ var ( // ErrTaskNotFound represents certain task not found. ErrTaskNotFound = errors.New("task not found") - // ErrConfigEmptyTable represents TableName in the config is empty. - ErrConfigEmptyTable = errors.New("config table is empty") - // ErrConfigNilDB represents DB in the config is nil. - ErrConfigNilDB = errors.New("config db is nil") - // ErrConfigInvalidRunningTimeout represents RunningTimeout in the config is invalid. - ErrConfigInvalidRunningTimeout = errors.New("config running timeout is invalid") - // ErrConfigInvalidInitializeTimeout represents InitializeTimeout in the config is invalid. - ErrConfigInvalidInitializeTimeout = errors.New("config initialize timeout is invalid") - // ErrConfigInvalidScanInterval represents ScanInterval in the config is invalid. - ErrConfigInvalidScanInterval = errors.New("config scan interval is invalid") - // ErrConfigInvalidInstantScanInterval represents InstantScanInterval in the config is invalid. - ErrConfigInvalidInstantScanInterval = errors.New("config instant scan interval is invalid") + // ErrOption represents option is invalid. + ErrOption = errors.New("option invalid") // ErrDefNilHandler represents Handler in the task definition is nil. ErrDefNilHandler = errors.New("definition handler is nil") diff --git a/manager.go b/manager.go index 10ce681..3dd45de 100644 --- a/manager.go +++ b/manager.go @@ -11,7 +11,7 @@ import ( // TaskManager is the overall processor of task, which includes scheduler, scanner and other components type TaskManager struct { - tc *TaskConfig + *options tr taskRegister tass taskAssembler tsch taskScheduler @@ -26,7 +26,7 @@ type TaskManager struct { // Start starts the TaskManager. This function should be called before any other functions in a TaskManager is called. func (s *TaskManager) Start() { s.startOnce.Do(func() { - if s.tc.DryRun { + if s.dryRun { // don't start scan and monitor process in dry run mode return } @@ -58,7 +58,7 @@ func (s *TaskManager) Register(key TaskKey, definition TaskDefinition) { // If the retry times exceeds the maximum config value, the task is marked 'failed' in the database with error logs // recorded. In these cases, maybe a manual operation is essential. // -// The context passed in should be consistent with the 'CtxMarshaler' value defined in the overall configuration or the +// The context passed in should be consistent with the 'ctxMarshaler' value defined in the overall configuration or the // task definition. func (s *TaskManager) Run(ctx context.Context, key TaskKey, arg interface{}) error { return s.Transaction(func(tx *gorm.DB) error { return s.RunWithTx(tx, ctx, key, arg) }) @@ -106,9 +106,9 @@ func (s *TaskManager) Transaction(fc func(tx *gorm.DB) error) (err error) { // The wait parameter determines whether to wait for all running tasks to complete. func (s *TaskManager) Stop(wait bool) { s.stopOnce.Do(func() { - if !s.tc.DryRun { + if !s.dryRun { // send global cancel signal - s.tc.cancel() + s.cancel() } s.tsch.Stop(wait) }) @@ -128,12 +128,12 @@ func (s *TaskManager) Stop(wait bool) { // ForceRerunTasks changes specific tasks to 'initialized'. func (s *TaskManager) ForceRerunTasks(taskIDs []uint64, status TaskStatus) (int64, error) { - return s.tdal.UpdateStatusByIDs(s.tc.db(), taskIDs, status, TaskStatusInitialized) + return s.tdal.UpdateStatusByIDs(s.getDB(), taskIDs, status, TaskStatusInitialized) } // QueryUnsuccessfulTasks checks initialized, running or failed tasks. func (s *TaskManager) QueryUnsuccessfulTasks(limit, offset int) ([]Task, error) { - return s.tdal.GetSliceExcludeSucceeded(s.tc.db(), s.tr.GetBuiltInKeys(), limit, offset) + return s.tdal.GetSliceExcludeSucceeded(s.getDB(), s.tr.GetBuiltInKeys(), limit, offset) } func (s *TaskManager) registerBuiltinTasks() { @@ -146,19 +146,19 @@ func (s *TaskManager) registerBuiltinTasks() { // The database and task table must be provided because this tool relies heavily on the database. For more information // about the table schema, please refer to 'model.sql'. func NewTaskManager(db *gorm.DB, table string, options ...Option) *TaskManager { - tc, err := newConfig(db, table, options...) + opts, err := newOptions(db, table, options...) if err != nil { panic(err) } - tr := tc.taskRegister - tdal := &taskDALImp{config: tc} - tass := &taskAssemblerImp{config: tc} - pool, err := ants.NewPool(tc.PoolSize, ants.WithLogger(tc.logger()), ants.WithNonblocking(true)) + tr := opts.taskRegister + tdal := &taskDALImp{options: opts} + tass := &taskAssemblerImp{options: opts} + pool, err := ants.NewPool(opts.poolSize, ants.WithLogger(opts.logger()), ants.WithNonblocking(true)) if err != nil { panic(err) } - tsch := &taskSchedulerImp{config: tc, register: tr, dal: tdal, assembler: tass, pool: pool} - tmon := &taskMonitorImp{config: tc, register: tr, dal: tdal, assembler: tass} - tscn := &taskScannerImp{config: tc, register: tr, dal: tdal, scheduler: tsch} - return &TaskManager{tc: tc, tr: tr, tass: tass, tsch: tsch, tdal: tdal, tmon: tmon, tscn: tscn} + tsch := &taskSchedulerImp{options: opts, register: tr, dal: tdal, assembler: tass, pool: pool} + tmon := &taskMonitorImp{options: opts, register: tr, dal: tdal, assembler: tass} + tscn := &taskScannerImp{options: opts, register: tr, dal: tdal, scheduler: tsch} + return &TaskManager{options: opts, tr: tr, tass: tass, tsch: tsch, tdal: tdal, tmon: tmon, tscn: tscn} } diff --git a/manager_test.go b/manager_test.go index 4f384b6..178bded 100644 --- a/manager_test.go +++ b/manager_test.go @@ -36,7 +36,7 @@ func TestTaskManager_Start(t *testing.T) { m.Start() defer m.Stop(false) convey.So(m.tr.GetBuiltInKeys(), convey.ShouldHaveLength, 2) - task, err := m.tdal.Get(m.tc.db(), taskCheckAbnormalID) + task, err := m.tdal.Get(m.getDB(), taskCheckAbnormalID) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldNotBeNil) }) @@ -46,7 +46,7 @@ func TestTaskManager_Start(t *testing.T) { m.Start() defer m.Stop(false) convey.So(m.tr.GetBuiltInKeys(), convey.ShouldHaveLength, 0) - task, err := m.tdal.Get(m.tc.db(), taskCheckAbnormalID) + task, err := m.tdal.Get(m.getDB(), taskCheckAbnormalID) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldBeNil) }) @@ -97,7 +97,7 @@ func TestTaskManager_Run(t *testing.T) { m.Stop(true) convey.So(err, convey.ShouldBeNil) convey.So(t1Run, convey.ShouldEqual, 1) - task, err := m.tdal.Get(m.tc.db(), 10001) + task, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldNotBeNil) convey.So(task.TaskStatus, convey.ShouldEqual, TaskStatusSucceeded) @@ -105,7 +105,7 @@ func TestTaskManager_Run(t *testing.T) { convey.Convey("with options", func() { m := NewTaskManager(testDB("TestTaskManager_Run"), "tasks") - convey.Convey("with CtxMarshaler", func() { + convey.Convey("with ctxMarshaler", func() { var t1Run int64 m.Register("t1", TaskDefinition{ Handler: testWrappedHandler(func(ctx context.Context, arg interface{}) (err error) { @@ -173,7 +173,7 @@ func TestTaskManager_Run(t *testing.T) { m.Stop(true) convey.So(err, convey.ShouldBeNil) convey.So(t1Run, convey.ShouldEqual, 1) - task, err := m.tdal.Get(m.tc.db(), 10001) + task, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldBeNil) }) @@ -181,7 +181,7 @@ func TestTaskManager_Run(t *testing.T) { convey.Convey("with InitTimeoutSensitive", func() { var t1Run int64 m.Register("t1", TaskDefinition{Handler: testCountHandler(&t1Run), InitTimeoutSensitive: true}) - err1 := m.tdal.Create(m.tc.db(), &Task{ + err1 := m.tdal.Create(m.getDB(), &Task{ ID: 10001, TaskKey: "t1", TaskStatus: TaskStatusInitialized, @@ -191,7 +191,7 @@ func TestTaskManager_Run(t *testing.T) { CreatedAt: time.Now().Add(-time.Hour), UpdatedAt: time.Now().Add(-time.Hour), }) - err2 := m.tdal.Create(m.tc.db(), &Task{ + err2 := m.tdal.Create(m.getDB(), &Task{ ID: 10002, TaskKey: "t1", TaskStatus: TaskStatusInitialized, @@ -207,7 +207,7 @@ func TestTaskManager_Run(t *testing.T) { time.Sleep(time.Second) m.Stop(true) convey.So(t1Run, convey.ShouldEqual, 1) - task, err := m.tdal.Get(m.tc.db(), 10001) + task, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldNotBeNil) convey.So(task.TaskStatus, convey.ShouldEqual, TaskStatusInitialized) @@ -222,7 +222,7 @@ func TestTaskManager_Run(t *testing.T) { m.Stop(false) err := m.Run(context.TODO(), "t1", nil) convey.So(err, convey.ShouldBeNil) - task, err := m.tdal.Get(m.tc.db(), 10001) + task, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldNotBeNil) convey.So(task.TaskKey, convey.ShouldEqual, "t1") @@ -254,7 +254,7 @@ func TestTaskManager_Run(t *testing.T) { m.Stop(true) convey.So(errSlice, convey.ShouldHaveLength, 0) convey.So(t1Run, convey.ShouldBeLessThan, 10) - task, err := m.tdal.Get(m.tc.db(), 10010) + task, err := m.tdal.Get(m.getDB(), 10010) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldNotBeNil) convey.So(task.TaskKey, convey.ShouldEqual, "t1") @@ -292,7 +292,7 @@ func TestTaskManager_Run(t *testing.T) { m.Stop(true) convey.So(err, convey.ShouldBeNil) convey.So(t1Run, convey.ShouldEqual, 1) - task, err := m.tdal.Get(m.tc.db(), 10001) + task, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldNotBeNil) convey.So(task.TaskStatus, convey.ShouldEqual, TaskStatusFailed) @@ -309,7 +309,7 @@ func TestTaskManager_Run(t *testing.T) { m.Stop(true) convey.So(err, convey.ShouldBeNil) convey.So(t1Run, convey.ShouldEqual, 1) - task, err := m.tdal.Get(m.tc.db(), 10001) + task, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldNotBeNil) convey.So(task.TaskStatus, convey.ShouldEqual, TaskStatusFailed) @@ -375,10 +375,10 @@ func TestTaskManager_RunWithTx(t *testing.T) { convey.So(err, convey.ShouldNotBeNil) convey.So(t1Run, convey.ShouldEqual, 0) convey.So(t2Run, convey.ShouldEqual, 0) - task1, err := m.tdal.Get(m.tc.db(), 10001) + task1, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task1, convey.ShouldBeNil) - task2, err := m.tdal.Get(m.tc.db(), 10002) + task2, err := m.tdal.Get(m.getDB(), 10002) convey.So(err, convey.ShouldBeNil) convey.So(task2, convey.ShouldBeNil) }) @@ -387,7 +387,7 @@ func TestTaskManager_RunWithTx(t *testing.T) { convey.Convey("not builtin transaction", func() { convey.Convey("transaction succeeded", func() { m.Start() - err := m.tc.db().Transaction(func(tx *gorm.DB) error { + err := m.getDB().Transaction(func(tx *gorm.DB) error { if err := m.RunWithTx(tx, context.TODO(), "t1", nil); err != nil { return err } @@ -404,7 +404,7 @@ func TestTaskManager_RunWithTx(t *testing.T) { }) convey.Convey("transaction failed", func() { m.Start() - err := m.tc.db().Transaction(func(tx *gorm.DB) error { + err := m.getDB().Transaction(func(tx *gorm.DB) error { if err := m.RunWithTx(tx, context.TODO(), "t1", nil); err != nil { return err } @@ -414,10 +414,10 @@ func TestTaskManager_RunWithTx(t *testing.T) { convey.So(err, convey.ShouldNotBeNil) convey.So(t1Run, convey.ShouldEqual, 0) convey.So(t2Run, convey.ShouldEqual, 0) - task1, err := m.tdal.Get(m.tc.db(), 10001) + task1, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task1, convey.ShouldBeNil) - task2, err := m.tdal.Get(m.tc.db(), 10002) + task2, err := m.tdal.Get(m.getDB(), 10002) convey.So(err, convey.ShouldBeNil) convey.So(task2, convey.ShouldBeNil) }) @@ -450,7 +450,7 @@ func TestTaskManager_Stop(t *testing.T) { err := m.Run(context.TODO(), "t1", nil) m.Stop(true) convey.So(err, convey.ShouldBeNil) - task, err := m.tdal.Get(m.tc.db(), 10001) + task, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldNotBeNil) convey.So(task.TaskStatus, convey.ShouldEqual, TaskStatusSucceeded) @@ -465,7 +465,7 @@ func TestTaskManager_Stop(t *testing.T) { err := m.Run(context.TODO(), "t1", nil) m.Stop(false) convey.So(err, convey.ShouldBeNil) - task, err := m.tdal.Get(m.tc.db(), 10001) + task, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldNotBeNil) convey.So(task.TaskStatus, convey.ShouldEqual, TaskStatusInitialized) @@ -482,7 +482,7 @@ func TestTaskManager_Stop(t *testing.T) { err := m.Run(context.TODO(), "t1", nil) m.Stop(true) convey.So(err, convey.ShouldBeNil) - task, err := m.tdal.Get(m.tc.db(), 10001) + task, err := m.tdal.Get(m.getDB(), 10001) convey.So(err, convey.ShouldBeNil) convey.So(task, convey.ShouldBeNil) }) @@ -494,7 +494,7 @@ func TestTaskManager_ForceRerunTasks(t *testing.T) { var t1Run int64 m.Register("t1", TaskDefinition{Handler: testCountHandler(&t1Run)}) convey.Convey("TestTaskManager_ForceRerunTasks", t, func() { - _ = m.tdal.Create(m.tc.db(), &Task{ + _ = m.tdal.Create(m.getDB(), &Task{ ID: 10001, TaskKey: "t1", TaskStatus: TaskStatusFailed, @@ -515,7 +515,7 @@ func TestTaskManager_QueryUnsuccessfulTasks(t *testing.T) { var t1Run int64 m.Register("t1", TaskDefinition{Handler: testCountHandler(&t1Run)}) convey.Convey("TestTaskManager_QueryUnsuccessfulTasks", t, func() { - _ = m.tdal.Create(m.tc.db(), &Task{ + _ = m.tdal.Create(m.getDB(), &Task{ ID: 10001, TaskKey: "t1", TaskStatus: TaskStatusFailed, @@ -578,7 +578,6 @@ func TestTaskManager_RaceWithMySQL(t *testing.T) { db, _ := gorm.Open(mysql.Open("root@(127.0.0.1:3306)/test_db?charset=utf8&parseTime=True&loc=UTC")) tmFactory := func(tag string) *TaskManager { m := NewTaskManager(db, "tasks", - WithConfig(TaskConfig{}), WithLoggerFactory(func(ctx context.Context) Logger { fields := logrus.Fields{ "manager": tag, @@ -629,7 +628,7 @@ func TestTaskManager_RaceWithMySQL(t *testing.T) { // run with tx - not builtin for i := 0; i < 5; i++ { ctx := mockTaskContext(fmt.Sprintf("request_tx_nonbuiltin_%d", i), fmt.Sprintf("user%d", i)) - err := m0.tc.db().Transaction(func(tx *gorm.DB) error { + err := m0.getDB().Transaction(func(tx *gorm.DB) error { if err := m0.RunWithTx(tx, ctx, "t", &testTaskArg{A: i, B: 2 * i}); err != nil { return err } diff --git a/monitor.go b/monitor.go index ff2faf0..4c9d67d 100644 --- a/monitor.go +++ b/monitor.go @@ -11,14 +11,14 @@ type taskMonitor interface { } type taskMonitorImp struct { - config *TaskConfig + *options register taskRegister dal taskDAL assembler taskAssembler } func (s *taskMonitorImp) GoMonitorBuiltinTasks() { - logger := s.config.logger() + logger := s.logger() for _, key := range s.register.GetBuiltInKeys() { taskDef, _ := s.register.GetDefinition(key) s.goMonitorBuiltinTask(taskDef) @@ -31,7 +31,7 @@ func (s *taskMonitorImp) goMonitorBuiltinTask(taskDef *TaskDefinition) { defer panicHandler() for { select { - case <-s.config.done(): + case <-s.done(): return default: s.monitorBuiltinTask(taskDef) @@ -42,16 +42,16 @@ func (s *taskMonitorImp) goMonitorBuiltinTask(taskDef *TaskDefinition) { } func (s *taskMonitorImp) monitorBuiltinTask(taskDef *TaskDefinition) { - logger := s.config.logger() + logger := s.logger() - newTask, err := s.assembler.AssembleTask(s.config.Context, taskDef, taskDef.argument) + newTask, err := s.assembler.AssembleTask(s.context, taskDef, taskDef.argument) if err != nil { logger.Errorf("[monitorBuiltinTask] assemble buitin task failed, err[%v], task_key[%v]", err, taskDef.key) return } newTask.TaskStatus = TaskStatusInitialized - if err := s.config.db().Transaction(func(tx *gorm.DB) error { + if err := s.getDB().Transaction(func(tx *gorm.DB) error { if task, err := s.dal.GetForUpdate(tx, taskDef.taskID); err != nil { return err } else if task == nil { @@ -70,7 +70,7 @@ func (s *taskMonitorImp) monitorBuiltinTask(taskDef *TaskDefinition) { }); err == ErrTaskNotFound { // need create, ignore primary key conflict // TODO: distinguish primary key conflict error - _ = s.dal.Create(s.config.db(), newTask) + _ = s.dal.Create(s.getDB(), newTask) return } else if err != nil { logger.Errorf("[monitorBuiltinTask] update transaction failed, err[%v], task_key[%v]", err, taskDef.key) @@ -82,7 +82,7 @@ func (s *taskMonitorImp) needLoopBuiltinTask(task *Task, taskDef *TaskDefinition // normal loop if task_status is succeeded or failed needNormalLoop := time.Since(task.UpdatedAt) >= taskDef.loopInterval && (task.TaskStatus == TaskStatusSucceeded || task.TaskStatus == TaskStatusFailed) // force loop if abnormal running found - needForceLoop := time.Since(task.UpdatedAt) >= s.config.RunningTimeout && task.TaskStatus == TaskStatusRunning + needForceLoop := time.Since(task.UpdatedAt) >= s.runningTimeout && task.TaskStatus == TaskStatusRunning return needNormalLoop || needForceLoop } diff --git a/monitor_test.go b/monitor_test.go index 3251134..f19704e 100644 --- a/monitor_test.go +++ b/monitor_test.go @@ -12,13 +12,13 @@ func Test_taskMonitorImp_monitorBuiltinTask(t *testing.T) { convey.Convey("Test_taskMonitorImp_monitorBuiltinTask", t, func() { convey.Convey("error", func() { convey.Convey("assemble task error", func() { - tc, _ := newConfig(&gorm.DB{}, "tasks") - mon := &taskMonitorImp{config: tc, assembler: &taskAssemblerImp{config: tc}} + tc, _ := newOptions(&gorm.DB{}, "tasks") + mon := &taskMonitorImp{options: tc, assembler: &taskAssemblerImp{options: tc}} convey.So(func() { mon.monitorBuiltinTask(&TaskDefinition{ArgType: reflect.TypeOf(""), argument: 0}) }, convey.ShouldNotPanic) }) convey.Convey("dal error", func() { - tc, _ := newConfig(testDB("Test_taskMonitorImp_monitorBuiltinTask"), "not exist") - mon := &taskMonitorImp{config: tc, assembler: &taskAssemblerImp{config: tc}, dal: &taskDALImp{config: tc}} + tc, _ := newOptions(testDB("Test_taskMonitorImp_monitorBuiltinTask"), "not exist") + mon := &taskMonitorImp{options: tc, assembler: &taskAssemblerImp{options: tc}, dal: &taskDALImp{options: tc}} convey.So(func() { mon.monitorBuiltinTask(&TaskDefinition{ArgType: reflect.TypeOf(""), argument: ""}) }, convey.ShouldNotPanic) }) }) diff --git a/option.go b/option.go new file mode 100644 index 0000000..d34204b --- /dev/null +++ b/option.go @@ -0,0 +1,391 @@ +package gta + +import ( + "context" + "fmt" + "time" + + "github.com/panjf2000/ants/v2" + "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +func newOptions(db *gorm.DB, table string, opts ...Option) (*options, error) { + options, err := newDefaultOptions().apply(append(optionGroup{withDB(db), withTable(table)}, opts...)) + if err != nil { + return nil, err + } + return options, nil +} + +// options contains all options of a TaskManager. +type options struct { + // must provide, db for async task table + db *gorm.DB + // must provide, async task table name + table string + + // optional, context for the task mansger + context context.Context + // optional, logger factory + loggerFactory func(ctx context.Context) Logger + // optional, grouped time options + groupedTimeOptions + // optional, wait timeout in Stop() process + waitTimeout time.Duration + // optional, context marshaler to store or recover a context + ctxMarshaler CtxMarshaler + // optional, callback function for abnormal tasks + checkCallback func(logger Logger, abnormalTasks []Task) + // optional, flag for dry run mode + dryRun bool + // optional, goroutine pool size for scheduling tasks + poolSize int + + // optional, task register + taskRegister taskRegister + // global cancel function + cancelFunc context.CancelFunc +} + +func (s *options) apply(opts ...Option) (*options, error) { + for _, opt := range opts { + opt.apply(s) + } + for _, opt := range opts { + if err := opt.verify(s); err != nil { + return s, err + } + } + return s, nil +} + +func (s *options) logger() Logger { + return s.loggerFactory(s.context) +} + +func (s *options) done() <-chan struct{} { + return s.context.Done() +} + +func (s *options) cancel() { + s.cancelFunc() +} + +func (s *options) getDB() *gorm.DB { + return s.db +} + +// Logger is a logging interface for logging necessary messages. +type Logger interface { + Printf(format string, args ...interface{}) + Infof(format string, args ...interface{}) + Warnf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) +} + +// CtxMarshaler is used to marshal or unmarshal context. +type CtxMarshaler interface { + MarshalCtx(ctx context.Context) ([]byte, error) + UnmarshalCtx(bytes []byte) (context.Context, error) +} + +type groupedTimeOptions struct { + // optional, determine when a normal task can be cleaned + storageTimeout time.Duration + // optional, determine whether a initialized task is abnormal + initializedTimeout time.Duration + // optional, determine whether a running task is abnormal + runningTimeout time.Duration + // optional, scan interval + scanInterval time.Duration + // optional, instant scan interval + instantScanInterval time.Duration +} + +func (t groupedTimeOptions) verify() error { + if t.storageTimeout <= 0 || t.initializedTimeout <= 0 || t.runningTimeout <= 0 || t.scanInterval <= 0 || t.instantScanInterval <= 0 { + return fmt.Errorf("%w: groupedTimeOptions #1", ErrOption) + } + if t.runningTimeout > t.storageTimeout { + return fmt.Errorf("%w: groupedTimeOptions #2", ErrOption) + } + if t.initializedTimeout > t.storageTimeout { + return fmt.Errorf("%w: groupedTimeOptions #3", ErrOption) + } + if t.scanInterval > t.storageTimeout || t.scanInterval > t.initializedTimeout || t.scanInterval > t.runningTimeout { + return fmt.Errorf("%w: groupedTimeOptions #4", ErrOption) + } + if t.instantScanInterval > t.scanInterval { + return fmt.Errorf("%w: groupedTimeOptions #5", ErrOption) + } + return nil +} + +// Option is a interface. +type Option interface { + apply(opts *options) + verify(opts *options) error +} + +type option struct { + applyFunc func(opts *options) + verifyFunc func(opts *options) error +} + +func (o option) apply(opts *options) { + o.applyFunc(opts) +} + +func (o option) verify(opts *options) error { + if o.verifyFunc == nil { + return nil + } + return o.verifyFunc(opts) +} + +type optionGroup []Option + +func (g optionGroup) apply(opts *options) { + for _, opt := range g { + opt.apply(opts) + } +} + +func (g optionGroup) verify(opts *options) error { + for _, opt := range g { + if err := opt.verify(opts); err != nil { + return err + } + } + return nil +} + +// WithContext set the context option. +func WithContext(ctx context.Context) Option { + return &option{ + applyFunc: func(opts *options) { opts.context, opts.cancelFunc = context.WithCancel(ctx) }, + } +} + +// WithLoggerFactory set the loggerFactory option. +func WithLoggerFactory(f func(ctx context.Context) Logger) Option { + return &option{ + applyFunc: func(opts *options) { opts.loggerFactory = f }, + verifyFunc: func(opts *options) error { + if opts.loggerFactory == nil { + return fmt.Errorf("%w: loggerFactory", ErrOption) + } + return nil + }, + } +} + +// WithStorageTimeout set the storageTimeout option. +func WithStorageTimeout(d time.Duration) Option { + return &option{ + applyFunc: func(opts *options) { opts.storageTimeout = d }, + verifyFunc: func(opts *options) error { + return opts.groupedTimeOptions.verify() + }, + } +} + +// WithInitializedTimeout set the initializedTimeout option. +func WithInitializedTimeout(d time.Duration) Option { + return &option{ + applyFunc: func(opts *options) { opts.initializedTimeout = d }, + verifyFunc: func(opts *options) error { + return opts.groupedTimeOptions.verify() + }, + } +} + +// WithRunningTimeout set the runningTimeout option. +func WithRunningTimeout(d time.Duration) Option { + return &option{ + applyFunc: func(opts *options) { opts.runningTimeout = d }, + verifyFunc: func(opts *options) error { + return opts.groupedTimeOptions.verify() + }, + } +} + +// WithScanInterval set the scanInterval option. +func WithScanInterval(d time.Duration) Option { + return &option{ + applyFunc: func(opts *options) { opts.scanInterval = d }, + verifyFunc: func(opts *options) error { + return opts.groupedTimeOptions.verify() + }, + } +} + +// WithInstantScanInterval set the instantScanInterval option. +func WithInstantScanInterval(d time.Duration) Option { + return &option{ + applyFunc: func(opts *options) { opts.instantScanInterval = d }, + verifyFunc: func(opts *options) error { + return opts.groupedTimeOptions.verify() + }, + } +} + +// WithWaitTimeout set the waitTimeout option. +func WithWaitTimeout(d time.Duration) Option { + return &option{ + applyFunc: func(opts *options) { opts.waitTimeout = d }, + verifyFunc: func(opts *options) error { + if opts.waitTimeout <= 0 { + return fmt.Errorf("%w: waitTimeout", ErrOption) + } + return nil + }, + } +} + +// WithCtxMarshaler set the ctxMarshaler option. +func WithCtxMarshaler(m CtxMarshaler) Option { + return &option{ + applyFunc: func(opts *options) { opts.ctxMarshaler = m }, + verifyFunc: func(opts *options) error { + if opts.ctxMarshaler == nil { + return fmt.Errorf("%w: ctxMarshaler", ErrOption) + } + return nil + }, + } +} + +// WithCheckCallback set the checkCallback option. +func WithCheckCallback(f func(logger Logger, abnormalTasks []Task)) Option { + return &option{ + applyFunc: func(opts *options) { opts.checkCallback = f }, + verifyFunc: func(opts *options) error { + if opts.checkCallback == nil { + return fmt.Errorf("%w: checkCallback", ErrOption) + } + return nil + }, + } +} + +// WithDryRun set the dryRun option. +func WithDryRun(flag bool) Option { + return &option{ + applyFunc: func(opts *options) { opts.dryRun = flag }, + } +} + +// WithPoolSize set the poolSize option. +func WithPoolSize(size int) Option { + return &option{ + applyFunc: func(opts *options) { opts.poolSize = size }, + verifyFunc: func(opts *options) error { + if opts.poolSize <= 0 { + return fmt.Errorf("%w: poolSize", ErrOption) + } + return nil + }, + } +} + +func withDB(db *gorm.DB) Option { + return &option{ + applyFunc: func(opts *options) { opts.db = db }, + verifyFunc: func(opts *options) error { + if opts.db == nil { + return fmt.Errorf("%w: db", ErrOption) + } + return nil + }, + } +} + +func withTable(table string) Option { + return &option{ + applyFunc: func(opts *options) { opts.table = table }, + verifyFunc: func(opts *options) error { + if opts.table == "" { + return fmt.Errorf("%w: table", ErrOption) + } + return nil + }, + } +} + +func withTaskRegister(tr taskRegister) Option { + return &option{ + applyFunc: func(opts *options) { opts.taskRegister = tr }, + verifyFunc: func(opts *options) error { + if opts.taskRegister == nil { + return fmt.Errorf("%w: taskRegister", ErrOption) + } + return nil + }, + } +} + +// newDefaultOptions generate default options +func newDefaultOptions() *options { + ctx, cancelFunc := context.WithCancel(defaultContextFactory()) + return &options{ + db: nil, + table: "", + context: ctx, + loggerFactory: defaultLoggerFactory, + groupedTimeOptions: groupedTimeOptions{ + storageTimeout: defaultStorageTimeout, + initializedTimeout: defaultInitializedTimeout, + runningTimeout: defaultRunningTimeout, + scanInterval: defaultScanInterval, + instantScanInterval: defaultInstantScanInterval, + }, + waitTimeout: defaultWaitTimeout, + ctxMarshaler: &defaultCtxMarshaler{}, + checkCallback: defaultCheckCallback, + dryRun: false, + poolSize: defaultPoolSize, + taskRegister: &taskRegisterImp{}, + cancelFunc: cancelFunc, + } +} + +var ( + defaultStorageTimeout = time.Hour * 7 * 24 + defaultWaitTimeout = time.Second * 0 + defaultScanInterval = time.Second * 5 + defaultInstantScanInterval = time.Millisecond * 100 + defaultRunningTimeout = time.Minute * 30 + defaultInitializedTimeout = time.Minute * 5 + defaultPoolSize = ants.DefaultAntsPoolSize + defaultRetryInterval = time.Second +) + +type defaultCtxMarshaler struct{} + +func (s *defaultCtxMarshaler) MarshalCtx(ctx context.Context) ([]byte, error) { + return nil, nil +} + +func (s *defaultCtxMarshaler) UnmarshalCtx(bytes []byte) (context.Context, error) { + return context.Background(), nil +} + +func defaultContextFactory() context.Context { + return context.Background() +} + +func defaultLoggerFactory(ctx context.Context) Logger { + return logrus.NewEntry(logrus.New()) +} + +func defaultCheckCallback(logger Logger, abnormalTasks []Task) { + if len(abnormalTasks) == 0 { + return + } + logger.Errorf("[defaultCheckCallback] abnormal tasks found, total[%v]", len(abnormalTasks)) + for _, at := range abnormalTasks { + logger.Warnf("[defaultCheckCallback] abnormal task found, id[%v], task_key[%v], task_status[%v]", at.ID, at.TaskKey, at.TaskStatus) + } +} diff --git a/option_test.go b/option_test.go new file mode 100644 index 0000000..d44bece --- /dev/null +++ b/option_test.go @@ -0,0 +1,79 @@ +package gta + +import ( + "testing" + "time" + + "github.com/smartystreets/goconvey/convey" + "gorm.io/gorm" +) + +func Test_newOptions(t *testing.T) { + defaultDB := &gorm.DB{} + defaultTable := "tasks" + + convey.Convey("Test_newOptions", t, func() { + convey.Convey("normal process", func() { + _, err := newOptions(defaultDB, defaultTable) + convey.So(err, convey.ShouldBeNil) + }) + + convey.Convey("nil db", func() { + _, err := newOptions(defaultDB, defaultTable, withDB(nil)) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("empty table name", func() { + _, err := newOptions(defaultDB, defaultTable, withTable("")) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("nil logger factory", func() { + _, err := newOptions(defaultDB, defaultTable, WithLoggerFactory(nil)) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("nil ctxMashaler", func() { + _, err := newOptions(defaultDB, defaultTable, WithCtxMarshaler(nil)) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("invalid running timeout", func() { + _, err := newOptions(defaultDB, defaultTable, WithRunningTimeout(time.Hour*365)) + convey.So(err, convey.ShouldNotBeNil) + + _, err = newOptions(defaultDB, defaultTable, WithRunningTimeout(-time.Hour*365)) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("invalid initialized timeout", func() { + _, err := newOptions(defaultDB, defaultTable, WithInitializedTimeout(time.Hour*365)) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("invalid scan interval", func() { + _, err := newOptions(defaultDB, defaultTable, WithScanInterval(time.Hour*365)) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("invalid instant scan interval", func() { + _, err := newOptions(defaultDB, defaultTable, WithInstantScanInterval(time.Hour*365)) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("invalid wait timeout", func() { + _, err := newOptions(defaultDB, defaultTable, WithWaitTimeout(-time.Hour*365)) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("invalid pool size", func() { + _, err := newOptions(defaultDB, defaultTable, WithPoolSize(-1)) + convey.So(err, convey.ShouldNotBeNil) + }) + + convey.Convey("nil task register", func() { + _, err := newOptions(defaultDB, defaultTable, withTaskRegister(nil)) + convey.So(err, convey.ShouldNotBeNil) + }) + }) +} diff --git a/scanner.go b/scanner.go index 611029c..58132a2 100644 --- a/scanner.go +++ b/scanner.go @@ -10,7 +10,7 @@ type taskScanner interface { } type taskScannerImp struct { - config *TaskConfig + *options register taskRegister dal taskDAL scheduler taskScheduler @@ -18,29 +18,28 @@ type taskScannerImp struct { } func (s *taskScannerImp) GoScanAndSchedule() { - logger := s.config.logger() - logger.Infof("[GoScanAndSchedule] scan and run start, scan interval[%v], instant scan interval[%v]", - s.config.ScanInterval, s.config.InstantScanInterval) + logger := s.logger() + logger.Infof("[GoScanAndSchedule] scan and run start, scan interval[%v], instant scan interval[%v]", s.scanInterval, s.instantScanInterval) go func() { defer panicHandler() for { select { - case <-s.config.done(): + case <-s.done(): return default: s.scanAndSchedule() - time.Sleep(s.scanInterval()) + time.Sleep(s.randomScanInterval()) } } }() } func (s *taskScannerImp) scanAndSchedule() { - logger := s.config.logger() + logger := s.logger() if !s.scheduler.CanSchedule() { // the schedule has reached its capacity limit - s.switchOffInstantScan() + s.swishOffInstantScan() return } @@ -50,20 +49,20 @@ func (s *taskScannerImp) scanAndSchedule() { if err != ErrTaskNotFound { logger.Errorf("[scanAndSchedule] claim task err, err[%v]", err) } - s.switchOffInstantScan() + s.swishOffInstantScan() return } else if task != nil { s.scheduler.GoScheduleTask(task) } - s.switchOnInstantScan() + s.swishOnInstantScan() } -func (s *taskScannerImp) scanInterval() time.Duration { +func (s *taskScannerImp) randomScanInterval() time.Duration { if s.needInstantScan() { - return randomInterval(s.config.InstantScanInterval) + return randomInterval(s.instantScanInterval) } - return randomInterval(s.config.ScanInterval) + return randomInterval(s.scanInterval) } func (s *taskScannerImp) needInstantScan() bool { @@ -74,19 +73,17 @@ func (s *taskScannerImp) needInstantScan() bool { return iv.(bool) } -func (s *taskScannerImp) switchOffInstantScan() { +func (s *taskScannerImp) swishOffInstantScan() { s.instantScan.Store(false) } -func (s *taskScannerImp) switchOnInstantScan() { +func (s *taskScannerImp) swishOnInstantScan() { s.instantScan.Store(true) } func (s *taskScannerImp) claimInitializedTask() (*Task, error) { - tc := s.config - sensitiveKeys, insensitiveKeys := s.register.GroupKeysByInitTimeoutSensitivity() - task, err := s.dal.GetInitialized(tc.db(), sensitiveKeys, tc.InitializedTimeout, insensitiveKeys) + task, err := s.dal.GetInitialized(s.getDB(), sensitiveKeys, s.initializedTimeout, insensitiveKeys) if err != nil { return nil, err } else if task == nil { @@ -95,11 +92,11 @@ func (s *taskScannerImp) claimInitializedTask() (*Task, error) { } select { - case <-tc.done(): + case <-s.done(): // abort claim when cancel signal received return nil, nil default: - if rowsAffected, err := s.dal.UpdateStatusByIDs(tc.db(), []uint64{task.ID}, task.TaskStatus, TaskStatusRunning); err != nil { + if rowsAffected, err := s.dal.UpdateStatusByIDs(s.getDB(), []uint64{task.ID}, task.TaskStatus, TaskStatusRunning); err != nil { return nil, err } else if rowsAffected == 0 { // task is claimed by others, ignore error diff --git a/scanner_test.go b/scanner_test.go index 63b4b4d..0f0cba1 100644 --- a/scanner_test.go +++ b/scanner_test.go @@ -10,12 +10,12 @@ import ( func Test_taskScannerImp_claimInitializedTask(t *testing.T) { convey.Convey("Test_taskScannerImp_claimInitializedTask", t, func() { convey.Convey("ctx cancelled", func() { - tc, _ := newConfig(testDB("Test_taskScannerImp_claimInitializedTask"), "tasks") + tc, _ := newOptions(testDB("Test_taskScannerImp_claimInitializedTask"), "tasks") tr := &taskRegisterImp{} - tdal := &taskDALImp{config: tc} - tscn := &taskScannerImp{config: tc, register: tr, dal: tdal} + tdal := &taskDALImp{options: tc} + tscn := &taskScannerImp{options: tc, register: tr, dal: tdal} _ = tr.Register("t1", TaskDefinition{Handler: testWrappedHandler()}) - _ = tdal.Create(tc.db(), &Task{TaskKey: "t1", TaskStatus: TaskStatusInitialized}) + _ = tdal.Create(tc.getDB(), &Task{TaskKey: "t1", TaskStatus: TaskStatusInitialized}) tc.cancel() task, err := tscn.claimInitializedTask() convey.So(err, convey.ShouldBeNil) @@ -23,10 +23,10 @@ func Test_taskScannerImp_claimInitializedTask(t *testing.T) { }) convey.Convey("error", func() { - tc, _ := newConfig(testDB("Test_taskScannerImp_claimInitializedTask"), "not exist") + tc, _ := newOptions(testDB("Test_taskScannerImp_claimInitializedTask"), "not exist") tr := &taskRegisterImp{} - tdal := &taskDALImp{config: tc} - tscn := &taskScannerImp{config: tc, register: tr, dal: tdal} + tdal := &taskDALImp{options: tc} + tscn := &taskScannerImp{options: tc, register: tr, dal: tdal} _, err := tscn.claimInitializedTask() convey.So(err, convey.ShouldNotBeNil) }) @@ -37,12 +37,12 @@ func Test_taskScannerImp_claimInitializedTask(t *testing.T) { func Test_taskScannerImp_scanAndSchedule(t *testing.T) { convey.Convey("Test_taskScannerImp_scanAndSchedule", t, func() { convey.Convey("error", func() { - tc, _ := newConfig(testDB("Test_taskScannerImp_scanAndSchedule"), "not exist") + tc, _ := newOptions(testDB("Test_taskScannerImp_scanAndSchedule"), "not exist") tr := &taskRegisterImp{} - tdal := &taskDALImp{config: tc} + tdal := &taskDALImp{options: tc} pool, _ := ants.NewPool(1) tsch := &taskSchedulerImp{pool: pool} - tscn := &taskScannerImp{config: tc, register: tr, dal: tdal, scheduler: tsch} + tscn := &taskScannerImp{options: tc, register: tr, dal: tdal, scheduler: tsch} convey.So(func() { tscn.scanAndSchedule() }, convey.ShouldNotPanic) }) }) diff --git a/scheduler.go b/scheduler.go index ae3b35e..a1ba213 100644 --- a/scheduler.go +++ b/scheduler.go @@ -25,7 +25,7 @@ type taskScheduler interface { } type taskSchedulerImp struct { - config *TaskConfig + *options register taskRegister dal taskDAL assembler taskAssembler @@ -34,7 +34,7 @@ type taskSchedulerImp struct { } func (s *taskSchedulerImp) Transaction(fc func(tx *gorm.DB) error) error { - db := s.config.db().Set(transactionKey, &sync.Map{}) + db := s.getDB().Set(transactionKey, &sync.Map{}) if err := db.Transaction(fc); err != nil { return err @@ -50,7 +50,7 @@ func (s *taskSchedulerImp) Transaction(fc func(tx *gorm.DB) error) error { } func (s *taskSchedulerImp) CreateTask(tx *gorm.DB, ctxIn context.Context, key TaskKey, arg interface{}) error { - logger := s.config.LoggerFactory(ctxIn) + logger := s.loggerFactory(ctxIn) taskDef, err := s.register.GetDefinition(key) if err != nil { @@ -62,7 +62,7 @@ func (s *taskSchedulerImp) CreateTask(tx *gorm.DB, ctxIn context.Context, key Ta } select { - case <-s.config.done(): + case <-s.done(): // may still accept create task requests when cancel signal is received if err := s.createInitializedTask(tx, task); err != nil { return err @@ -70,7 +70,7 @@ func (s *taskSchedulerImp) CreateTask(tx *gorm.DB, ctxIn context.Context, key Ta default: if toScheduleTasks, ok := tx.Get(transactionKey); ok { // buitin transaction, try to create running task - if !s.config.DryRun { + if !s.dryRun { if s.CanSchedule() { if err := s.createRunningTask(tx, task); err != nil { return err @@ -90,7 +90,7 @@ func (s *taskSchedulerImp) CreateTask(tx *gorm.DB, ctxIn context.Context, key Ta } } else { // not builtin transaction, create initialized task - if !s.config.DryRun { + if !s.dryRun { if err := s.createInitializedTask(tx, task); err != nil { return err } @@ -114,7 +114,7 @@ func (s *taskSchedulerImp) CreateTask(tx *gorm.DB, ctxIn context.Context, key Ta func (s *taskSchedulerImp) Stop(wait bool) { defer s.pool.Release() - logger := s.config.logger() + logger := s.logger() // first check, if tasks len is zero, return immediately taskIDs := s.runningTaskIDs() @@ -131,10 +131,10 @@ func (s *taskSchedulerImp) Stop(wait bool) { if len(taskIDs) <= 0 { logger.Infof("[Stop] current running tasks finished") return - } else if !wait || (s.config.WaitTimeout > 0 && time.Since(waitStart) > s.config.WaitTimeout) { - if !s.config.DryRun { + } else if !wait || (s.waitTimeout > 0 && time.Since(waitStart) > s.waitTimeout) { + if !s.dryRun { // change remaining tasks status to initialized - rowsAffected, err := s.dal.UpdateStatusByIDs(s.config.db(), taskIDs, TaskStatusRunning, TaskStatusInitialized) + rowsAffected, err := s.dal.UpdateStatusByIDs(s.getDB(), taskIDs, TaskStatusRunning, TaskStatusInitialized) if err != nil { logger.Errorf("[Stop] update task status from running to initialized failed, err[%v]", err) return @@ -149,7 +149,7 @@ func (s *taskSchedulerImp) Stop(wait bool) { } func (s *taskSchedulerImp) GoScheduleTask(task *Task) { - logger := s.config.logger() + logger := s.logger() if task.TaskStatus != TaskStatusRunning { logger.Errorf("[GoScheduleTask] invalid task status, task_key[%v], task_status[%v]", task.TaskKey, task.TaskStatus) @@ -179,7 +179,7 @@ func (s *taskSchedulerImp) CanSchedule() bool { } func (s *taskSchedulerImp) scheduleTask(task *Task) { - logger := s.config.logger() + logger := s.logger() taskDef, _ := s.register.GetDefinition(task.TaskKey) succeeded := false startTime := time.Now() @@ -215,7 +215,7 @@ func (s *taskSchedulerImp) scheduleTask(task *Task) { } func (s *taskSchedulerImp) executeTask(taskDef *TaskDefinition, task *Task) (err error) { - logger := s.config.logger() + logger := s.logger() startTime := time.Now() defer func() { @@ -245,15 +245,15 @@ func (s *taskSchedulerImp) executeTask(taskDef *TaskDefinition, task *Task) (err } func (s *taskSchedulerImp) stopRunning(task *Task, taskDef *TaskDefinition, toStatus TaskStatus) error { - if !s.config.DryRun { + if !s.dryRun { if taskDef.CleanSucceeded && toStatus == TaskStatusSucceeded { - if rowsAffected, err := s.dal.DeleteByIDAndStatus(s.config.db(), task.ID, task.TaskStatus); err != nil { + if rowsAffected, err := s.dal.DeleteByIDAndStatus(s.getDB(), task.ID, task.TaskStatus); err != nil { return err } else if rowsAffected == 0 { return ErrZeroRowsAffected } } else { - if rowsAffected, err := s.dal.UpdateStatusByIDs(s.config.db(), []uint64{task.ID}, task.TaskStatus, toStatus); err != nil { + if rowsAffected, err := s.dal.UpdateStatusByIDs(s.getDB(), []uint64{task.ID}, task.TaskStatus, toStatus); err != nil { return err } else if rowsAffected == 0 { return ErrZeroRowsAffected diff --git a/scheduler_test.go b/scheduler_test.go index 1e02b41..82ae9ff 100644 --- a/scheduler_test.go +++ b/scheduler_test.go @@ -13,12 +13,12 @@ import ( func Test_taskSchedulerImp_CreateTask(t *testing.T) { convey.Convey("Test_taskSchedulerImp_CreateTask", t, func() { - tc, _ := newConfig(testDB("Test_taskSchedulerImp_CreateTask"), "tasks") + tc, _ := newOptions(testDB("Test_taskSchedulerImp_CreateTask"), "tasks") tr := tc.taskRegister - tdal := &taskDALImp{config: tc} - tass := &taskAssemblerImp{config: tc} - pool, _ := ants.NewPool(tc.PoolSize, ants.WithLogger(tc.logger()), ants.WithNonblocking(true)) - tsch := &taskSchedulerImp{config: tc, register: tr, dal: tdal, assembler: tass, pool: pool} + tdal := &taskDALImp{options: tc} + tass := &taskAssemblerImp{options: tc} + pool, _ := ants.NewPool(tc.poolSize, ants.WithLogger(tc.logger()), ants.WithNonblocking(true)) + tsch := &taskSchedulerImp{options: tc, register: tr, dal: tdal, assembler: tass, pool: pool} _ = tr.Register("t1", TaskDefinition{Handler: func(ctx context.Context, arg interface{}) (err error) { return nil }}) _ = tr.Register("t2", TaskDefinition{Handler: func(ctx context.Context, arg interface{}) (err error) { return ErrUnexpected }}) @@ -26,7 +26,7 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { convey.Convey("built in transaction", func() { convey.Convey("transaction succeeded", func() { convey.Convey("not full pool", func() { - db := tc.db().Set(transactionKey, &sync.Map{}) + db := tc.getDB().Set(transactionKey, &sync.Map{}) err := db.Transaction(func(tx *gorm.DB) error { if err := tsch.CreateTask(tx, context.TODO(), "t1", nil); err != nil { return err @@ -45,10 +45,10 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { t2, ok2 := m.(*sync.Map).Load(uint64(2)) convey.So(ok2, convey.ShouldBeTrue) convey.So(t2.(*Task).TaskKey, convey.ShouldEqual, "t2") - task1, _ := tdal.Get(tc.db(), 1) + task1, _ := tdal.Get(tc.getDB(), 1) convey.So(task1.TaskKey, convey.ShouldEqual, "t1") convey.So(task1.TaskStatus, convey.ShouldEqual, TaskStatusRunning) - task2, _ := tdal.Get(tc.db(), 2) + task2, _ := tdal.Get(tc.getDB(), 2) convey.So(task2.TaskKey, convey.ShouldEqual, "t2") convey.So(task2.TaskStatus, convey.ShouldEqual, TaskStatusRunning) }) @@ -56,7 +56,7 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { pool, _ := ants.NewPool(1, ants.WithLogger(tc.logger()), ants.WithNonblocking(true)) _ = pool.Submit(func() { time.Sleep(time.Second * 10) }) tsch.pool = pool - db := tc.db().Set(transactionKey, &sync.Map{}) + db := tc.getDB().Set(transactionKey, &sync.Map{}) err := db.Transaction(func(tx *gorm.DB) error { if err := tsch.CreateTask(tx, context.TODO(), "t1", nil); err != nil { return err @@ -73,16 +73,16 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { convey.So(ok1, convey.ShouldBeFalse) _, ok2 := m.(*sync.Map).Load(uint64(2)) convey.So(ok2, convey.ShouldBeFalse) - task1, _ := tdal.Get(tc.db(), 1) + task1, _ := tdal.Get(tc.getDB(), 1) convey.So(task1.TaskKey, convey.ShouldEqual, "t1") convey.So(task1.TaskStatus, convey.ShouldEqual, TaskStatusInitialized) - task2, _ := tdal.Get(tc.db(), 2) + task2, _ := tdal.Get(tc.getDB(), 2) convey.So(task2.TaskKey, convey.ShouldEqual, "t2") convey.So(task2.TaskStatus, convey.ShouldEqual, TaskStatusInitialized) }) }) convey.Convey("transaction failed", func() { - db := tc.db().Set(transactionKey, &sync.Map{}) + db := tc.getDB().Set(transactionKey, &sync.Map{}) err := db.Transaction(func(tx *gorm.DB) error { if err := tsch.CreateTask(tx, context.TODO(), "t1", nil); err != nil { return err @@ -94,15 +94,15 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { }) convey.So(err, convey.ShouldNotBeNil) convey.So(tsch.runningTaskIDs(), convey.ShouldHaveLength, 0) - task1, _ := tdal.Get(tc.db(), 1) + task1, _ := tdal.Get(tc.getDB(), 1) convey.So(task1, convey.ShouldBeNil) - task2, _ := tdal.Get(tc.db(), 2) + task2, _ := tdal.Get(tc.getDB(), 2) convey.So(task2, convey.ShouldBeNil) }) }) convey.Convey("non built in transaction", func() { convey.Convey("transaction succeeded", func() { - db := tc.db() + db := tc.getDB() err := db.Transaction(func(tx *gorm.DB) error { if err := tsch.CreateTask(tx, context.TODO(), "t1", nil); err != nil { return err @@ -116,15 +116,15 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { convey.So(tsch.runningTaskIDs(), convey.ShouldHaveLength, 0) _, ok := db.Get(transactionKey) convey.So(ok, convey.ShouldBeFalse) - task1, _ := tdal.Get(tc.db(), 1) + task1, _ := tdal.Get(tc.getDB(), 1) convey.So(task1.TaskKey, convey.ShouldEqual, "t1") convey.So(task1.TaskStatus, convey.ShouldEqual, TaskStatusInitialized) - task2, _ := tdal.Get(tc.db(), 2) + task2, _ := tdal.Get(tc.getDB(), 2) convey.So(task2.TaskKey, convey.ShouldEqual, "t2") convey.So(task2.TaskStatus, convey.ShouldEqual, TaskStatusInitialized) }) convey.Convey("transaction failed", func() { - db := tc.db() + db := tc.getDB() err := db.Transaction(func(tx *gorm.DB) error { if err := tsch.CreateTask(tx, context.TODO(), "t1", nil); err != nil { return err @@ -138,9 +138,9 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { convey.So(ok, convey.ShouldBeFalse) convey.So(err, convey.ShouldNotBeNil) convey.So(tsch.runningTaskIDs(), convey.ShouldHaveLength, 0) - task1, _ := tdal.Get(tc.db(), 1) + task1, _ := tdal.Get(tc.getDB(), 1) convey.So(task1, convey.ShouldBeNil) - task2, _ := tdal.Get(tc.db(), 2) + task2, _ := tdal.Get(tc.getDB(), 2) convey.So(task2, convey.ShouldBeNil) }) }) @@ -149,19 +149,19 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { convey.Convey("ctx cancelled", func() { tc.cancel() - err := tsch.CreateTask(tc.db(), context.TODO(), "t1", nil) + err := tsch.CreateTask(tc.getDB(), context.TODO(), "t1", nil) convey.So(err, convey.ShouldBeNil) convey.So(tsch.runningTaskIDs(), convey.ShouldHaveLength, 0) - task1, _ := tdal.Get(tc.db(), 1) + task1, _ := tdal.Get(tc.getDB(), 1) convey.So(task1.TaskKey, convey.ShouldEqual, "t1") convey.So(task1.TaskStatus, convey.ShouldEqual, TaskStatusInitialized) }) convey.Convey("dry run mode", func() { - tc.DryRun = true + tc.dryRun = true convey.Convey("built in transaction", func() { - db := tc.db().Set(transactionKey, &sync.Map{}) + db := tc.getDB().Set(transactionKey, &sync.Map{}) err := db.Transaction(func(tx *gorm.DB) error { if err := tsch.CreateTask(tx, context.TODO(), "t1", nil); err != nil { return err @@ -181,13 +181,13 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { return true }) convey.So(count, convey.ShouldEqual, 2) - task1, _ := tdal.Get(tc.db(), 1) + task1, _ := tdal.Get(tc.getDB(), 1) convey.So(task1, convey.ShouldBeNil) - task2, _ := tdal.Get(tc.db(), 2) + task2, _ := tdal.Get(tc.getDB(), 2) convey.So(task2, convey.ShouldBeNil) }) convey.Convey("non built in transaction", func() { - db := tc.db() + db := tc.getDB() err := db.Transaction(func(tx *gorm.DB) error { if err := tsch.CreateTask(tx, context.TODO(), "t1", nil); err != nil { return err @@ -201,9 +201,9 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { convey.So(tsch.runningTaskIDs(), convey.ShouldHaveLength, 0) _, ok := db.Get(transactionKey) convey.So(ok, convey.ShouldBeFalse) - task1, _ := tdal.Get(tc.db(), 1) + task1, _ := tdal.Get(tc.getDB(), 1) convey.So(task1, convey.ShouldBeNil) - task2, _ := tdal.Get(tc.db(), 2) + task2, _ := tdal.Get(tc.getDB(), 2) convey.So(task2, convey.ShouldBeNil) }) }) @@ -216,12 +216,12 @@ func Test_taskSchedulerImp_CreateTask(t *testing.T) { func Test_taskSchedulerImp_GoScheduleTask(t *testing.T) { convey.Convey("Test_taskSchedulerImp_GoScheduleTask", t, func() { - tc, _ := newConfig(testDB("Test_taskSchedulerImp_GoScheduleTask"), "tasks", WithPoolSize(1)) + tc, _ := newOptions(testDB("Test_taskSchedulerImp_GoScheduleTask"), "tasks", WithPoolSize(1)) tr := tc.taskRegister - tdal := &taskDALImp{config: tc} - tass := &taskAssemblerImp{config: tc} - pool, _ := ants.NewPool(tc.PoolSize, ants.WithLogger(tc.logger()), ants.WithNonblocking(true)) - tsch := &taskSchedulerImp{config: tc, register: tr, dal: tdal, assembler: tass, pool: pool} + tdal := &taskDALImp{options: tc} + tass := &taskAssemblerImp{options: tc} + pool, _ := ants.NewPool(tc.poolSize, ants.WithLogger(tc.logger()), ants.WithNonblocking(true)) + tsch := &taskSchedulerImp{options: tc, register: tr, dal: tdal, assembler: tass, pool: pool} var t1Run int64 _ = tr.Register("t1", TaskDefinition{Handler: testCountHandler(&t1Run)}) diff --git a/task_check_abnormal.go b/task_check_abnormal.go index 027a31c..40e98d8 100644 --- a/task_check_abnormal.go +++ b/task_check_abnormal.go @@ -19,33 +19,31 @@ type checkAbnormalTaskReq struct { } func registerCheckAbnormalTask(tm *TaskManager) { - tc := tm.tc tm.Register(taskCheckAbnormal, TaskDefinition{ Handler: checkAbnormalHandler(tm), ArgType: reflect.TypeOf(checkAbnormalTaskReq{}), builtin: true, taskID: taskCheckAbnormalID, argument: checkAbnormalTaskReq{ - StorageTimeout: tc.StorageTimeout, - RunningTimeout: tc.RunningTimeout, - InitializedTimeout: tc.InitializedTimeout, + StorageTimeout: tm.storageTimeout, + RunningTimeout: tm.runningTimeout, + InitializedTimeout: tm.initializedTimeout, }, loopInterval: time.Duration( - minInt64(int64(tc.InitializedTimeout)/2, int64(tc.RunningTimeout)/2, int64(tc.ScanInterval)*15), + minInt64(int64(tm.initializedTimeout)/2, int64(tm.runningTimeout)/2, int64(tm.scanInterval)*15), ), }) } func checkAbnormalHandler(tm *TaskManager) TaskHandler { - tc := tm.tc return func(ctx context.Context, arg interface{}) (err error) { req := arg.(checkAbnormalTaskReq) - abnormalRunning, err := tm.tdal.GetSliceByOffsetsAndStatus(tc.db(), req.StorageTimeout, + abnormalRunning, err := tm.tdal.GetSliceByOffsetsAndStatus(tm.getDB(), req.StorageTimeout, req.RunningTimeout, TaskStatusRunning) if err != nil { return fmt.Errorf("check abnormal running failed, err[%w]", err) } - abnormalInitilized, err := tm.tdal.GetSliceByOffsetsAndStatus(tc.db(), req.StorageTimeout, + abnormalInitilized, err := tm.tdal.GetSliceByOffsetsAndStatus(tm.getDB(), req.StorageTimeout, req.InitializedTimeout, TaskStatusInitialized) if err != nil { return fmt.Errorf("check abnormal running failed, err[%w]", err) @@ -71,7 +69,7 @@ func checkAbnormalHandler(tm *TaskManager) TaskHandler { abnormalTasks = append(abnormalTasks, t) } - tc.CheckCallback(tc.logger(), abnormalTasks) + tm.checkCallback(tm.logger(), abnormalTasks) return nil } } diff --git a/task_clean_up.go b/task_clean_up.go index 1845deb..d1fa40d 100644 --- a/task_clean_up.go +++ b/task_clean_up.go @@ -16,23 +16,21 @@ type cleanUpReq struct { } func registerCleanUpTask(tm *TaskManager) { - tc := tm.tc tm.Register(taskCleanUp, TaskDefinition{ Handler: cleanUpHandler(tm), ArgType: reflect.TypeOf(cleanUpReq{}), builtin: true, taskID: taskCleanUpID, - argument: cleanUpReq{StorageTimeout: tc.StorageTimeout}, - loopInterval: tc.StorageTimeout / 2, + argument: cleanUpReq{StorageTimeout: tm.storageTimeout}, + loopInterval: tm.storageTimeout / 2, }) } func cleanUpHandler(tm *TaskManager) TaskHandler { - tc := tm.tc return func(ctx context.Context, arg interface{}) (err error) { - logger := tc.logger() + logger := tm.logger() storageTimeout := arg.(cleanUpReq).StorageTimeout - rowsAffected, err := tm.tdal.DeleteSucceededByOffset(tc.db(), storageTimeout, tm.tr.GetBuiltInKeys()) + rowsAffected, err := tm.tdal.DeleteSucceededByOffset(tm.getDB(), storageTimeout, tm.tr.GetBuiltInKeys()) if err != nil { return err } else if rowsAffected > 0 {