From e61e8fe736d819b8f9b2969484967f1e57a6d6f2 Mon Sep 17 00:00:00 2001 From: wxing1292 Date: Mon, 20 Dec 2021 01:52:30 -0800 Subject: [PATCH] Update DB task manager (#2310) * Add necessary functionality to db task manager & UT --- service/matching/db_task_manager.go | 230 +++++++--- service/matching/db_task_manager_test.go | 392 ++++++++++++++++++ service/matching/db_task_queue_ownership.go | 1 + .../matching/db_task_queue_ownership_mock.go | 14 + service/matching/db_task_reader.go | 22 +- service/matching/db_task_reader_mock.go | 99 +++++ service/matching/db_task_reader_test.go | 2 +- service/matching/db_task_writer.go | 30 +- service/matching/db_task_writer_mock.go | 100 +++++ service/matching/db_task_writer_test.go | 8 +- 10 files changed, 809 insertions(+), 89 deletions(-) create mode 100644 service/matching/db_task_manager_test.go create mode 100644 service/matching/db_task_reader_mock.go create mode 100644 service/matching/db_task_writer_mock.go diff --git a/service/matching/db_task_manager.go b/service/matching/db_task_manager.go index fa468dab671..0bd322f3cb3 100644 --- a/service/matching/db_task_manager.go +++ b/service/matching/db_task_manager.go @@ -30,6 +30,7 @@ import ( "time" enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/serviceerror" enumsspb "go.temporal.io/server/api/enums/v1" persistencespb "go.temporal.io/server/api/persistence/v1" @@ -53,21 +54,34 @@ const ( dbTaskUpdateQueueInterval = time.Minute ) +var ( + errDBTaskManagerNotReady = serviceerror.NewUnavailable("dbTaskManager is not ready") +) + type ( + taskQueueOwnershipProviderFn func() dbTaskQueueOwnership + taskReaderProviderFn func(ownership dbTaskQueueOwnership) dbTaskReader + taskWriterProviderFn func(ownership dbTaskQueueOwnership) dbTaskWriter + dbTaskManager struct { - status int32 - taskQueueKey persistence.TaskQueueKey - store persistence.TaskManager - taskQueueOwnership *dbTaskQueueOwnershipImpl - taskReader *dbTaskWriter - taskWriter *dbTaskReader - - dispatchTaskFn func(context.Context, *internalTask) error - finishTaskFn func(*persistencespb.AllocatedTaskInfo, error) - logger log.Logger - - shutdownChan chan struct{} - dispatchChan chan struct{} + status int32 + taskQueueKey persistence.TaskQueueKey + taskQueueKind enumspb.TaskQueueKind + taskIDRangeSize int64 + taskQueueOwnershipProvider taskQueueOwnershipProviderFn + taskReaderProvider taskReaderProviderFn + taskWriterProvider taskWriterProviderFn + dispatchTaskFn func(context.Context, *internalTask) error + store persistence.TaskManager + logger log.Logger + + dispatchChan chan struct{} + startupChan chan struct{} + shutdownChan chan struct{} + + taskQueueOwnership dbTaskQueueOwnership + taskReader dbTaskReader + taskWriter dbTaskWriter maxDeletedTaskIDInclusive int64 // in mem only } ) @@ -76,46 +90,52 @@ func newDBTaskManager( taskQueueKey persistence.TaskQueueKey, taskQueueKind enumspb.TaskQueueKind, taskIDRangeSize int64, + dispatchTaskFn func(context.Context, *internalTask) error, store persistence.TaskManager, logger log.Logger, - dispatchTaskFn func(context.Context, *internalTask) error, - finishTaskFn func(*persistencespb.AllocatedTaskInfo, error), -) (*dbTaskManager, error) { - taskOwnership := newDBTaskQueueOwnership( - taskQueueKey, - taskQueueKind, - taskIDRangeSize, - store, - logger, - ) - if err := taskOwnership.takeTaskQueueOwnership(); err != nil { - return nil, err - } - +) *dbTaskManager { return &dbTaskManager{ - status: common.DaemonStatusInitialized, - taskQueueKey: taskQueueKey, - store: store, - taskQueueOwnership: taskOwnership, - taskReader: newDBTaskWriter( - taskQueueKey, - taskOwnership, - logger, - ), - taskWriter: newDBTaskReader( - taskQueueKey, - store, - taskOwnership.getAckedTaskID(), - logger, - ), + status: common.DaemonStatusInitialized, + taskQueueKey: taskQueueKey, + taskQueueKind: taskQueueKind, + taskIDRangeSize: taskIDRangeSize, + taskQueueOwnershipProvider: func() dbTaskQueueOwnership { + return newDBTaskQueueOwnership( + taskQueueKey, + taskQueueKind, + taskIDRangeSize, + store, + logger, + ) + }, + taskReaderProvider: func(taskQueueOwnership dbTaskQueueOwnership) dbTaskReader { + return newDBTaskReader( + taskQueueKey, + store, + taskQueueOwnership.getAckedTaskID(), + logger, + ) + }, + taskWriterProvider: func(taskQueueOwnership dbTaskQueueOwnership) dbTaskWriter { + return newDBTaskWriter( + taskQueueKey, + taskQueueOwnership, + logger, + ) + }, dispatchTaskFn: dispatchTaskFn, - finishTaskFn: finishTaskFn, + store: store, logger: logger, - shutdownChan: make(chan struct{}), - dispatchChan: make(chan struct{}, 1), - maxDeletedTaskIDInclusive: taskOwnership.getAckedTaskID(), - }, nil + dispatchChan: make(chan struct{}, 1), + startupChan: make(chan struct{}), + shutdownChan: make(chan struct{}), + + taskQueueOwnership: nil, + taskWriter: nil, + taskReader: nil, + maxDeletedTaskIDInclusive: 0, + } } func (d *dbTaskManager) Start() { @@ -127,9 +147,10 @@ func (d *dbTaskManager) Start() { return } - d.SignalDispatch() - go d.readerEventLoop() + d.signalDispatch() + go d.acquireLoop() go d.writerEventLoop() + go d.readerEventLoop() } func (d *dbTaskManager) Stop() { @@ -144,18 +165,30 @@ func (d *dbTaskManager) Stop() { close(d.shutdownChan) } -func (d *dbTaskManager) SignalDispatch() { - select { - case d.dispatchChan <- struct{}{}: - default: // channel already has an event, don't block - } -} - func (d *dbTaskManager) isStopped() bool { return atomic.LoadInt32(&d.status) == common.DaemonStatusStopped } +func (d *dbTaskManager) acquireLoop() { + defer close(d.startupChan) + +AcquireLoop: + for !d.isStopped() { + err := d.acquireOwnership() + if err == nil { + break AcquireLoop + } + if !common.IsPersistenceTransientError(err) { + d.Stop() + break AcquireLoop + } + time.Sleep(2 * time.Second) + } +} + func (d *dbTaskManager) writerEventLoop() { + <-d.startupChan + updateQueueTicker := time.NewTicker(dbTaskUpdateQueueInterval) defer updateQueueTicker.Stop() // TODO we should impl a more efficient method to @@ -178,16 +211,18 @@ func (d *dbTaskManager) writerEventLoop() { case <-updateQueueTicker.C: d.persistTaskQueue() case <-flushTicker.C: - d.taskReader.flushTasks() - d.SignalDispatch() - case <-d.taskReader.notifyFlushChan(): - d.taskReader.flushTasks() - d.SignalDispatch() + d.taskWriter.flushTasks() + d.signalDispatch() + case <-d.taskWriter.notifyFlushChan(): + d.taskWriter.flushTasks() + d.signalDispatch() } } } func (d *dbTaskManager) readerEventLoop() { + <-d.startupChan + updateAckTicker := time.NewTicker(dbTaskUpdateAckInterval) defer updateAckTicker.Stop() @@ -215,19 +250,46 @@ func (d *dbTaskManager) readerEventLoop() { } } -func (d *dbTaskManager) bufferAndWriteTask( +func (d *dbTaskManager) acquireOwnership() error { + taskQueueOwnership := d.taskQueueOwnershipProvider() + if err := taskQueueOwnership.takeTaskQueueOwnership(); err != nil { + return err + } + d.taskReader = d.taskReaderProvider(taskQueueOwnership) + d.taskWriter = d.taskWriterProvider(taskQueueOwnership) + d.maxDeletedTaskIDInclusive = taskQueueOwnership.getAckedTaskID() + d.taskQueueOwnership = taskQueueOwnership + return nil +} + +func (d *dbTaskManager) signalDispatch() { + select { + case d.dispatchChan <- struct{}{}: + default: // channel already has an event, don't block + } +} + +func (d *dbTaskManager) BufferAndWriteTask( task *persistencespb.TaskInfo, ) future.Future { - return d.taskReader.appendTask(task) + select { + case <-d.startupChan: + if d.isStopped() { + return future.NewReadyFuture(nil, errDBTaskManagerNotReady) + } + return d.taskWriter.appendTask(task) + default: + return future.NewReadyFuture(nil, errDBTaskManagerNotReady) + } } func (d *dbTaskManager) readAndDispatchTasks() { - iter := d.taskWriter.taskIterator(d.taskQueueOwnership.getLastAllocatedTaskID()) + iter := d.taskReader.taskIterator(d.taskQueueOwnership.getLastAllocatedTaskID()) for iter.HasNext() { item, err := iter.Next() if err != nil { d.logger.Error("dbTaskManager encountered error when fetching tasks", tag.Error(err)) - d.SignalDispatch() + d.signalDispatch() return } @@ -241,13 +303,13 @@ func (d *dbTaskManager) mustDispatch( ) { for !d.isStopped() { if taskqueue.IsTaskExpired(task) { - d.taskWriter.ackTask(task.TaskId) + d.taskReader.ackTask(task.TaskId) return } err := d.dispatchTaskFn(context.Background(), newInternalTask( task, - d.finishTaskFn, + d.finishTask, enumsspb.TASK_SOURCE_DB_BACKLOG, "", false, @@ -260,7 +322,7 @@ func (d *dbTaskManager) mustDispatch( } func (d *dbTaskManager) updateAckTaskID() { - ackedTaskID := d.taskWriter.moveAckedTaskID() + ackedTaskID := d.taskReader.moveAckedTaskID() d.taskQueueOwnership.updateAckedTaskID(ackedTaskID) } @@ -289,3 +351,35 @@ func (d *dbTaskManager) persistTaskQueue() { d.logger.Error("dbTaskManager encountered unknown error", tag.Error(err)) } } + +func (d *dbTaskManager) finishTask( + info *persistencespb.AllocatedTaskInfo, + err error, +) { + if err == nil { + d.taskReader.ackTask(info.TaskId) + return + } + + // TODO @wxing1292 logic below is subject to discussion + // NOTE: logic below is legacy logic, which will move task with error + // to the end of the queue for later retry + // + // failed to start the task. + // We cannot just remove it from persistence because then it will be lost. + // We handle this by writing the task back to persistence with a higher taskID. + // This will allow subsequent tasks to make progress, and hopefully by the time this task is picked-up + // again the underlying reason for failing to start will be resolved. + // Note that RecordTaskStarted only fails after retrying for a long time, so a single task will not be + // re-written to persistence frequently. + _, err = d.BufferAndWriteTask(info.Data).Get(context.Background()) + if err != nil { + d.logger.Error("dbTaskManager encountered error when moving task to end of task queue", + tag.Error(err), + tag.WorkflowTaskQueueName(d.taskQueueKey.TaskQueueName), + tag.WorkflowTaskQueueType(d.taskQueueKey.TaskQueueType)) + d.Stop() + return + } + d.taskReader.ackTask(info.TaskId) +} diff --git a/service/matching/db_task_manager_test.go b/service/matching/db_task_manager_test.go new file mode 100644 index 00000000000..476a7947914 --- /dev/null +++ b/service/matching/db_task_manager_test.go @@ -0,0 +1,392 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package matching + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/serviceerror" + + persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/common/collection" + "go.temporal.io/server/common/future" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/primitives/timestamp" +) + +type ( + dbTaskManagerSuite struct { + *require.Assertions + suite.Suite + + controller *gomock.Controller + taskQueueOwnership *MockdbTaskQueueOwnership + taskWriter *MockdbTaskWriter + taskReader *MockdbTaskReader + store *persistence.MockTaskManager + ackedTaskID int64 + lastAllocatedTaskID int64 + dispatchTaskFn func(context.Context, *internalTask) error + + namespaceID string + taskQueueName string + taskQueueType enumspb.TaskQueueType + taskQueueKind enumspb.TaskQueueKind + taskIDRangeSize int64 + + dbTaskManager *dbTaskManager + } +) + +func TestDBTaskManagerSuite(t *testing.T) { + s := new(dbTaskManagerSuite) + suite.Run(t, s) +} + +func (s *dbTaskManagerSuite) SetupSuite() { + rand.Seed(time.Now().UnixNano()) +} + +func (s *dbTaskManagerSuite) TearDownSuite() { + +} + +func (s *dbTaskManagerSuite) SetupTest() { + s.Assertions = require.New(s.T()) + + logger := log.NewTestLogger() + s.controller = gomock.NewController(s.T()) + s.taskQueueOwnership = NewMockdbTaskQueueOwnership(s.controller) + s.taskWriter = NewMockdbTaskWriter(s.controller) + s.taskReader = NewMockdbTaskReader(s.controller) + s.store = persistence.NewMockTaskManager(s.controller) + s.ackedTaskID = rand.Int63() + s.lastAllocatedTaskID = s.ackedTaskID + 100 + s.dispatchTaskFn = func(context.Context, *internalTask) error { + panic("unexpected call to dispatch function") + } + + s.namespaceID = uuid.New().String() + s.taskQueueName = uuid.New().String() + s.taskQueueType = enumspb.TASK_QUEUE_TYPE_WORKFLOW + s.taskQueueKind = enumspb.TASK_QUEUE_KIND_STICKY + s.taskIDRangeSize = rand.Int63() + + s.dbTaskManager = newDBTaskManager( + persistence.TaskQueueKey{ + NamespaceID: s.namespaceID, + TaskQueueName: s.taskQueueName, + TaskQueueType: s.taskQueueType, + }, + s.taskQueueKind, + s.taskIDRangeSize, + s.dispatchTaskFn, + s.store, + logger, + ) + s.dbTaskManager.taskQueueOwnershipProvider = func() dbTaskQueueOwnership { + return s.taskQueueOwnership + } + s.dbTaskManager.taskReaderProvider = func(_ dbTaskQueueOwnership) dbTaskReader { + return s.taskReader + } + s.dbTaskManager.taskWriterProvider = func(_ dbTaskQueueOwnership) dbTaskWriter { + return s.taskWriter + } +} + +func (s *dbTaskManagerSuite) TearDownTest() { + s.dbTaskManager.Stop() + s.controller.Finish() +} + +func (s *dbTaskManagerSuite) TestAcquireOwnership_Success() { + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID).AnyTimes() + + err := s.dbTaskManager.acquireOwnership() + s.NoError(err) + s.NotNil(s.dbTaskManager.taskWriter) + s.NotNil(s.dbTaskManager.taskReader) + s.Equal(s.ackedTaskID, s.dbTaskManager.maxDeletedTaskIDInclusive) +} + +func (s *dbTaskManagerSuite) TestAcquireOwnership_Failed() { + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(serviceerror.NewUnavailable("some random error")) + + err := s.dbTaskManager.acquireOwnership() + s.Error(err) + s.Nil(s.dbTaskManager.taskWriter) + s.Nil(s.dbTaskManager.taskReader) + s.Equal(int64(0), s.dbTaskManager.maxDeletedTaskIDInclusive) +} + +func (s *dbTaskManagerSuite) TestStart_Success() { + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID).AnyTimes() + s.taskQueueOwnership.EXPECT().getLastAllocatedTaskID().Return(s.lastAllocatedTaskID).AnyTimes() + s.taskQueueOwnership.EXPECT().getShutdownChan().Return(nil).AnyTimes() + s.taskReader.EXPECT().taskIterator(s.lastAllocatedTaskID).Return(collection.NewPagingIterator( + func(paginationToken []byte) ([]interface{}, []byte, error) { + return nil, nil, nil + }, + )).AnyTimes() + s.taskWriter.EXPECT().notifyFlushChan().Return(nil).AnyTimes() + + s.dbTaskManager.Start() + <-s.dbTaskManager.startupChan + s.False(s.dbTaskManager.isStopped()) +} + +func (s *dbTaskManagerSuite) TestStart_ErrorThenSuccess() { + gomock.InOrder( + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(serviceerror.NewUnavailable("some random error")), + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil), + ) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID).AnyTimes() + s.taskQueueOwnership.EXPECT().getLastAllocatedTaskID().Return(s.lastAllocatedTaskID).AnyTimes() + s.taskQueueOwnership.EXPECT().getShutdownChan().Return(nil).AnyTimes() + s.taskReader.EXPECT().taskIterator(s.lastAllocatedTaskID).Return(collection.NewPagingIterator( + func(paginationToken []byte) ([]interface{}, []byte, error) { + return nil, nil, nil + }, + )).AnyTimes() + s.taskWriter.EXPECT().notifyFlushChan().Return(nil).AnyTimes() + + s.dbTaskManager.Start() + <-s.dbTaskManager.startupChan + s.False(s.dbTaskManager.isStopped()) +} + +func (s *dbTaskManagerSuite) TestStart_Error() { + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(&persistence.ConditionFailedError{}) + + s.dbTaskManager.Start() + <-s.dbTaskManager.startupChan + s.True(s.dbTaskManager.isStopped()) +} + +func (s *dbTaskManagerSuite) TestBufferAndWriteTask_NotReady() { + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(serviceerror.NewUnavailable("some random error")).AnyTimes() + s.dbTaskManager.Start() + + taskInfo := &persistencespb.TaskInfo{} + fut := s.dbTaskManager.BufferAndWriteTask(taskInfo) + _, err := fut.Get(context.Background()) + s.Equal(errDBTaskManagerNotReady, err) +} + +func (s *dbTaskManagerSuite) TestBufferAndWriteTask_Ready() { + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID).AnyTimes() + s.taskQueueOwnership.EXPECT().getLastAllocatedTaskID().Return(s.lastAllocatedTaskID).AnyTimes() + s.taskQueueOwnership.EXPECT().getShutdownChan().Return(nil).AnyTimes() + s.taskReader.EXPECT().taskIterator(s.lastAllocatedTaskID).Return(collection.NewPagingIterator( + func(paginationToken []byte) ([]interface{}, []byte, error) { + return nil, nil, nil + }, + )).AnyTimes() + s.taskWriter.EXPECT().notifyFlushChan().Return(nil).AnyTimes() + s.dbTaskManager.Start() + <-s.dbTaskManager.startupChan + + taskInfo := &persistencespb.TaskInfo{} + taskWriterErr := serviceerror.NewInternal("random error") + s.taskWriter.EXPECT().appendTask(taskInfo).Return( + future.NewReadyFuture(nil, taskWriterErr), + ) + fut := s.dbTaskManager.BufferAndWriteTask(taskInfo) + _, err := fut.Get(context.Background()) + s.Equal(taskWriterErr, err) +} + +func (s *dbTaskManagerSuite) TestReadAndDispatchTasks_ReadSuccess_Expired() { + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID) + err := s.dbTaskManager.acquireOwnership() + s.NoError(err) + + // make sure no signal exists in dispatch chan + select { + case <-s.dbTaskManager.dispatchChan: + default: + } + + allocatedTaskInfo := &persistencespb.AllocatedTaskInfo{ + TaskId: s.lastAllocatedTaskID + 100, + Data: &persistencespb.TaskInfo{ + NamespaceId: uuid.New().String(), + WorkflowId: uuid.New().String(), + RunId: uuid.New().String(), + ScheduleId: rand.Int63(), + CreateTime: timestamp.TimePtr(time.Now().UTC()), + ExpiryTime: timestamp.TimePtr(time.Now().UTC().Add(-time.Minute)), + }, + } + s.taskQueueOwnership.EXPECT().getLastAllocatedTaskID().Return(s.lastAllocatedTaskID) + s.taskReader.EXPECT().taskIterator(s.lastAllocatedTaskID).Return(collection.NewPagingIterator( + func(paginationToken []byte) ([]interface{}, []byte, error) { + return []interface{}{allocatedTaskInfo}, nil, nil + }, + )) + s.taskReader.EXPECT().ackTask(allocatedTaskInfo.TaskId) + + s.dbTaskManager.readAndDispatchTasks() +} + +func (s *dbTaskManagerSuite) TestReadAndDispatchTasks_ReadSuccess_Dispatch() { + var dispatchedTasks []*persistencespb.AllocatedTaskInfo + s.dbTaskManager.dispatchTaskFn = func(_ context.Context, task *internalTask) error { + dispatchedTasks = append(dispatchedTasks, task.event.AllocatedTaskInfo) + return nil + } + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID) + err := s.dbTaskManager.acquireOwnership() + s.NoError(err) + + // make sure no signal exists in dispatch chan + select { + case <-s.dbTaskManager.dispatchChan: + default: + } + + allocatedTaskInfo := &persistencespb.AllocatedTaskInfo{ + TaskId: s.lastAllocatedTaskID + 100, + Data: &persistencespb.TaskInfo{ + NamespaceId: uuid.New().String(), + WorkflowId: uuid.New().String(), + RunId: uuid.New().String(), + ScheduleId: rand.Int63(), + CreateTime: timestamp.TimePtr(time.Now().UTC()), + ExpiryTime: timestamp.TimePtr(time.Unix(0, 0)), + }, + } + s.taskQueueOwnership.EXPECT().getLastAllocatedTaskID().Return(s.lastAllocatedTaskID) + s.taskReader.EXPECT().taskIterator(s.lastAllocatedTaskID).Return(collection.NewPagingIterator( + func(paginationToken []byte) ([]interface{}, []byte, error) { + return []interface{}{allocatedTaskInfo}, nil, nil + }, + )) + + s.dbTaskManager.readAndDispatchTasks() + s.Equal([]*persistencespb.AllocatedTaskInfo{allocatedTaskInfo}, dispatchedTasks) +} + +func (s *dbTaskManagerSuite) TestReadAndDispatchTasks_ReadFailure() { + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID) + err := s.dbTaskManager.acquireOwnership() + s.NoError(err) + + // make sure no signal exists in dispatch chan + select { + case <-s.dbTaskManager.dispatchChan: + default: + } + + s.taskQueueOwnership.EXPECT().getLastAllocatedTaskID().Return(s.lastAllocatedTaskID) + s.taskReader.EXPECT().taskIterator(s.lastAllocatedTaskID).Return(collection.NewPagingIterator( + func(paginationToken []byte) ([]interface{}, []byte, error) { + return nil, nil, serviceerror.NewUnavailable("random error") + }, + )) + + s.dbTaskManager.readAndDispatchTasks() + select { + case <-s.dbTaskManager.dispatchChan: + // noop + default: + s.Fail("dispatch channel should contain one signal") + } +} + +func (s *dbTaskManagerSuite) TestUpdateAckTaskID() { + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID) + err := s.dbTaskManager.acquireOwnership() + s.NoError(err) + + ackedTaskID := rand.Int63() + s.taskReader.EXPECT().moveAckedTaskID().Return(ackedTaskID) + s.taskQueueOwnership.EXPECT().updateAckedTaskID(ackedTaskID) + + s.dbTaskManager.updateAckTaskID() +} + +func (s *dbTaskManagerSuite) TestDeleteAckedTasks_Success() { + maxDeletedTaskIDInclusive := s.ackedTaskID - 100 + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID).AnyTimes() + err := s.dbTaskManager.acquireOwnership() + s.NoError(err) + s.dbTaskManager.maxDeletedTaskIDInclusive = maxDeletedTaskIDInclusive + + s.store.EXPECT().CompleteTasksLessThan(&persistence.CompleteTasksLessThanRequest{ + NamespaceID: s.namespaceID, + TaskQueueName: s.taskQueueName, + TaskType: s.taskQueueType, + TaskID: s.ackedTaskID, + Limit: 100000, + }).Return(0, nil) + + s.dbTaskManager.deleteAckedTasks() + s.Equal(s.ackedTaskID, s.dbTaskManager.maxDeletedTaskIDInclusive) +} + +func (s *dbTaskManagerSuite) TestDeleteAckedTasks_Failed() { + maxDeletedTaskIDInclusive := s.ackedTaskID - 100 + s.taskQueueOwnership.EXPECT().takeTaskQueueOwnership().Return(nil) + s.taskQueueOwnership.EXPECT().getAckedTaskID().Return(s.ackedTaskID).AnyTimes() + err := s.dbTaskManager.acquireOwnership() + s.NoError(err) + s.dbTaskManager.maxDeletedTaskIDInclusive = maxDeletedTaskIDInclusive + + s.store.EXPECT().CompleteTasksLessThan(&persistence.CompleteTasksLessThanRequest{ + NamespaceID: s.namespaceID, + TaskQueueName: s.taskQueueName, + TaskType: s.taskQueueType, + TaskID: s.ackedTaskID, + Limit: 100000, + }).Return(0, serviceerror.NewUnavailable("random error")) + + s.dbTaskManager.deleteAckedTasks() + s.Equal(maxDeletedTaskIDInclusive, s.dbTaskManager.maxDeletedTaskIDInclusive) +} + +// TODO @wxing1292 add necessary tests +// once there is concensus about whether to keep the `task move to end` behavior +func (s *dbTaskManagerSuite) TestFinishTask_Success() {} + +func (s *dbTaskManagerSuite) TestFinishTask_Error() {} diff --git a/service/matching/db_task_queue_ownership.go b/service/matching/db_task_queue_ownership.go index 4c230bd0dcc..d1e0f93ac13 100644 --- a/service/matching/db_task_queue_ownership.go +++ b/service/matching/db_task_queue_ownership.go @@ -51,6 +51,7 @@ type ( dbTaskQueueOwnershipStatus int dbTaskQueueOwnership interface { + takeTaskQueueOwnership() error getShutdownChan() <-chan struct{} getAckedTaskID() int64 updateAckedTaskID(taskID int64) diff --git a/service/matching/db_task_queue_ownership_mock.go b/service/matching/db_task_queue_ownership_mock.go index 4b1faa1580c..bc4952bdfaa 100644 --- a/service/matching/db_task_queue_ownership_mock.go +++ b/service/matching/db_task_queue_ownership_mock.go @@ -132,6 +132,20 @@ func (mr *MockdbTaskQueueOwnershipMockRecorder) persistTaskQueue() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "persistTaskQueue", reflect.TypeOf((*MockdbTaskQueueOwnership)(nil).persistTaskQueue)) } +// takeTaskQueueOwnership mocks base method. +func (m *MockdbTaskQueueOwnership) takeTaskQueueOwnership() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "takeTaskQueueOwnership") + ret0, _ := ret[0].(error) + return ret0 +} + +// takeTaskQueueOwnership indicates an expected call of takeTaskQueueOwnership. +func (mr *MockdbTaskQueueOwnershipMockRecorder) takeTaskQueueOwnership() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "takeTaskQueueOwnership", reflect.TypeOf((*MockdbTaskQueueOwnership)(nil).takeTaskQueueOwnership)) +} + // updateAckedTaskID mocks base method. func (m *MockdbTaskQueueOwnership) updateAckedTaskID(taskID int64) { m.ctrl.T.Helper() diff --git a/service/matching/db_task_reader.go b/service/matching/db_task_reader.go index ee23e1be690..085eb184b09 100644 --- a/service/matching/db_task_reader.go +++ b/service/matching/db_task_reader.go @@ -37,8 +37,16 @@ const ( dbTaskReaderPageSize = 100 ) +//go:generate mockgen -copyright_file ../../LICENSE -package $GOPACKAGE -source $GOFILE -destination db_task_reader_mock.go + type ( - dbTaskReader struct { + dbTaskReader interface { + taskIterator(maxTaskID int64) collection.Iterator + ackTask(taskID int64) + moveAckedTaskID() int64 + } + + dbTaskReaderImpl struct { taskQueueKey persistence.TaskQueueKey store persistence.TaskManager logger log.Logger @@ -55,8 +63,8 @@ func newDBTaskReader( store persistence.TaskManager, ackedTaskID int64, logger log.Logger, -) *dbTaskReader { - return &dbTaskReader{ +) *dbTaskReaderImpl { + return &dbTaskReaderImpl{ taskQueueKey: taskQueueKey, store: store, logger: logger, @@ -67,13 +75,13 @@ func newDBTaskReader( } } -func (t *dbTaskReader) taskIterator( +func (t *dbTaskReaderImpl) taskIterator( maxTaskID int64, ) collection.Iterator { return collection.NewPagingIterator(t.getPaginationFn(maxTaskID)) } -func (t *dbTaskReader) ackTask(taskID int64) { +func (t *dbTaskReaderImpl) ackTask(taskID int64) { t.Lock() defer t.Unlock() @@ -91,7 +99,7 @@ func (t *dbTaskReader) ackTask(taskID int64) { // 12 -> true // 15 -> false // the acked task ID can be set to 12, meaning task with ID <= 12 are finished -func (t *dbTaskReader) moveAckedTaskID() int64 { +func (t *dbTaskReaderImpl) moveAckedTaskID() int64 { t.Lock() defer t.Unlock() @@ -111,7 +119,7 @@ func (t *dbTaskReader) moveAckedTaskID() int64 { return t.ackedTaskID } -func (t *dbTaskReader) getPaginationFn( +func (t *dbTaskReaderImpl) getPaginationFn( maxTaskID int64, ) collection.PaginationFn { t.Lock() diff --git a/service/matching/db_task_reader_mock.go b/service/matching/db_task_reader_mock.go new file mode 100644 index 00000000000..b98b51b11a9 --- /dev/null +++ b/service/matching/db_task_reader_mock.go @@ -0,0 +1,99 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: db_task_reader.go + +// Package matching is a generated GoMock package. +package matching + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + collection "go.temporal.io/server/common/collection" +) + +// MockdbTaskReader is a mock of dbTaskReader interface. +type MockdbTaskReader struct { + ctrl *gomock.Controller + recorder *MockdbTaskReaderMockRecorder +} + +// MockdbTaskReaderMockRecorder is the mock recorder for MockdbTaskReader. +type MockdbTaskReaderMockRecorder struct { + mock *MockdbTaskReader +} + +// NewMockdbTaskReader creates a new mock instance. +func NewMockdbTaskReader(ctrl *gomock.Controller) *MockdbTaskReader { + mock := &MockdbTaskReader{ctrl: ctrl} + mock.recorder = &MockdbTaskReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockdbTaskReader) EXPECT() *MockdbTaskReaderMockRecorder { + return m.recorder +} + +// ackTask mocks base method. +func (m *MockdbTaskReader) ackTask(taskID int64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ackTask", taskID) +} + +// ackTask indicates an expected call of ackTask. +func (mr *MockdbTaskReaderMockRecorder) ackTask(taskID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ackTask", reflect.TypeOf((*MockdbTaskReader)(nil).ackTask), taskID) +} + +// moveAckedTaskID mocks base method. +func (m *MockdbTaskReader) moveAckedTaskID() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "moveAckedTaskID") + ret0, _ := ret[0].(int64) + return ret0 +} + +// moveAckedTaskID indicates an expected call of moveAckedTaskID. +func (mr *MockdbTaskReaderMockRecorder) moveAckedTaskID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "moveAckedTaskID", reflect.TypeOf((*MockdbTaskReader)(nil).moveAckedTaskID)) +} + +// taskIterator mocks base method. +func (m *MockdbTaskReader) taskIterator(maxTaskID int64) collection.Iterator { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "taskIterator", maxTaskID) + ret0, _ := ret[0].(collection.Iterator) + return ret0 +} + +// taskIterator indicates an expected call of taskIterator. +func (mr *MockdbTaskReaderMockRecorder) taskIterator(maxTaskID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "taskIterator", reflect.TypeOf((*MockdbTaskReader)(nil).taskIterator), maxTaskID) +} diff --git a/service/matching/db_task_reader_test.go b/service/matching/db_task_reader_test.go index 337634c1ff3..1460967146f 100644 --- a/service/matching/db_task_reader_test.go +++ b/service/matching/db_task_reader_test.go @@ -56,7 +56,7 @@ type ( ackedTaskID int64 maxTaskID int64 - taskTracker *dbTaskReader + taskTracker *dbTaskReaderImpl } ) diff --git a/service/matching/db_task_writer.go b/service/matching/db_task_writer.go index e9c0d4d8415..2385fe62111 100644 --- a/service/matching/db_task_writer.go +++ b/service/matching/db_task_writer.go @@ -38,13 +38,25 @@ const ( dbTaskFlusherBufferSize = dbTaskFlusherBatchSize * 4 ) +var ( + errDBTaskWriterBufferFull = serviceerror.NewUnavailable("dbTaskWriter encountered task buffer full") +) + +//go:generate mockgen -copyright_file ../../LICENSE -package $GOPACKAGE -source $GOFILE -destination db_task_writer_mock.go + type ( + dbTaskWriter interface { + appendTask(task *persistencespb.TaskInfo) future.Future + flushTasks() + notifyFlushChan() <-chan struct{} + } + dbTaskInfo struct { task *persistencespb.TaskInfo future *future.FutureImpl // nil, error } - dbTaskWriter struct { + dbTaskWriterImpl struct { taskQueueKey persistence.TaskQueueKey ownership dbTaskQueueOwnership logger log.Logger @@ -58,8 +70,8 @@ func newDBTaskWriter( taskQueueKey persistence.TaskQueueKey, ownership dbTaskQueueOwnership, logger log.Logger, -) *dbTaskWriter { - return &dbTaskWriter{ +) *dbTaskWriterImpl { + return &dbTaskWriterImpl{ taskQueueKey: taskQueueKey, ownership: ownership, logger: logger, @@ -69,7 +81,7 @@ func newDBTaskWriter( } } -func (f *dbTaskWriter) appendTask( +func (f *dbTaskWriterImpl) appendTask( task *persistencespb.TaskInfo, ) future.Future { if len(f.taskBuffer) >= dbTaskFlusherBatchSize { @@ -85,18 +97,18 @@ func (f *dbTaskWriter) appendTask( // noop default: // busy - fut.Set(nil, serviceerror.NewUnavailable("dbTaskWriter encountered task buffer full")) + fut.Set(nil, errDBTaskWriterBufferFull) } return fut } -func (f *dbTaskWriter) flushTasks() { +func (f *dbTaskWriterImpl) flushTasks() { for len(f.taskBuffer) > 0 { f.flushTasksOnce() } } -func (f *dbTaskWriter) flushTasksOnce() { +func (f *dbTaskWriterImpl) flushTasksOnce() { tasks := make([]*persistencespb.TaskInfo, 0, dbTaskFlusherBatchSize) futures := make([]*future.FutureImpl, 0, len(tasks)) @@ -120,7 +132,7 @@ FlushLoop: } } -func (f *dbTaskWriter) notifyFlush() { +func (f *dbTaskWriterImpl) notifyFlush() { select { case f.flushSignalChan <- struct{}{}: default: @@ -128,6 +140,6 @@ func (f *dbTaskWriter) notifyFlush() { } } -func (f *dbTaskWriter) notifyFlushChan() <-chan struct{} { +func (f *dbTaskWriterImpl) notifyFlushChan() <-chan struct{} { return f.flushSignalChan } diff --git a/service/matching/db_task_writer_mock.go b/service/matching/db_task_writer_mock.go new file mode 100644 index 00000000000..2a890ffd4eb --- /dev/null +++ b/service/matching/db_task_writer_mock.go @@ -0,0 +1,100 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: db_task_writer.go + +// Package matching is a generated GoMock package. +package matching + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + persistence "go.temporal.io/server/api/persistence/v1" + future "go.temporal.io/server/common/future" +) + +// MockdbTaskWriter is a mock of dbTaskWriter interface. +type MockdbTaskWriter struct { + ctrl *gomock.Controller + recorder *MockdbTaskWriterMockRecorder +} + +// MockdbTaskWriterMockRecorder is the mock recorder for MockdbTaskWriter. +type MockdbTaskWriterMockRecorder struct { + mock *MockdbTaskWriter +} + +// NewMockdbTaskWriter creates a new mock instance. +func NewMockdbTaskWriter(ctrl *gomock.Controller) *MockdbTaskWriter { + mock := &MockdbTaskWriter{ctrl: ctrl} + mock.recorder = &MockdbTaskWriterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockdbTaskWriter) EXPECT() *MockdbTaskWriterMockRecorder { + return m.recorder +} + +// appendTask mocks base method. +func (m *MockdbTaskWriter) appendTask(task *persistence.TaskInfo) future.Future { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "appendTask", task) + ret0, _ := ret[0].(future.Future) + return ret0 +} + +// appendTask indicates an expected call of appendTask. +func (mr *MockdbTaskWriterMockRecorder) appendTask(task interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "appendTask", reflect.TypeOf((*MockdbTaskWriter)(nil).appendTask), task) +} + +// flushTasks mocks base method. +func (m *MockdbTaskWriter) flushTasks() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "flushTasks") +} + +// flushTasks indicates an expected call of flushTasks. +func (mr *MockdbTaskWriterMockRecorder) flushTasks() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "flushTasks", reflect.TypeOf((*MockdbTaskWriter)(nil).flushTasks)) +} + +// notifyFlushChan mocks base method. +func (m *MockdbTaskWriter) notifyFlushChan() <-chan struct{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "notifyFlushChan") + ret0, _ := ret[0].(<-chan struct{}) + return ret0 +} + +// notifyFlushChan indicates an expected call of notifyFlushChan. +func (mr *MockdbTaskWriterMockRecorder) notifyFlushChan() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "notifyFlushChan", reflect.TypeOf((*MockdbTaskWriter)(nil).notifyFlushChan)) +} diff --git a/service/matching/db_task_writer_test.go b/service/matching/db_task_writer_test.go index 911b83e64bb..00e9ae25185 100644 --- a/service/matching/db_task_writer_test.go +++ b/service/matching/db_task_writer_test.go @@ -54,9 +54,9 @@ type ( namespaceID string taskQueueName string - taskType enumspb.TaskQueueType + taskQueueType enumspb.TaskQueueType - taskFlusher *dbTaskWriter + taskFlusher *dbTaskWriterImpl } ) @@ -81,13 +81,13 @@ func (s *dbTaskWriterSuite) SetupTest() { s.namespaceID = uuid.New().String() s.taskQueueName = uuid.New().String() - s.taskType = enumspb.TASK_QUEUE_TYPE_ACTIVITY + s.taskQueueType = enumspb.TASK_QUEUE_TYPE_ACTIVITY s.taskFlusher = newDBTaskWriter( persistence.TaskQueueKey{ NamespaceID: s.namespaceID, TaskQueueName: s.taskQueueName, - TaskQueueType: s.taskType, + TaskQueueType: s.taskQueueType, }, s.taskOwnership, log.NewTestLogger(),