diff --git a/bux_suite_mocks_test.go b/bux_suite_mocks_test.go index 265407a9..4929237d 100644 --- a/bux_suite_mocks_test.go +++ b/bux_suite_mocks_test.go @@ -12,13 +12,13 @@ type taskManagerMockBase struct{} func (tm *taskManagerMockBase) Info(context.Context, string, ...interface{}) {} -func (tm *taskManagerMockBase) RegisterTask(*taskmanager.Task) error { +func (tm *taskManagerMockBase) RegisterTask(string, interface{}) error { return nil } func (tm *taskManagerMockBase) ResetCron() {} -func (tm *taskManagerMockBase) RunTask(context.Context, *taskmanager.TaskOptions) error { +func (tm *taskManagerMockBase) RunTask(context.Context, *taskmanager.TaskRunOptions) error { return nil } @@ -55,6 +55,6 @@ func (tm *taskManagerMockBase) CronJobsInit(cronJobsMap taskmanager.CronJobs) er // Sets custom task manager only for testing func withTaskManagerMockup() ClientOps { return func(c *clientOptions) { - c.taskManager.TaskManagerInterface = &taskManagerMockBase{} + c.taskManager.Tasker = &taskManagerMockBase{} } } diff --git a/bux_suite_test.go b/bux_suite_test.go index b1ef2d1a..aac0aaf9 100644 --- a/bux_suite_test.go +++ b/bux_suite_test.go @@ -326,7 +326,7 @@ func (ts *EmbeddedDBTestSuite) genericDBClient(t *testing.T, database datastore. WithAutoMigrate(&PaymailAddress{}), ) if taskManagerEnabled { - opts = append(opts, WithTaskqConfig(taskmanager.DefaultTaskQConfig(prefix+"_queue", nil))) + opts = append(opts, WithTaskqConfig(taskmanager.DefaultTaskQConfig(prefix+"_queue"))) } else { opts = append(opts, withTaskManagerMockup()) } diff --git a/bux_test.go b/bux_test.go index f958238e..f9e0432a 100644 --- a/bux_test.go +++ b/bux_test.go @@ -59,7 +59,7 @@ func (tc *TestingClient) Close(ctx context.Context) { // DefaultClientOpts will return a default set of client options required to load the new client func DefaultClientOpts(debug, shared bool) []ClientOps { - tqc := taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix(), nil) + tqc := taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix()) tqc.MaxNumWorker = 2 tqc.MaxNumFetcher = 2 diff --git a/client.go b/client.go index 0de28a0f..8ee35acd 100644 --- a/client.go +++ b/client.go @@ -113,10 +113,10 @@ type ( // taskManagerOptions holds the configuration for taskmanager taskManagerOptions struct { - taskmanager.TaskManagerInterface // Client for TaskManager - cronJobs taskmanager.CronJobs // List of cron jobs - options []taskmanager.ClientOps // List of options - cronCustomPeriods map[string]time.Duration // will override the default period of cronJob + taskmanager.Tasker // Client for TaskManager + cronJobs taskmanager.CronJobs // List of cron jobs + options []taskmanager.ClientOps // List of options + cronCustomPeriods map[string]time.Duration // will override the default period of cronJob } ) @@ -303,7 +303,7 @@ func (c *Client) Close(ctx context.Context) error { if err := tm.Close(ctx); err != nil { return err } - c.options.taskManager.TaskManagerInterface = nil + c.options.taskManager.Tasker = nil } return nil } @@ -445,9 +445,9 @@ func (c *Client) SetNotificationsClient(client notifications.ClientInterface) { } // Taskmanager will return the Taskmanager if it exists -func (c *Client) Taskmanager() taskmanager.TaskManagerInterface { - if c.options.taskManager != nil && c.options.taskManager.TaskManagerInterface != nil { - return c.options.taskManager.TaskManagerInterface +func (c *Client) Taskmanager() taskmanager.Tasker { + if c.options.taskManager != nil && c.options.taskManager.Tasker != nil { + return c.options.taskManager.Tasker } return nil } diff --git a/client_internal.go b/client_internal.go index 900d3481..d169e3e7 100644 --- a/client_internal.go +++ b/client_internal.go @@ -114,8 +114,8 @@ func (c *Client) loadPaymailClient() (err error) { // loadTaskmanager will load the TaskManager and start the TaskManager client func (c *Client) loadTaskmanager(ctx context.Context) (err error) { // Load if a custom interface was NOT provided - if c.options.taskManager.TaskManagerInterface == nil { - c.options.taskManager.TaskManagerInterface, err = taskmanager.NewClient( + if c.options.taskManager.Tasker == nil { + c.options.taskManager.Tasker, err = taskmanager.NewClient( ctx, c.options.taskManager.options..., ) } diff --git a/client_options.go b/client_options.go index 7924076f..4be87a7d 100644 --- a/client_options.go +++ b/client_options.go @@ -105,8 +105,8 @@ func defaultClientOptions() *clientOptions { // Blank TaskManager config taskManager: &taskManagerOptions{ - TaskManagerInterface: nil, - cronCustomPeriods: map[string]time.Duration{}, + Tasker: nil, + cronCustomPeriods: map[string]time.Duration{}, }, // Default user agent diff --git a/client_options_test.go b/client_options_test.go index 8a0b82b7..1f82dcba 100644 --- a/client_options_test.go +++ b/client_options_test.go @@ -264,7 +264,7 @@ func TestWithRedis(t *testing.T) { tc, err := NewClient( tester.GetNewRelicCtx(t, defaultNewRelicApp, defaultNewRelicTx), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix(), nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix())), WithRedis(&cachestore.RedisConfig{ URL: cachestore.RedisPrefix + "localhost:6379", }), @@ -287,7 +287,7 @@ func TestWithRedis(t *testing.T) { tc, err := NewClient( tester.GetNewRelicCtx(t, defaultNewRelicApp, defaultNewRelicTx), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix(), nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix())), WithRedis(&cachestore.RedisConfig{ URL: "localhost:6379", }), @@ -314,7 +314,7 @@ func TestWithRedisConnection(t *testing.T) { t.Run("using a nil connection", func(t *testing.T) { tc, err := NewClient( tester.GetNewRelicCtx(t, defaultNewRelicApp, defaultNewRelicTx), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix(), nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix())), WithRedisConnection(nil), WithSQLite(tester.SQLiteTestConfig(false, true)), WithMinercraft(&chainstate.MinerCraftBase{}), @@ -335,7 +335,7 @@ func TestWithRedisConnection(t *testing.T) { tc, err := NewClient( tester.GetNewRelicCtx(t, defaultNewRelicApp, defaultNewRelicTx), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix(), nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix())), WithRedisConnection(client), WithSQLite(tester.SQLiteTestConfig(false, true)), WithMinercraft(&chainstate.MinerCraftBase{}), @@ -363,7 +363,7 @@ func TestWithFreeCache(t *testing.T) { tc, err := NewClient( tester.GetNewRelicCtx(t, defaultNewRelicApp, defaultNewRelicTx), WithFreeCache(), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName, nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName)), WithSQLite(&datastore.SQLiteConfig{Shared: true}), WithMinercraft(&chainstate.MinerCraftBase{})) require.NoError(t, err) @@ -391,7 +391,7 @@ func TestWithFreeCacheConnection(t *testing.T) { tc, err := NewClient( tester.GetNewRelicCtx(t, defaultNewRelicApp, defaultNewRelicTx), WithFreeCacheConnection(nil), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName, nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName)), WithSQLite(&datastore.SQLiteConfig{Shared: true}), WithMinercraft(&chainstate.MinerCraftBase{}), WithLogger(&logger), @@ -412,7 +412,7 @@ func TestWithFreeCacheConnection(t *testing.T) { tc, err := NewClient( tester.GetNewRelicCtx(t, defaultNewRelicApp, defaultNewRelicTx), WithFreeCacheConnection(fc), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName, nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName)), WithSQLite(&datastore.SQLiteConfig{Shared: true}), WithMinercraft(&chainstate.MinerCraftBase{}), WithLogger(&logger), @@ -497,9 +497,7 @@ func TestWithTaskQ(t *testing.T) { tc, err := NewClient( tester.GetNewRelicCtx(t, defaultNewRelicApp, defaultNewRelicTx), WithTaskqConfig( - taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix(), &taskmanager.SimplifiedRedisOptions{ - Addr: "localhost:6379", - }), + taskmanager.DefaultTaskQConfig(tester.RandomTablePrefix(), taskmanager.WithRedis("localhost:6379")), ), WithRedis(&cachestore.RedisConfig{ URL: cachestore.RedisPrefix + "localhost:6379", diff --git a/examples/client/redis/redis.go b/examples/client/redis/redis.go index a3dfc6ec..6f979787 100644 --- a/examples/client/redis/redis.go +++ b/examples/client/redis/redis.go @@ -15,9 +15,8 @@ func main() { context.Background(), // Set context bux.WithRedis(&cachestore.RedisConfig{URL: redisURL}), // Cache bux.WithTaskqConfig( // Tasks - taskmanager.DefaultTaskQConfig("example_queue", &taskmanager.SimplifiedRedisOptions{ - Addr: redisURL, - })), + taskmanager.DefaultTaskQConfig("example_queue", taskmanager.WithRedis(redisURL)), + ), ) if err != nil { log.Fatalln("error: " + err.Error()) diff --git a/interface.go b/interface.go index 17034b65..fda205c6 100644 --- a/interface.go +++ b/interface.go @@ -66,7 +66,7 @@ type ClientService interface { Logger() *zerolog.Logger Notifications() notifications.ClientInterface PaymailClient() paymail.ClientInterface - Taskmanager() taskmanager.TaskManagerInterface + Taskmanager() taskmanager.Tasker } // DestinationService is the destination actions diff --git a/paymail_test.go b/paymail_test.go index 3962747f..9bf0081b 100644 --- a/paymail_test.go +++ b/paymail_test.go @@ -219,7 +219,7 @@ func Test_getCapabilities(t *testing.T) { tc, err := NewClient(context.Background(), WithRedisConnection(redisClient), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName, nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName)), WithSQLite(&datastore.SQLiteConfig{Shared: true}), WithChainstateOptions(false, false, false, false), WithDebugging(), @@ -261,7 +261,7 @@ func Test_getCapabilities(t *testing.T) { tc, err := NewClient(context.Background(), WithRedisConnection(redisClient), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName, nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName)), WithSQLite(&datastore.SQLiteConfig{Shared: true}), WithChainstateOptions(false, false, false, false), WithDebugging(), @@ -374,7 +374,7 @@ func Test_resolvePaymailAddress(t *testing.T) { tc, err := NewClient(context.Background(), WithRedisConnection(redisClient), - WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName, nil)), + WithTaskqConfig(taskmanager.DefaultTaskQConfig(testQueueName)), WithSQLite(&datastore.SQLiteConfig{Shared: true}), WithChainstateOptions(false, false, false, false), WithDebugging(), diff --git a/taskmanager/cron_jobs.go b/taskmanager/cron_jobs.go index 85f88653..f04ec636 100644 --- a/taskmanager/cron_jobs.go +++ b/taskmanager/cron_jobs.go @@ -31,23 +31,19 @@ func (tm *Client) CronJobsInit(cronJobsMap CronJobs) (err error) { for name, taskDef := range cronJobsMap { handler := taskDef.Handler - if err = tm.RegisterTask(&Task{ - Name: name, - RetryLimit: 1, - Handler: func() error { - if taskErr := handler(ctx); taskErr != nil { - if tm.options.logger != nil { - tm.options.logger.Error().Msgf("error running %v task: %v", name, taskErr.Error()) - } + if err = tm.RegisterTask(name, func() error { + if taskErr := handler(ctx); taskErr != nil { + if tm.options.logger != nil { + tm.options.logger.Error().Msgf("error running %v task: %v", name, taskErr.Error()) } - return nil - }, + } + return nil }); err != nil { return } // Run the task periodically - if err = tm.RunTask(ctx, &TaskOptions{ + if err = tm.RunTask(ctx, &TaskRunOptions{ RunEveryPeriod: taskDef.Period, TaskName: name, }); err != nil { diff --git a/taskmanager/errors.go b/taskmanager/errors.go deleted file mode 100644 index e30acbf6..00000000 --- a/taskmanager/errors.go +++ /dev/null @@ -1,15 +0,0 @@ -package taskmanager - -import "errors" - -// ErrMissingTaskQConfig is when the taskq configuration is missing prior to loading taskq -var ErrMissingTaskQConfig = errors.New("missing taskq configuration") - -// ErrMissingRedis is when the Redis connection is missing prior to loading taskq -var ErrMissingRedis = errors.New("missing redis connection") - -// ErrMissingFactory is when the factory type is missing or empty -var ErrMissingFactory = errors.New("missing factory type to load taskq") - -// ErrTaskNotFound is when a task was not found -var ErrTaskNotFound = errors.New("task not found") diff --git a/taskmanager/interface.go b/taskmanager/interface.go index d1a4d2b8..8e453e8a 100644 --- a/taskmanager/interface.go +++ b/taskmanager/interface.go @@ -6,11 +6,11 @@ import ( taskq "github.com/vmihailenco/taskq/v3" ) -// TaskManagerInterface is the taskmanager client interface -type TaskManagerInterface interface { - RegisterTask(task *Task) error +// Tasker is the taskmanager client interface +type Tasker interface { + RegisterTask(name string, handler interface{}) error ResetCron() - RunTask(ctx context.Context, options *TaskOptions) error + RunTask(ctx context.Context, options *TaskRunOptions) error Tasks() map[string]*taskq.Task CronJobsInit(cronJobsMap CronJobs) error Close(ctx context.Context) error diff --git a/taskmanager/options.go b/taskmanager/options.go index c3d95d1e..d02623e8 100644 --- a/taskmanager/options.go +++ b/taskmanager/options.go @@ -22,7 +22,7 @@ func defaultClientOptions() *clientOptions { newRelicEnabled: false, taskq: &taskqOptions{ tasks: make(map[string]*taskq.Task), - config: DefaultTaskQConfig("taskq", nil), + config: DefaultTaskQConfig("taskq"), }, } } @@ -52,6 +52,7 @@ func WithDebugging() ClientOps { } } +// WithTaskqConfig will set the taskq custom config func WithTaskqConfig(config *taskq.QueueOptions) ClientOps { return func(c *clientOptions) { if config != nil { diff --git a/taskmanager/options_test.go b/taskmanager/options_test.go index 70c6f3c8..de851473 100644 --- a/taskmanager/options_test.go +++ b/taskmanager/options_test.go @@ -61,7 +61,7 @@ func TestWithTaskQ(t *testing.T) { options := &clientOptions{ taskq: &taskqOptions{}, } - opt := WithTaskqConfig(DefaultTaskQConfig(testQueueName, nil)) + opt := WithTaskqConfig(DefaultTaskQConfig(testQueueName)) opt(options) assert.NotNil(t, options.taskq.config) }) diff --git a/taskmanager/task.go b/taskmanager/task.go deleted file mode 100644 index f67fb727..00000000 --- a/taskmanager/task.go +++ /dev/null @@ -1,51 +0,0 @@ -package taskmanager - -import ( - "time" -) - -// Task is the options for a new task (mimics TaskQ) -type Task struct { - Name string // Task name. - - // Function called to process a message. - // There are three permitted types of signature: - // 1. A zero-argument function - // 2. A function whose arguments are assignable in type from those which are passed in the message - // 3. A function which takes a single `*Message` argument - // The handler function may also optionally take a Context as a first argument and may optionally return an error. - // If the handler takes a Context, when it is invoked it will be passed the same Context as that which was passed to - // `StartConsumer`. If the handler returns a non-nil error the message processing will fail and will be retried/. - Handler interface{} - // Function called to process failed message after the specified number of retries have all failed. - // The FallbackHandler accepts the same types of function as the Handler. - FallbackHandler interface{} - - // Optional function used by Consumer with defer statement to recover from panics. - DeferFunc func() - - // Number of tries/releases after which the message fails permanently and is deleted. Default is 64 retries. - RetryLimit int - - // Minimum backoff time between retries. Default is 30 seconds. - MinBackoff time.Duration - - // Maximum backoff time between retries. Default is 30 minutes. - MaxBackoff time.Duration -} - -// TaskOptions are used for running a task -type TaskOptions struct { - Arguments []interface{} `json:"arguments"` // Arguments for the task - Delay time.Duration `json:"delay"` // Run after X delay - OnceInPeriod time.Duration `json:"once_in_period"` // Run once in X period - RunEveryPeriod time.Duration `json:"run_every_period"` // Cron job! - TaskName string `json:"task_name"` // Name of the task -} - -/* -// todo: add this functionality to the task options -OnceInPeriod(period time.Duration, args ...interface{}) -OnceWithDelay(delay time.Duration) -OnceWithSchedule(tm time.Time) -*/ diff --git a/taskmanager/taskmanager.go b/taskmanager/taskmanager.go index 64e1c5c2..fe938a32 100644 --- a/taskmanager/taskmanager.go +++ b/taskmanager/taskmanager.go @@ -40,7 +40,7 @@ type ( // // If no options are given, it will use the defaultClientOptions() // ctx may contain a NewRelic txn (or one will be created) -func NewClient(_ context.Context, opts ...ClientOps) (TaskManagerInterface, error) { +func NewClient(_ context.Context, opts ...ClientOps) (Tasker, error) { // Create a new client with defaults client := &Client{options: defaultClientOptions()} diff --git a/taskmanager/taskq.go b/taskmanager/taskq.go index 5c6230df..4fdbd8e7 100644 --- a/taskmanager/taskq.go +++ b/taskmanager/taskq.go @@ -16,54 +16,55 @@ import ( var mutex sync.Mutex -type SimplifiedRedisOptions struct { - Addr string -} - -// redisClientIfSet creates a redis client if the options are set otherwise nil which will use the memory queue -func (s *SimplifiedRedisOptions) redisClientIfSet() taskq.Redis { - if s == nil { - return nil +// TasqOps allow functional options to be supplied +type TasqOps func(options *taskq.QueueOptions) + +// WithRedis will set the redis client for the TaskQ engine +func WithRedis(addr string) TasqOps { + return func(options *taskq.QueueOptions) { + options.Redis = redis.NewClient(&redis.Options{ + Addr: strings.Replace(addr, "redis://", "", -1), + }) } - return redis.NewClient(&redis.Options{ - Addr: strings.Replace(s.Addr, "redis://", "", -1), - }) } // DefaultTaskQConfig will return a default configuration that can be modified // If redisOptions is nil, it will use the memory queue -func DefaultTaskQConfig(name string, redisOptions *SimplifiedRedisOptions) *taskq.QueueOptions { - return &taskq.QueueOptions{ - BufferSize: 10, // Size of the buffer where reserved messages are stored. - ConsumerIdleTimeout: 6 * time.Hour, // ConsumerIdleTimeout Time after which the consumer need to be deleted. - Handler: nil, // Optional message handler. The default is the global Tasks registry. - MaxNumFetcher: 0, // Maximum number of goroutines fetching messages. - MaxNumWorker: 10, // Maximum number of goroutines processing messages. - MinNumWorker: 1, // Minimum number of goroutines processing messages. - Name: name, // Queue name. - PauseErrorsThreshold: 100, // Number of consecutive failures after which queue processing is paused. - RateLimit: redis_rate.Limit{}, // Processing rate limit. - RateLimiter: nil, // Optional rate limiter. The default is to use Redis. - Redis: redisOptions.redisClientIfSet(), // Redis client that is used for storing metadata. - ReservationSize: 10, // Number of messages reserved by a fetcher in the queue in one request. - ReservationTimeout: 60 * time.Second, // Time after which the reserved message is returned to the queue. - Storage: taskq.NewLocalStorage(), // Optional storage interface. The default is to use Redis. - WaitTimeout: 3 * time.Second, // Time that a long polling receive call waits for a message to become available before returning an empty response. - WorkerLimit: 0, // Global limit of concurrently running workers across all servers. Overrides MaxNumWorker. +func DefaultTaskQConfig(name string, opts ...TasqOps) *taskq.QueueOptions { + queueOptions := &taskq.QueueOptions{ + BufferSize: 10, // Size of the buffer where reserved messages are stored. + ConsumerIdleTimeout: 6 * time.Hour, // ConsumerIdleTimeout Time after which the consumer need to be deleted. + Handler: nil, // Optional message handler. The default is the global Tasks registry. + MaxNumFetcher: 0, // Maximum number of goroutines fetching messages. + MaxNumWorker: 10, // Maximum number of goroutines processing messages. + MinNumWorker: 1, // Minimum number of goroutines processing messages. + Name: name, // Queue name. + PauseErrorsThreshold: 100, // Number of consecutive failures after which queue processing is paused. + RateLimit: redis_rate.Limit{}, // Processing rate limit. + RateLimiter: nil, // Optional rate limiter. The default is to use Redis. + Redis: nil, // Redis client that is used for storing metadata. + ReservationSize: 10, // Number of messages reserved by a fetcher in the queue in one request. + ReservationTimeout: 60 * time.Second, // Time after which the reserved message is returned to the queue. + Storage: taskq.NewLocalStorage(), // Optional storage interface. The default is to use Redis. + WaitTimeout: 3 * time.Second, // Time that a long polling receive call waits for a message to become available before returning an empty response. + WorkerLimit: 0, // Global limit of concurrently running workers across all servers. Overrides MaxNumWorker. } -} -// convertTaskToTaskQ will convert our internal task to a TaskQ struct -func convertTaskToTaskQ(task *Task) *taskq.TaskOptions { - return &taskq.TaskOptions{ - Name: task.Name, - Handler: task.Handler, - FallbackHandler: task.FallbackHandler, - DeferFunc: task.DeferFunc, - RetryLimit: task.RetryLimit, - MinBackoff: task.MinBackoff, - MaxBackoff: task.MaxBackoff, + // Overwrite defaults with any set by user + for _, opt := range opts { + opt(queueOptions) } + + return queueOptions +} + +// TaskRunOptions are the options for running a task +type TaskRunOptions struct { + Arguments []interface{} `json:"arguments"` // Arguments for the task + Delay time.Duration `json:"delay"` // Run after X delay + OnceInPeriod time.Duration `json:"once_in_period"` // Run once in X period + RunEveryPeriod time.Duration `json:"run_every_period"` // Cron job! + TaskName string `json:"task_name"` // Name of the task } // loadTaskQ will load TaskQ based on the Factory Type and configuration set by the client loading @@ -71,7 +72,7 @@ func (c *Client) loadTaskQ() error { // Check for a valid config (set on client creation) factoryType := c.Factory() if factoryType == FactoryEmpty { - return ErrMissingFactory + return fmt.Errorf("missing factory type to load taskq") } var factory taskq.Factory @@ -92,7 +93,7 @@ func (c *Client) loadTaskQ() error { } // RegisterTask will register a new task using the TaskQ engine -func (c *Client) RegisterTask(task *Task) (err error) { +func (c *Client) RegisterTask(name string, handler interface{}) (err error) { defer func() { if panicErr := recover(); panicErr != nil { err = fmt.Errorf(fmt.Sprintf("registering task panic: %v", panicErr)) @@ -102,21 +103,25 @@ func (c *Client) RegisterTask(task *Task) (err error) { mutex.Lock() defer mutex.Unlock() - if t := taskq.Tasks.Get(task.Name); t != nil { + if t := taskq.Tasks.Get(name); t != nil { // if already registered - register the task locally - c.options.taskq.tasks[task.Name] = t + c.options.taskq.tasks[name] = t } else { // Register and store the task - c.options.taskq.tasks[task.Name] = taskq.RegisterTask(convertTaskToTaskQ(task)) + c.options.taskq.tasks[name] = taskq.RegisterTask(&taskq.TaskOptions{ + Name: name, + Handler: handler, + RetryLimit: 1, + }) } // Debugging - c.DebugLog(fmt.Sprintf("registering task: %s...", c.options.taskq.tasks[task.Name].Name())) + c.DebugLog(fmt.Sprintf("registering task: %s...", c.options.taskq.tasks[name].Name())) return nil } // RunTask will run a task using TaskQ -func (c *Client) RunTask(ctx context.Context, options *TaskOptions) error { +func (c *Client) RunTask(ctx context.Context, options *TaskRunOptions) error { // Starting the execution of the task c.DebugLog(fmt.Sprintf( "executing task: %s... delay: %s arguments: %s", @@ -127,7 +132,7 @@ func (c *Client) RunTask(ctx context.Context, options *TaskOptions) error { // Try to get the task if _, ok := c.options.taskq.tasks[options.TaskName]; !ok { - return ErrTaskNotFound + return fmt.Errorf("task %s not registered", options.TaskName) } // Add arguments, and delay if set diff --git a/taskmanager/taskq_test.go b/taskmanager/taskq_test.go index ae985f7c..bcd71bdc 100644 --- a/taskmanager/taskq_test.go +++ b/taskmanager/taskq_test.go @@ -19,7 +19,7 @@ const ( func TestNewClient(t *testing.T) { c, err := NewClient( context.Background(), - WithTaskqConfig(DefaultTaskQConfig(testQueueName, nil)), + WithTaskqConfig(DefaultTaskQConfig(testQueueName)), ) require.NoError(t, err) require.NotNil(t, c) @@ -29,21 +29,15 @@ func TestNewClient(t *testing.T) { ctx := c.GetTxnCtx(context.Background()) - err = c.RegisterTask(&Task{ - Name: "task-1", - Handler: func(name string) error { - fmt.Println("TSK1 ran: " + name) - return nil - }, + err = c.RegisterTask("task-1", func(name string) error { + fmt.Println("TSK1 ran: " + name) + return nil }) require.NoError(t, err) - err = c.RegisterTask(&Task{ - Name: "task-2", - Handler: func(name string) error { - fmt.Println("TSK2 ran: " + name) - return nil - }, + err = c.RegisterTask("task-2", func(name string) error { + fmt.Println("TSK2 ran: " + name) + return nil }) require.NoError(t, err) @@ -52,19 +46,19 @@ func TestNewClient(t *testing.T) { time.Sleep(2 * time.Second) // Run tasks - err = c.RunTask(ctx, &TaskOptions{ + err = c.RunTask(ctx, &TaskRunOptions{ Arguments: []interface{}{"task #1"}, TaskName: "task-1", }) require.NoError(t, err) - err = c.RunTask(ctx, &TaskOptions{ + err = c.RunTask(ctx, &TaskRunOptions{ Arguments: []interface{}{"task #2"}, TaskName: "task-2", }) require.NoError(t, err) - err = c.RunTask(ctx, &TaskOptions{ + err = c.RunTask(ctx, &TaskRunOptions{ Arguments: []interface{}{"task #2 with delay"}, Delay: time.Second, TaskName: "task-2",