diff --git a/pkg/lib/watcher/README.md b/pkg/lib/watcher/README.md index 83069abce4..93c2435c65 100644 --- a/pkg/lib/watcher/README.md +++ b/pkg/lib/watcher/README.md @@ -20,7 +20,7 @@ The Watcher Library is an internal component of the Bacalhau project that provid ## Key Components -1. **Registry**: Manages multiple watchers and provides methods to create and manage watchers. +1. **Manager**: Manages multiple watchers and provides methods to create, lookup, and stop watchers. 2. **Watcher**: Represents a single event watcher that processes events sequentially. 3. **EventStore**: Responsible for storing and retrieving events, with BoltDB as the default implementation. 4. **EventHandler**: Interface for handling individual events. @@ -42,9 +42,9 @@ An `Event` represents a single occurrence in the system. It has the following pr The `EventStore` is responsible for persisting events and providing methods to retrieve them. It uses BoltDB as the underlying storage engine and supports features like caching, checkpointing, and garbage collection. -### Registry +### Manager -The `Registry` manages multiple watchers. It's the main entry point for components that want to subscribe to events. +The `Manager` manages multiple watchers and provides methods to create, lookup, and stop watchers. ### Watcher @@ -70,9 +70,9 @@ db, _ := bbolt.Open("events.db", 0600, nil) store, _ := boltdb.NewEventStore(db) ``` -2. Create a Registry: +2. Create a manager: ```go -registry := watcher.NewRegistry(store) +manager := watcher.NewManager(store) ``` 3. Implement an EventHandler: @@ -86,9 +86,32 @@ func (h *MyHandler) HandleEvent(ctx context.Context, event watcher.Event) error ``` -4. Start watching for events: +4. Create a watcher and set handler: + +There are two main approaches to create and configure a watcher with a handler: + +a. Two-Step Creation (Handler After Creation): +```go +// Create watcher +w, _ := manager.Create(ctx, "my-watcher", + watcher.WithFilter(watcher.EventFilter{ + ObjectTypes: []string{"Job", "Execution"}, + Operations: []watcher.Operation{watcher.OperationCreate, watcher.OperationUpdate}, + }), +) + +// Set handler +err = w.SetHandler(&MyHandler{}) + +// Start watching +err = w.Start(ctx) +``` + +b. One-Step Creation (With Auto-Start): ```go -watcher, _ := registry.Watch(ctx, "my-watcher", &MyHandler{}, +w, _ := manager.Create(ctx, "my-watcher", + watcher.WithHandler(&MyHandler{}), + watcher.WithAutoStart(), watcher.WithFilter(watcher.EventFilter{ ObjectTypes: []string{"Job", "Execution"}, Operations: []watcher.Operation{watcher.OperationCreate, watcher.OperationUpdate}, @@ -109,6 +132,8 @@ store.StoreEvent(ctx, watcher.OperationCreate, "Job", jobData) When creating a watcher, you can configure it with various options: - `WithInitialEventIterator(iterator EventIterator)`: Sets the starting position for watching if no checkpoint is found. +- `WithHandler(handler EventHandler)`: Sets the event handler for the watcher. +- `WithAutoStart()`: Enables automatic start of the watcher after creation. - `WithFilter(filter EventFilter)`: Sets the event filter for watching. - `WithBufferSize(size int)`: Sets the size of the event buffer. - `WithBatchSize(size int)`: Sets the number of events to fetch in each batch. @@ -120,8 +145,10 @@ When creating a watcher, you can configure it with various options: Example: ```go -watcher, err := registry.Watch(ctx, "my-watcher", &MyHandler{}, +w, err := manager.Create(ctx, "my-watcher", watcher.WithInitialEventIterator(watcher.TrimHorizonIterator()), + watcher.WithHandler(&MyHandler{}), + watcher.WithAutoStart(), watcher.WithFilter(watcher.EventFilter{ ObjectTypes: []string{"Job", "Execution"}, Operations: []watcher.Operation{watcher.OperationCreate, watcher.OperationUpdate}, diff --git a/pkg/lib/watcher/errors.go b/pkg/lib/watcher/errors.go index 1efd745f7e..c13bf3a6b7 100644 --- a/pkg/lib/watcher/errors.go +++ b/pkg/lib/watcher/errors.go @@ -9,6 +9,8 @@ var ( ErrWatcherAlreadyExists = errors.New("watcher already exists") ErrWatcherNotFound = errors.New("watcher not found") ErrCheckpointNotFound = errors.New("checkpoint not found") + ErrNoHandler = errors.New("no handler configured") + ErrHandlerExists = errors.New("handler already exists") ) // WatcherError represents an error related to a specific watcher diff --git a/pkg/lib/watcher/manager.go b/pkg/lib/watcher/manager.go new file mode 100644 index 0000000000..82bd7e6f51 --- /dev/null +++ b/pkg/lib/watcher/manager.go @@ -0,0 +1,124 @@ +package watcher + +import ( + "context" + "sync" + "time" + + "github.com/rs/zerolog/log" +) + +const ( + DefaultShutdownTimeout = 30 * time.Second +) + +// manager handles lifecycle of multiple watchers with shared resources +type manager struct { + store EventStore + watchers map[string]Watcher + mu sync.RWMutex +} + +// NewManager creates a new Manager with the given EventStore. +// +// Example usage: +// +// store := // initialize your event store +// manager := NewManager(store) +// defer manager.Stop(context.Background()) +// +// watcher, err := manager.Create(context.Background(), "myWatcher") +// if err != nil { +// // handle error +// } +func NewManager(store EventStore) Manager { + return &manager{ + store: store, + watchers: make(map[string]Watcher), + } +} + +// Create creates a new watcher. SetHandler must be called before Start can be called successfully, +// or pass WithHandler option to Create to set the handler at creation time. +func (m *manager) Create(ctx context.Context, watcherID string, opts ...WatchOption) (Watcher, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // Check if a watcher with this ID already exists + if _, exists := m.watchers[watcherID]; exists { + return nil, NewWatcherError(watcherID, ErrWatcherAlreadyExists) + } + + w, err := New(ctx, watcherID, m.store, opts...) + if err != nil { + return nil, err + } + + m.watchers[w.ID()] = w + return w, nil +} + +// Lookup retrieves a specific watcher by ID +func (m *manager) Lookup(_ context.Context, watcherID string) (Watcher, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + w, exists := m.watchers[watcherID] + if !exists { + return nil, NewWatcherError(watcherID, ErrWatcherNotFound) + } + + return w, nil +} + +// Stop gracefully shuts down the manager and all its watchers +func (m *manager) Stop(ctx context.Context) error { + log.Ctx(ctx).Debug().Msg("Shutting down manager") + + // Create a timeout context if the parent context doesn't have a deadline + timeoutCtx := ctx + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + timeoutCtx, cancel = context.WithTimeout(ctx, DefaultShutdownTimeout) + defer cancel() + } + + var wg sync.WaitGroup + + // Take a snapshot of watchers under lock + m.mu.RLock() + watchers := make([]Watcher, 0, len(m.watchers)) + for _, w := range m.watchers { + watchers = append(watchers, w) + } + m.mu.RUnlock() + + // Stop all watchers concurrently + for i := range watchers { + w := watchers[i] + wg.Add(1) + go func(w Watcher) { + defer wg.Done() + w.Stop(timeoutCtx) + }(w) + } + + // Wait for completion or timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + log.Ctx(ctx).Debug().Msg("manager shutdown complete") + return nil + case <-timeoutCtx.Done(): + log.Ctx(ctx).Warn().Msg("manager shutdown timed out") + return timeoutCtx.Err() + } +} + +// compile time check for interface implementation +var _ Manager = &manager{} diff --git a/pkg/lib/watcher/manager_test.go b/pkg/lib/watcher/manager_test.go new file mode 100644 index 0000000000..0312d07b3e --- /dev/null +++ b/pkg/lib/watcher/manager_test.go @@ -0,0 +1,266 @@ +//go:build unit || !integration + +package watcher_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/boltdb" + watchertest "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/test" +) + +type ManagerTestSuite struct { + suite.Suite + ctrl *gomock.Controller + mockStore *watchertest.EventStoreWrapper + mockHandler *watcher.MockEventHandler + manager watcher.Manager +} + +func (s *ManagerTestSuite) SetupTest() { + boltdbEventStore, err := boltdb.NewEventStore(watchertest.CreateBoltDB(s.T())) + s.Require().NoError(err) + + s.ctrl = gomock.NewController(s.T()) + s.mockStore = watchertest.NewEventStoreWrapper(boltdbEventStore) + s.mockHandler = watcher.NewMockEventHandler(s.ctrl) + s.manager = watcher.NewManager(s.mockStore) +} + +func (s *ManagerTestSuite) TearDownTest() { + s.Require().NoError(s.manager.Stop(context.Background()), "failed to stop manager in teardown") + s.ctrl.Finish() +} + +func (s *ManagerTestSuite) TestCreate() { + ctx := context.Background() + watcherID := "test-watcher" + + // Create watcher + w, err := s.manager.Create(ctx, watcherID) + s.Require().NoError(err) + s.Require().NotNil(w) + s.Equal(watcherID, w.ID()) + + // Set handler and start watcher + err = w.SetHandler(s.mockHandler) + s.Require().NoError(err) + s.startAndWait(ctx, w) + + // Stop the manager and ensure the watcher is stopped + err = s.manager.Stop(ctx) + s.Require().NoError(err) + s.Require().Equal(watcher.StateStopped, w.Stats().State) +} + +func (s *ManagerTestSuite) TestCreateDuplicateWatcher() { + ctx := context.Background() + watcherID := "test-watcher" + + _, err := s.manager.Create(ctx, watcherID) + s.Require().NoError(err) + + // Try to create another watcher with same ID + _, err = s.manager.Create(ctx, watcherID) + s.Require().Error(err) + s.Contains(err.Error(), "watcher already exists") +} + +func (s *ManagerTestSuite) TestLookup() { + ctx := context.Background() + watcherID := "test-watcher" + + w, err := s.manager.Create(ctx, watcherID) + s.Require().NoError(err) + + // Lookup existing watcher + retrievedWatcher, err := s.manager.Lookup(context.Background(), watcherID) + s.Require().NoError(err) + s.Require().NotNil(retrievedWatcher) + s.Equal(watcherID, retrievedWatcher.ID()) + s.Equal(w, retrievedWatcher) +} + +func (s *ManagerTestSuite) TestLookupNonExistentWatcher() { + _, err := s.manager.Lookup(context.Background(), "non-existent") + s.Require().Error(err) + s.Contains(err.Error(), "watcher not found") +} + +func (s *ManagerTestSuite) TestStop() { + ctx := context.Background() + + // create a started watcher + w1, err := s.manager.Create(ctx, "watcher-1") + s.Require().NoError(err) + s.Require().NoError(w1.SetHandler(s.mockHandler)) + s.startAndWait(ctx, w1) + + // create a stopped watcher + w2, err := s.manager.Create(ctx, "watcher-2") + s.Require().NoError(err) + s.Require().NoError(w2.SetHandler(s.mockHandler)) + s.startAndWait(ctx, w2) + w2.Stop(ctx) + + // create a non-started watcher + w3, err := s.manager.Create(ctx, "watcher-3") + s.Require().NoError(err) + + err = s.manager.Stop(ctx) + s.Require().NoError(err) + s.Require().Equal(watcher.StateStopped, w1.Stats().State) + s.Require().Equal(watcher.StateStopped, w2.Stats().State) + s.Require().Equal(watcher.StateStopped, w3.Stats().State) +} + +func (s *ManagerTestSuite) TestStopWithTimeout() { + ctx := context.Background() + watcherID := "test-watcher" + + // Create a channel to control GetEvents + getEventsCh := make(chan struct{}) + + //// Set up the mockStore to block on GetEvents + s.mockStore.WithGetEventsInterceptor(func() error { + <-getEventsCh + return nil + }) + + w, err := s.manager.Create(ctx, watcherID) + s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.startAndWait(ctx, w) + + // Create a very short timeout + ctxWithTimeout, cancel := context.WithTimeout(ctx, 1*time.Nanosecond) + defer cancel() + + err = s.manager.Stop(ctxWithTimeout) + s.Require().Error(err) + s.Equal(context.DeadlineExceeded, err) + + // Ensure the watcher is stopping + s.Require().Eventually(func() bool { + return w.Stats().State == watcher.StateStopping + }, 200*time.Millisecond, 10*time.Millisecond) + + // verify that the watcher is still stopping + time.Sleep(100 * time.Millisecond) + s.Require().Equal(watcher.StateStopping, w.Stats().State) + + // Unblock GetEvents + close(getEventsCh) + + // Ensure the watcher is stopped + s.Require().Eventually(func() bool { + return w.Stats().State == watcher.StateStopped + }, 200*time.Millisecond, 10*time.Millisecond) +} + +func (s *ManagerTestSuite) TestWatcherProcessesEvents() { + ctx := context.Background() + watcherID := "test-watcher" + + events := []watcher.StoreEventRequest{ + {Operation: watcher.OperationCreate, ObjectType: "TestObject", Object: "test1"}, + {Operation: watcher.OperationUpdate, ObjectType: "TestObject", Object: "test2"}, + } + + for _, event := range events { + err := s.mockStore.StoreEvent(ctx, event) + s.Require().NoError(err) + } + + s.mockHandler.EXPECT().HandleEvent(gomock.Any(), gomock.Any()).Return(nil).Times(2) + + w, err := s.manager.Create(ctx, watcherID) + s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.startAndWait(ctx, w) + + // Wait for events to be processed + s.Require().Eventually(func() bool { + return w.Stats().LastProcessedSeqNum == 2 + }, 200*time.Millisecond, 10*time.Millisecond) + + err = s.manager.Stop(ctx) + s.Require().NoError(err) +} + +func (s *ManagerTestSuite) TestMultipleWatchers() { + ctx := context.Background() + watcherID1 := "test-watcher-1" + watcherID2 := "test-watcher-2" + + events := []watcher.StoreEventRequest{ + {Operation: watcher.OperationCreate, ObjectType: "TestObject", Object: "test1"}, + {Operation: watcher.OperationUpdate, ObjectType: "TestObject", Object: "test2"}, + } + + for _, event := range events { + err := s.mockStore.StoreEvent(ctx, event) + s.Require().NoError(err) + } + + s.mockHandler.EXPECT().HandleEvent(gomock.Any(), gomock.Any()).Return(nil).Times(4) + + // Create and start first watcher + w1, err := s.manager.Create(ctx, watcherID1) + s.Require().NoError(err) + s.Require().NoError(w1.SetHandler(s.mockHandler)) + s.startAndWait(ctx, w1) + + // Create and start second watcher + w2, err := s.manager.Create(ctx, watcherID2) + s.Require().NoError(err) + s.Require().NoError(w2.SetHandler(s.mockHandler)) + s.startAndWait(ctx, w2) + + // Wait for events to be processed + s.Require().Eventually(func() bool { + return w1.Stats().LastProcessedSeqNum == 2 && w2.Stats().LastProcessedSeqNum == 2 + }, 200*time.Millisecond, 10*time.Millisecond) + + // Stop one watcher and ensure the other is still running + w1.Stop(ctx) + s.Require().Eventually(func() bool { return w1.Stats().State == watcher.StateStopped }, time.Second, 10*time.Millisecond) + s.Equal(watcher.StateRunning, w2.Stats().State) + + // Stop the manager and ensure the second watcher is stopped + err = s.manager.Stop(ctx) + s.Require().NoError(err) + s.Require().Equal(watcher.StateStopped, w1.Stats().State) + s.Require().Equal(watcher.StateStopped, w2.Stats().State) + +} + +func (s *ManagerTestSuite) TestStoppingWatcherMultipleTimes() { + ctx := context.Background() + + err := s.manager.Stop(ctx) + s.Require().NoError(err) + + // Stopping an already stopped manager should not cause issues + err = s.manager.Stop(ctx) + s.Require().NoError(err) +} + +func (s *ManagerTestSuite) startAndWait(ctx context.Context, w watcher.Watcher) { + s.Require().NoError(w.Start(ctx)) + + // Ensure the watcher is running + s.Require().Eventually(func() bool { + return w.Stats().State == watcher.StateRunning + }, time.Second, 10*time.Millisecond) +} + +func TestManagerSuite(t *testing.T) { + suite.Run(t, new(ManagerTestSuite)) +} diff --git a/pkg/lib/watcher/mocks.go b/pkg/lib/watcher/mocks.go index 55d9c8c0e0..99a4e4cf53 100644 --- a/pkg/lib/watcher/mocks.go +++ b/pkg/lib/watcher/mocks.go @@ -76,6 +76,34 @@ func (mr *MockWatcherMockRecorder) SeekToOffset(ctx, eventSeqNum interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SeekToOffset", reflect.TypeOf((*MockWatcher)(nil).SeekToOffset), ctx, eventSeqNum) } +// SetHandler mocks base method. +func (m *MockWatcher) SetHandler(handler EventHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetHandler", handler) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetHandler indicates an expected call of SetHandler. +func (mr *MockWatcherMockRecorder) SetHandler(handler interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandler", reflect.TypeOf((*MockWatcher)(nil).SetHandler), handler) +} + +// Start mocks base method. +func (m *MockWatcher) Start(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockWatcherMockRecorder) Start(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockWatcher)(nil).Start), ctx) +} + // Stats mocks base method. func (m *MockWatcher) Stats() Stats { m.ctrl.T.Helper() @@ -139,76 +167,76 @@ func (mr *MockEventHandlerMockRecorder) HandleEvent(ctx, event interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleEvent", reflect.TypeOf((*MockEventHandler)(nil).HandleEvent), ctx, event) } -// MockRegistry is a mock of Registry interface. -type MockRegistry struct { +// MockManager is a mock of Manager interface. +type MockManager struct { ctrl *gomock.Controller - recorder *MockRegistryMockRecorder + recorder *MockManagerMockRecorder } -// MockRegistryMockRecorder is the mock recorder for MockRegistry. -type MockRegistryMockRecorder struct { - mock *MockRegistry +// MockManagerMockRecorder is the mock recorder for MockManager. +type MockManagerMockRecorder struct { + mock *MockManager } -// NewMockRegistry creates a new mock instance. -func NewMockRegistry(ctrl *gomock.Controller) *MockRegistry { - mock := &MockRegistry{ctrl: ctrl} - mock.recorder = &MockRegistryMockRecorder{mock} +// NewMockManager creates a new mock instance. +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockRegistry) EXPECT() *MockRegistryMockRecorder { +func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } -// GetWatcher mocks base method. -func (m *MockRegistry) GetWatcher(watcherID string) (Watcher, error) { +// Create mocks base method. +func (m *MockManager) Create(ctx context.Context, watcherID string, opts ...WatchOption) (Watcher, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWatcher", watcherID) + varargs := []interface{}{ctx, watcherID} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Create", varargs...) ret0, _ := ret[0].(Watcher) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetWatcher indicates an expected call of GetWatcher. -func (mr *MockRegistryMockRecorder) GetWatcher(watcherID interface{}) *gomock.Call { +// Create indicates an expected call of Create. +func (mr *MockManagerMockRecorder) Create(ctx, watcherID interface{}, opts ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWatcher", reflect.TypeOf((*MockRegistry)(nil).GetWatcher), watcherID) + varargs := append([]interface{}{ctx, watcherID}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockManager)(nil).Create), varargs...) } -// Stop mocks base method. -func (m *MockRegistry) Stop(ctx context.Context) error { +// Lookup mocks base method. +func (m *MockManager) Lookup(ctx context.Context, watcherID string) (Watcher, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Stop", ctx) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "Lookup", ctx, watcherID) + ret0, _ := ret[0].(Watcher) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// Stop indicates an expected call of Stop. -func (mr *MockRegistryMockRecorder) Stop(ctx interface{}) *gomock.Call { +// Lookup indicates an expected call of Lookup. +func (mr *MockManagerMockRecorder) Lookup(ctx, watcherID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockRegistry)(nil).Stop), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lookup", reflect.TypeOf((*MockManager)(nil).Lookup), ctx, watcherID) } -// Watch mocks base method. -func (m *MockRegistry) Watch(ctx context.Context, watcherID string, handler EventHandler, opts ...WatchOption) (Watcher, error) { +// Stop mocks base method. +func (m *MockManager) Stop(ctx context.Context) error { m.ctrl.T.Helper() - varargs := []interface{}{ctx, watcherID, handler} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Watch", varargs...) - ret0, _ := ret[0].(Watcher) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "Stop", ctx) + ret0, _ := ret[0].(error) + return ret0 } -// Watch indicates an expected call of Watch. -func (mr *MockRegistryMockRecorder) Watch(ctx, watcherID, handler interface{}, opts ...interface{}) *gomock.Call { +// Stop indicates an expected call of Stop. +func (mr *MockManagerMockRecorder) Stop(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, watcherID, handler}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockRegistry)(nil).Watch), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockManager)(nil).Stop), ctx) } // MockEventStore is a mock of EventStore interface. @@ -264,18 +292,18 @@ func (mr *MockEventStoreMockRecorder) GetCheckpoint(ctx, watcherID interface{}) } // GetEvents mocks base method. -func (m *MockEventStore) GetEvents(ctx context.Context, params GetEventsRequest) (*GetEventsResponse, error) { +func (m *MockEventStore) GetEvents(ctx context.Context, request GetEventsRequest) (*GetEventsResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEvents", ctx, params) + ret := m.ctrl.Call(m, "GetEvents", ctx, request) ret0, _ := ret[0].(*GetEventsResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // GetEvents indicates an expected call of GetEvents. -func (mr *MockEventStoreMockRecorder) GetEvents(ctx, params interface{}) *gomock.Call { +func (mr *MockEventStoreMockRecorder) GetEvents(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvents", reflect.TypeOf((*MockEventStore)(nil).GetEvents), ctx, params) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvents", reflect.TypeOf((*MockEventStore)(nil).GetEvents), ctx, request) } // GetLatestEventNum mocks base method. @@ -308,17 +336,17 @@ func (mr *MockEventStoreMockRecorder) StoreCheckpoint(ctx, watcherID, eventSeqNu } // StoreEvent mocks base method. -func (m *MockEventStore) StoreEvent(ctx context.Context, event StoreEventRequest) error { +func (m *MockEventStore) StoreEvent(ctx context.Context, request StoreEventRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StoreEvent", ctx, event) + ret := m.ctrl.Call(m, "StoreEvent", ctx, request) ret0, _ := ret[0].(error) return ret0 } // StoreEvent indicates an expected call of StoreEvent. -func (mr *MockEventStoreMockRecorder) StoreEvent(ctx, event interface{}) *gomock.Call { +func (mr *MockEventStoreMockRecorder) StoreEvent(ctx, request interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StoreEvent", reflect.TypeOf((*MockEventStore)(nil).StoreEvent), ctx, event) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StoreEvent", reflect.TypeOf((*MockEventStore)(nil).StoreEvent), ctx, request) } // MockSerializer is a mock of Serializer interface. diff --git a/pkg/lib/watcher/options.go b/pkg/lib/watcher/options.go index 043cb6c6cb..067a650082 100644 --- a/pkg/lib/watcher/options.go +++ b/pkg/lib/watcher/options.go @@ -21,22 +21,30 @@ type WatchOption func(*watchOptions) // watchOptions holds configuration options for watching events type watchOptions struct { initialEventIterator EventIterator // starting position for watching if no checkpoint is found + handler EventHandler // event handler filter EventFilter // filter for events batchSize int // number of events to fetch in each batch initialBackoff time.Duration // initial backoff duration for retries maxBackoff time.Duration // maximum backoff duration for retries maxRetries int retryStrategy RetryStrategy + autoStart bool } // validate checks all options for validity func (o *watchOptions) validate() error { - return errors.Join( + err := errors.Join( validate.IsGreaterThanZero(o.batchSize, "batchSize must be greater than zero"), validate.IsGreaterOrEqualToZero(o.initialBackoff, "initialBackoff cannot be negative"), validate.IsGreaterOrEqualToZero(o.maxBackoff, "maxBackoff cannot be negative"), validate.IsGreaterOrEqualToZero(o.maxRetries, "maxRetries cannot be negative"), validate.IsGreaterOrEqual(o.maxBackoff, o.initialBackoff, "maxBackoff must be greater than or equal to initialBackoff")) + + // validate handler is set if autoStart is enabled + if o.autoStart && o.handler == nil { + err = errors.Join(err, errors.New("handler must be set when autoStart is enabled")) + } + return err } // defaultWatchOptions returns the default watch options @@ -51,6 +59,13 @@ func defaultWatchOptions() *watchOptions { } } +// WithAutoStart enables auto-start for the watcher right after creation +func WithAutoStart() WatchOption { + return func(o *watchOptions) { + o.autoStart = true + } +} + // WithInitialEventIterator sets the starting position for watching if no checkpoint is found func WithInitialEventIterator(iterator EventIterator) WatchOption { return func(o *watchOptions) { @@ -58,6 +73,13 @@ func WithInitialEventIterator(iterator EventIterator) WatchOption { } } +// WithHandler sets the event handler for watching +func WithHandler(handler EventHandler) WatchOption { + return func(o *watchOptions) { + o.handler = handler + } +} + // WithFilter sets the event filter for watching func WithFilter(filter EventFilter) WatchOption { return func(o *watchOptions) { diff --git a/pkg/lib/watcher/registry.go b/pkg/lib/watcher/registry.go deleted file mode 100644 index 2f9ada6afb..0000000000 --- a/pkg/lib/watcher/registry.go +++ /dev/null @@ -1,126 +0,0 @@ -package watcher - -import ( - "context" - "sync" - "time" - - "github.com/rs/zerolog/log" -) - -const ( - DefaultShutdownTimeout = 30 * time.Second -) - -// registry manages multiple event watchers -type registry struct { - store EventStore - watchers map[string]*watcher - mu sync.RWMutex -} - -// NewRegistry creates a new Registry with the given EventStore. -// -// Example usage: -// -// store := // initialize your event store -// registry := NewRegistry(store) -// defer registry.Stop(context.Background()) -// -// watcher, err := registry.Watch(context.Background(), "myWatcher", myEventHandler) -// if err != nil { -// // handle error -// } -func NewRegistry(store EventStore) Registry { - return ®istry{ - store: store, - watchers: make(map[string]*watcher), - } -} - -// Watch starts watching for events with the given options -func (r *registry) Watch(ctx context.Context, watcherID string, handler EventHandler, opts ...WatchOption) (Watcher, error) { - r.mu.Lock() - defer r.mu.Unlock() - - // Check if a watcher with this ID already exists - if w, exists := r.watchers[watcherID]; exists { - if w.Stats().State != StateStopped { - return nil, NewWatcherError(watcherID, ErrWatcherAlreadyExists) - } - } - - w, err := newWatcher(ctx, watcherID, handler, r.store, opts...) - if err != nil { - return nil, err - } - - r.watchers[w.ID()] = w - go w.Start() - return w, nil -} - -// GetWatcher retrieves a specific watcher by ID -func (r *registry) GetWatcher(watcherID string) (Watcher, error) { - r.mu.RLock() - defer r.mu.RUnlock() - - w, exists := r.watchers[watcherID] - if !exists { - return nil, NewWatcherError(watcherID, ErrWatcherNotFound) - } - - return w, nil -} - -// Stop gracefully shuts down the registry and all its watchers -func (r *registry) Stop(ctx context.Context) error { - log.Ctx(ctx).Debug().Msg("Shutting down registry") - - // Create a timeout context if the parent context doesn't have a deadline - timeoutCtx := ctx - if _, hasDeadline := ctx.Deadline(); !hasDeadline { - var cancel context.CancelFunc - timeoutCtx, cancel = context.WithTimeout(ctx, DefaultShutdownTimeout) - defer cancel() - } - - var wg sync.WaitGroup - - // Take a snapshot of watchers under lock - r.mu.RLock() - watchers := make([]Watcher, 0, len(r.watchers)) - for _, w := range r.watchers { - watchers = append(watchers, w) - } - r.mu.RUnlock() - - // Stop all watchers concurrently - for i := range watchers { - w := watchers[i] - wg.Add(1) - go func(w Watcher) { - defer wg.Done() - w.Stop(timeoutCtx) - }(w) - } - - // Wait for completion or timeout - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - - select { - case <-done: - log.Ctx(ctx).Debug().Msg("registry shutdown complete") - return nil - case <-timeoutCtx.Done(): - log.Ctx(ctx).Warn().Msg("registry shutdown timed out") - return timeoutCtx.Err() - } -} - -// compile time check for interface implementation -var _ Registry = ®istry{} diff --git a/pkg/lib/watcher/registry_test.go b/pkg/lib/watcher/registry_test.go deleted file mode 100644 index 76bb5aa241..0000000000 --- a/pkg/lib/watcher/registry_test.go +++ /dev/null @@ -1,267 +0,0 @@ -//go:build unit || !integration - -package watcher_test - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" - - "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" - "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/boltdb" - watchertest "github.com/bacalhau-project/bacalhau/pkg/lib/watcher/test" -) - -type RegistryTestSuite struct { - suite.Suite - ctrl *gomock.Controller - mockStore *watchertest.EventStoreWrapper - mockHandler *watcher.MockEventHandler - registry watcher.Registry -} - -func (s *RegistryTestSuite) SetupTest() { - boltdbEventStore, err := boltdb.NewEventStore(watchertest.CreateBoltDB(s.T())) - s.Require().NoError(err) - - s.ctrl = gomock.NewController(s.T()) - s.mockStore = watchertest.NewEventStoreWrapper(boltdbEventStore) - s.mockHandler = watcher.NewMockEventHandler(s.ctrl) - s.registry = watcher.NewRegistry(s.mockStore) -} - -func (s *RegistryTestSuite) TearDownTest() { - s.ctrl.Finish() -} - -func (s *RegistryTestSuite) TestWatch() { - ctx := context.Background() - watcherID := "test-watcher" - - l, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().NoError(err) - s.Require().NotNil(l) - s.Equal(watcherID, l.ID()) - - // Stop the watcher to prevent further GetEvents calls - l.Stop(ctx) -} - -func (s *RegistryTestSuite) TestWatchDuplicateWatcher() { - ctx := context.Background() - watcherID := "test-watcher" - - l, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().NoError(err) - defer l.Stop(ctx) - - _, err = s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().Error(err) - s.Contains(err.Error(), "watcher already exists") -} - -func (s *RegistryTestSuite) TestGetWatcher() { - ctx := context.Background() - watcherID := "test-watcher" - - l, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().NoError(err) - defer l.Stop(ctx) - - retrievedWatcher, err := s.registry.GetWatcher(watcherID) - s.Require().NoError(err) - s.Require().NotNil(retrievedWatcher) - s.Equal(watcherID, retrievedWatcher.ID()) -} - -func (s *RegistryTestSuite) TestGetNonExistentWatcher() { - _, err := s.registry.GetWatcher("non-existent") - s.Require().Error(err) - s.Contains(err.Error(), "watcher not found") -} - -func (s *RegistryTestSuite) TestStop() { - ctx := context.Background() - watcherID := "test-watcher" - - l, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().NoError(err) - - // Ensure the watcher is running - s.Eventually(func() bool { - return l.Stats().State == watcher.StateRunning - }, 1*time.Second, 10*time.Millisecond) - - err = s.registry.Stop(ctx) - s.Require().NoError(err) - s.Require().Equal(watcher.StateStopped, l.Stats().State) -} - -func (s *RegistryTestSuite) TestStopWithTimeout() { - ctx := context.Background() - watcherID := "test-watcher" - - // Create a channel to control GetEvents - getEventsCh := make(chan struct{}) - - //// Set up the mockStore to block on GetEvents - s.mockStore.WithGetEventsInterceptor(func() error { - <-getEventsCh - return nil - }) - - l, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().NoError(err) - - // Ensure the watcher is running - s.Eventually(func() bool { - return l.Stats().State == watcher.StateRunning - }, 200*time.Millisecond, 10*time.Millisecond) - - // Create a very short timeout - ctxWithTimeout, cancel := context.WithTimeout(ctx, 1*time.Nanosecond) - defer cancel() - - err = s.registry.Stop(ctxWithTimeout) - s.Require().Error(err) - s.Equal(context.DeadlineExceeded, err) - - // Ensure the watcher is stopping - s.Eventually(func() bool { - return l.Stats().State == watcher.StateStopping - }, 200*time.Millisecond, 10*time.Millisecond) - - // sleep and verify that the watcher is still stopping - time.Sleep(100 * time.Millisecond) - s.Require().Equal(watcher.StateStopping, l.Stats().State) - - // Unblock GetEvents - close(getEventsCh) - - // Ensure the watcher is stopped - s.Eventually(func() bool { - return l.Stats().State == watcher.StateStopped - }, 200*time.Millisecond, 10*time.Millisecond) -} - -func (s *RegistryTestSuite) TestWatcherProcessesEvents() { - ctx := context.Background() - watcherID := "test-watcher" - - events := []watcher.StoreEventRequest{ - {Operation: watcher.OperationCreate, ObjectType: "TestObject", Object: "test1"}, - {Operation: watcher.OperationUpdate, ObjectType: "TestObject", Object: "test2"}, - } - - for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) - s.Require().NoError(err) - } - - s.mockHandler.EXPECT().HandleEvent(gomock.Any(), gomock.Any()).Return(nil).Times(2) - - _, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().NoError(err) - - // Wait for events to be processed - time.Sleep(100 * time.Millisecond) - - err = s.registry.Stop(ctx) - s.Require().NoError(err) -} - -func (s *RegistryTestSuite) TestMultipleWatchers() { - ctx := context.Background() - watcherID1 := "test-watcher-1" - watcherID2 := "test-watcher-2" - - events := []watcher.StoreEventRequest{ - {Operation: watcher.OperationCreate, ObjectType: "TestObject", Object: "test1"}, - {Operation: watcher.OperationUpdate, ObjectType: "TestObject", Object: "test2"}, - } - - for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) - s.Require().NoError(err) - } - - s.mockHandler.EXPECT().HandleEvent(gomock.Any(), gomock.Any()).Return(nil).Times(4) - - l1, err := s.registry.Watch(ctx, watcherID1, s.mockHandler) - s.Require().NoError(err) - l2, err := s.registry.Watch(ctx, watcherID2, s.mockHandler) - s.Require().NoError(err) - - time.Sleep(100 * time.Millisecond) - - // Stop one watcher and ensure the other is still running - l1.Stop(ctx) - s.Eventually(func() bool { return l1.Stats().State == watcher.StateStopped }, time.Second, 10*time.Millisecond) - s.Equal(watcher.StateRunning, l2.Stats().State) - - l2.Stop(ctx) -} - -func (s *RegistryTestSuite) TestEventStoreErrors() { - ctx := context.Background() - watcherID := "test-watcher" - - // Test GetCheckpoint error - s.mockStore.WithGetCheckpointInterceptor(func() error { - return errors.New("checkpoint error") - }) - _, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().Error(err) - s.Contains(err.Error(), "checkpoint error") - - // Reset checkpoint error - s.mockStore.WithGetCheckpointInterceptor(nil) - - // Test GetEvents error - s.mockStore.WithGetEventsInterceptor(func() error { - return errors.New("et events error") - }) - - l, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().NoError(err) - time.Sleep(100 * time.Millisecond) - s.Equal(watcher.StateRunning, l.Stats().State) // The watcher should keep running despite errors -} - -func (s *RegistryTestSuite) TestRestartStoppedWatcher() { - ctx := context.Background() - watcherID := "test-watcher" - - l, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().NoError(err) - - // Wait for the watcher to start - s.Eventually(func() bool { return l.Stats().State == watcher.StateRunning }, 200*time.Millisecond, 10*time.Millisecond) - - l.Stop(ctx) - s.Eventually(func() bool { return l.Stats().State == watcher.StateStopped }, 200*time.Second, 10*time.Millisecond) - - // Try to create a new watcher with the same ID - newL, err := s.registry.Watch(ctx, watcherID, s.mockHandler) - s.Require().NoError(err) - s.NotEqual(l, newL) -} - -func (s *RegistryTestSuite) TestStoppingWatcherMultipleTimes() { - ctx := context.Background() - - err := s.registry.Stop(ctx) - s.Require().NoError(err) - - // Stopping an already stopped registry should not cause issues - err = s.registry.Stop(ctx) - s.Require().NoError(err) -} - -func TestRegistrySuite(t *testing.T) { - suite.Run(t, new(RegistryTestSuite)) -} diff --git a/pkg/lib/watcher/serializer.go b/pkg/lib/watcher/serializer.go index 78e6d0d837..1c4f5af391 100644 --- a/pkg/lib/watcher/serializer.go +++ b/pkg/lib/watcher/serializer.go @@ -28,7 +28,7 @@ func NewJSONSerializer() *JSONSerializer { } } -// RegisterType adds a new type to the serializer's type registry +// RegisterType adds a new type to the serializer's type manager // It returns an error if the type is already registered or if the provided type is invalid func (s *JSONSerializer) RegisterType(name string, t reflect.Type) error { if _, exists := s.typeRegistry[name]; exists { @@ -44,7 +44,7 @@ func (s *JSONSerializer) RegisterType(name string, t reflect.Type) error { return nil } -// IsTypeRegistered checks if a type is registered in the serializer's type registry +// IsTypeRegistered checks if a type is registered in the serializer's type manager func (s *JSONSerializer) IsTypeRegistered(name string) bool { _, exists := s.typeRegistry[name] return exists diff --git a/pkg/lib/watcher/types.go b/pkg/lib/watcher/types.go index 6173644f8f..002c52d8ea 100644 --- a/pkg/lib/watcher/types.go +++ b/pkg/lib/watcher/types.go @@ -19,6 +19,7 @@ type Stats struct { ID string State State NextEventIterator EventIterator // Next event iterator for the watcher + CheckpointIterator EventIterator // Checkpoint iterator for the watcher LastProcessedSeqNum uint64 // SeqNum of the last event processed by this watcher LastProcessedEventTime time.Time // timestamp of the last processed event LastListenTime time.Time // timestamp of the last successful listen operation @@ -39,6 +40,15 @@ type Watcher interface { // Stats returns the current statistics for the watcher. Stats() Stats + // SetHandler sets the event handler for the watcher. Must be set before calling Start. + // Will fail if the handler is already set. + // Returns error if handler is nil or already configured. + SetHandler(handler EventHandler) error + + // Start begins processing events. + // Returns error if no handler configured or already running. + Start(ctx context.Context) error + // Stop gracefully stops the watcher. Stop(ctx context.Context) @@ -46,6 +56,7 @@ type Watcher interface { Checkpoint(ctx context.Context, eventSeqNum uint64) error // SeekToOffset moves the watcher to a specific event sequence number. + // Will stop and restart the watcher if running. SeekToOffset(ctx context.Context, eventSeqNum uint64) error } @@ -57,16 +68,16 @@ type EventHandler interface { HandleEvent(ctx context.Context, event Event) error } -// Registry manages multiple event watchers and provides methods to watch for events. -type Registry interface { - // Watch starts watching for events with the given options. - // It returns a Watcher that can be used to receive events. - Watch(ctx context.Context, watcherID string, handler EventHandler, opts ...WatchOption) (Watcher, error) +// Manager handles lifecycle of multiple watchers with shared resources +type Manager interface { + // Create creates a new not-started watcher with the given ID and options. + // The watcher must be configured with a handler before it can start watching. + Create(ctx context.Context, watcherID string, opts ...WatchOption) (Watcher, error) - // GetWatcher retrieves an existing watcher by its ID. - GetWatcher(watcherID string) (Watcher, error) + // Lookup retrieves an existing watcher by its ID. + Lookup(ctx context.Context, watcherID string) (Watcher, error) - // Stop gracefully shuts down the registry and all its watchers. + // Stop gracefully shuts down the manager and all its watchers. Stop(ctx context.Context) error } diff --git a/pkg/lib/watcher/watcher.go b/pkg/lib/watcher/watcher.go index 19a7695d02..3b99391d7c 100644 --- a/pkg/lib/watcher/watcher.go +++ b/pkg/lib/watcher/watcher.go @@ -3,6 +3,7 @@ package watcher import ( "context" "errors" + "fmt" mathgo "math" "sync" "time" @@ -19,7 +20,8 @@ type watcher struct { store EventStore // event store for fetching events and checkpoints options *watchOptions - nextEventIterator EventIterator + nextEventIterator EventIterator // for processing + checkpointIterator EventIterator // for confirmed checkpoints lastProcessedSeqNum uint64 lastProcessedEventTime time.Time lastListenTime time.Time @@ -30,8 +32,8 @@ type watcher struct { mu sync.RWMutex } -// newWatcher creates a new watcher with the given parameters -func newWatcher(ctx context.Context, id string, handler EventHandler, store EventStore, opts ...WatchOption) (*watcher, error) { +// New creates a new watcher with the given parameters +func New(ctx context.Context, id string, store EventStore, opts ...WatchOption) (Watcher, error) { options := defaultWatchOptions() for _, opt := range opts { opt(options) @@ -43,19 +45,37 @@ func newWatcher(ctx context.Context, id string, handler EventHandler, store Even w := &watcher{ id: id, - handler: handler, store: store, options: options, state: StateIdle, + stopped: make(chan struct{}), } + // Initially close stopped channel since the watcher starts in idle/stopped state + close(w.stopped) + // Determine the starting iterator iterator, err := w.determineStartingIterator(ctx, options.initialEventIterator) if err != nil { return nil, NewWatcherError(id, err) } + w.checkpointIterator = iterator w.nextEventIterator = iterator + // set the handler if provided + if options.handler != nil { + if err = w.SetHandler(options.handler); err != nil { + return nil, err + } + } + + // Auto-start if requested and handler is set + if options.autoStart { + if err = w.Start(ctx); err != nil { + return nil, NewWatcherError(id, fmt.Errorf("failed to auto-start watcher: %w", err)) + } + } + return w, nil } @@ -95,26 +115,47 @@ func (w *watcher) Stats() Stats { ID: w.id, State: w.state, NextEventIterator: w.nextEventIterator, + CheckpointIterator: w.checkpointIterator, LastProcessedSeqNum: w.lastProcessedSeqNum, LastProcessedEventTime: w.lastProcessedEventTime, LastListenTime: w.lastListenTime, } } +// SetHandler sets the event handler for this watcher +func (w *watcher) SetHandler(handler EventHandler) error { + if handler == nil { + return errors.New("handler cannot be nil") + } + + w.mu.Lock() + defer w.mu.Unlock() + + if w.handler != nil { + return ErrHandlerExists + } + + w.handler = handler + return nil +} + // Start begins the event listening process -func (w *watcher) Start() { +func (w *watcher) Start(ctx context.Context) error { w.mu.Lock() if w.state != StateIdle && w.state != StateStopped { - log.Warn().Str("watcher_id", w.id).Str("state", string(w.state)). - Msg("watcher already running/stopped, skipping start") w.mu.Unlock() - return + return NewWatcherError(w.id, fmt.Errorf("cannot start watcher in state %s", w.state)) } - var ctx context.Context - ctx, w.cancel = context.WithCancel(context.Background()) - w.stopped = make(chan struct{}, 1) + if w.handler == nil { + w.mu.Unlock() + return NewWatcherError(w.id, ErrNoHandler) + } + + ctx, w.cancel = context.WithCancel(ctx) + w.stopped = make(chan struct{}) w.state = StateRunning + w.nextEventIterator = w.checkpointIterator log.Ctx(ctx).Debug(). Str("watcher_id", w.ID()). Str("starting_at", w.nextEventIterator.String()). @@ -122,11 +163,17 @@ func (w *watcher) Start() { Msg("starting watcher") w.mu.Unlock() + go w.run(ctx) + return nil +} + +// run is the main event processing loop +func (w *watcher) run(ctx context.Context) { defer func() { w.mu.Lock() w.state = StateStopped w.mu.Unlock() - w.stopped <- struct{}{} + close(w.stopped) }() for { @@ -136,6 +183,9 @@ func (w *watcher) Start() { default: response, err := w.fetchWithBackoff(ctx) if err != nil { + if errors.Is(err, context.Canceled) { + return + } continue } @@ -230,12 +280,12 @@ func (w *watcher) updateLastProcessedEvent(event Event) { // Stop gracefully stops the watcher func (w *watcher) Stop(ctx context.Context) { w.mu.Lock() - if w.state != StateRunning { - log.Warn().Str("watcher_id", w.id).Str("state", string(w.state)). - Msg("watcher not running, skipping stop") + if w.state == StateStopped || w.state == StateIdle { + w.state = StateStopped w.mu.Unlock() return } + w.state = StateStopping w.mu.Unlock() @@ -254,25 +304,36 @@ func (w *watcher) Stop(ctx context.Context) { // Checkpoint saves the current progress of the watcher func (w *watcher) Checkpoint(ctx context.Context, eventSeqNum uint64) error { - return w.store.StoreCheckpoint(ctx, w.id, eventSeqNum) + if err := w.store.StoreCheckpoint(ctx, w.id, eventSeqNum); err != nil { + return err + } + log.Ctx(ctx).Trace().Str("watcher_id", w.id).Uint64("event_seq", eventSeqNum). + Msg("checkpoint saved") + + // Update checkpoint iterator after successful store + w.mu.Lock() + w.checkpointIterator = AfterSequenceNumberIterator(eventSeqNum) + w.mu.Unlock() + + return nil } // SeekToOffset moves the watcher to a specific event sequence number func (w *watcher) SeekToOffset(ctx context.Context, eventSeqNum uint64) error { + log.Ctx(ctx).Debug().Str("watcher_id", w.id).Uint64("event_seq", eventSeqNum). + Msg("seeking to event sequence number") // stop the watcher so that it doesn't process events while we're updating the offset w.Stop(ctx) - // update the offset - w.nextEventIterator = AfterSequenceNumberIterator(eventSeqNum) - // persist the offset so that the watcher resumes at the correct position if started - if err := w.store.StoreCheckpoint(ctx, w.id, eventSeqNum); err != nil { - log.Ctx(ctx).Error().Err(err).Str("watcher_id", w.id). - Msg("seek failed to persist offset. Watcher might not resume at the correct position") + if err := w.Checkpoint(ctx, eventSeqNum); err != nil { + return NewCheckpointError(w.id, fmt.Errorf("failed to persist seek offset: %w", err)) } - // restart the watcher - go w.Start() + // Restart watcher + if err := w.Start(ctx); err != nil { + return NewWatcherError(w.id, fmt.Errorf("failed to restart watcher after seek: %w", err)) + } return nil } diff --git a/pkg/lib/watcher/watcher_test.go b/pkg/lib/watcher/watcher_test.go index 5b9086728a..97220b7d7c 100644 --- a/pkg/lib/watcher/watcher_test.go +++ b/pkg/lib/watcher/watcher_test.go @@ -21,9 +21,10 @@ import ( type WatcherTestSuite struct { suite.Suite ctrl *gomock.Controller + ctx context.Context + cancel context.CancelFunc mockStore *watchertest.EventStoreWrapper mockHandler *watcher.MockEventHandler - registry watcher.Registry } func (s *WatcherTestSuite) SetupTest() { @@ -38,24 +39,30 @@ func (s *WatcherTestSuite) SetupTest() { ) s.Require().NoError(err) + s.ctx, s.cancel = context.WithTimeout(context.Background(), 5*time.Second) s.ctrl = gomock.NewController(s.T()) s.mockStore = watchertest.NewEventStoreWrapper(boltdbEventStore) s.mockHandler = watcher.NewMockEventHandler(s.ctrl) - s.registry = watcher.NewRegistry(s.mockStore) } func (s *WatcherTestSuite) TearDownTest() { s.ctrl.Finish() - s.registry.Stop(context.Background()) + s.cancel() } func (s *WatcherTestSuite) TestCreateWatcher() { - ctx := context.Background() - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler) + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) s.Require().NoError(err) s.Require().NotNil(w) s.Equal("test-watcher", w.ID()) - s.Eventually(func() bool { return w.Stats().State == watcher.StateRunning }, 200*time.Millisecond, 10*time.Millisecond) + + // Set handler after creation + err = w.SetHandler(s.mockHandler) + s.Require().NoError(err) + + // Start watcher async + s.Require().NoError(w.Start(s.ctx)) + s.Require().Eventually(func() bool { return w.Stats().State == watcher.StateRunning }, 200*time.Millisecond, 10*time.Millisecond) // verify stats stats := w.Stats() @@ -65,11 +72,34 @@ func (s *WatcherTestSuite) TestCreateWatcher() { s.Equal(time.Time{}, stats.LastProcessedEventTime) // Stop the watcher - w.Stop(ctx) - s.Eventually(func() bool { return w.Stats().State == watcher.StateStopped }, 200*time.Millisecond, 10*time.Millisecond) + w.Stop(s.ctx) + s.Require().Eventually(func() bool { return w.Stats().State == watcher.StateStopped }, 200*time.Millisecond, 10*time.Millisecond) +} + +func (s *WatcherTestSuite) TestSetHandlerErrors() { + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) + s.Require().NoError(err) + + // Test setting nil handler + err = w.SetHandler(nil) + s.Error(err) + + // Test setting handler twice + err = w.SetHandler(s.mockHandler) + s.NoError(err) + err = w.SetHandler(s.mockHandler) + s.Equal(watcher.ErrHandlerExists, err) } + +func (s *WatcherTestSuite) TestStartWithoutHandler() { + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) + s.Require().NoError(err) + + s.Require().Error(w.Start(s.ctx)) + s.Never(func() bool { return w.Stats().State == watcher.StateRunning }, 200*time.Millisecond, 10*time.Millisecond) +} + func (s *WatcherTestSuite) TestDetermineStartingIterator() { - ctx := context.Background() testCases := []struct { name string @@ -145,7 +175,7 @@ func (s *WatcherTestSuite) TestDetermineStartingIterator() { if tc.setupLatestEvent != nil { // Store events up to the desired sequence number for i := uint64(1); i <= *tc.setupLatestEvent; i++ { - err := s.mockStore.StoreEvent(ctx, watcher.StoreEventRequest{ + err := s.mockStore.StoreEvent(s.ctx, watcher.StoreEventRequest{ Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: fmt.Sprintf("test%d", i), @@ -155,7 +185,7 @@ func (s *WatcherTestSuite) TestDetermineStartingIterator() { } if tc.setupCheckpoint != nil { - err := s.mockStore.StoreCheckpoint(ctx, "test-watcher", *tc.setupCheckpoint) + err := s.mockStore.StoreCheckpoint(s.ctx, "test-watcher", *tc.setupCheckpoint) s.Require().NoError(err) } @@ -171,7 +201,7 @@ func (s *WatcherTestSuite) TestDetermineStartingIterator() { }) } - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithInitialEventIterator(tc.initialIter)) if tc.expectedError { @@ -183,19 +213,22 @@ func (s *WatcherTestSuite) TestDetermineStartingIterator() { "Iterator mismatch - expected: %s, got: %s", tc.expectedIter.String(), w.Stats().NextEventIterator.String()) + s.Equal(tc.expectedIter, w.Stats().CheckpointIterator, + "Checkpoint iterator mismatch - expected: %s, got: %s", + tc.expectedIter.String(), + w.Stats().CheckpointIterator.String()) }) } } func (s *WatcherTestSuite) TestWatcherProcessEvents() { - ctx := context.Background() events := []watcher.StoreEventRequest{ {Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test1"}, {Operation: watcher.OperationUpdate, ObjectType: "StringObject", Object: "test2"}, } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } @@ -204,14 +237,15 @@ func (s *WatcherTestSuite) TestWatcherProcessEvents() { s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(2)).Return(nil).Times(1), ) - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler) + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.Require().NoError(w.Start(s.ctx)) - s.waitAndStop(ctx, w, 2) + s.waitAndStop(s.ctx, w, 2) } func (s *WatcherTestSuite) TestWithStartSeqNum() { - ctx := context.Background() events := []watcher.StoreEventRequest{ {Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test1"}, {Operation: watcher.OperationUpdate, ObjectType: "StringObject", Object: "test2"}, @@ -219,7 +253,7 @@ func (s *WatcherTestSuite) TestWithStartSeqNum() { } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } @@ -228,15 +262,15 @@ func (s *WatcherTestSuite) TestWithStartSeqNum() { s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(3)).Return(nil).Times(1), ) - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithInitialEventIterator(watcher.AfterSequenceNumberIterator(1))) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) - s.waitAndStop(ctx, w, 3) + s.startWaitAndStop(s.ctx, w, 3) } func (s *WatcherTestSuite) TestWithFilter() { - ctx := context.Background() filter := watcher.EventFilter{ ObjectTypes: []string{"StringObject"}, Operations: []watcher.Operation{watcher.OperationCreate, watcher.OperationUpdate}, @@ -250,7 +284,7 @@ func (s *WatcherTestSuite) TestWithFilter() { } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } @@ -260,22 +294,53 @@ func (s *WatcherTestSuite) TestWithFilter() { s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(4)).Return(nil).Times(1), ) - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, watcher.WithFilter(filter)) + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithFilter(filter)) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) - s.waitAndStop(ctx, w, 5) + s.startWaitAndStop(s.ctx, w, 5) s.Equal(uint64(4), w.Stats().LastProcessedSeqNum) // last event not processed } +func (s *WatcherTestSuite) TestWithHandlerAndAutoStart() { + s.Run("WithAutoStart requires handler", func() { + // Should fail because no handler is set + _, err := watcher.New(s.ctx, "test-watcher", s.mockStore, + watcher.WithAutoStart()) + s.Require().Error(err) + s.Contains(err.Error(), "handler must be set when autoStart is enabled") + }) + + s.Run("WithHandler and WithAutoStart starts automatically", func() { + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, + watcher.WithHandler(s.mockHandler), + watcher.WithAutoStart()) + s.Require().NoError(err) + + // Should be automatically running + s.Require().Eventually(func() bool { + return w.Stats().State == watcher.StateRunning + }, 200*time.Millisecond, 10*time.Millisecond) + + w.Stop(s.ctx) + }) + + s.Run("WithHandler only does not auto-start", func() { + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, + watcher.WithHandler(s.mockHandler)) + s.Require().NoError(err) + s.Equal(watcher.StateIdle, w.Stats().State) + }) +} + func (s *WatcherTestSuite) TestCheckpoint() { - ctx := context.Background() events := []watcher.StoreEventRequest{ {Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test1"}, {Operation: watcher.OperationUpdate, ObjectType: "StringObject", Object: "test2"}, } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } @@ -283,34 +348,77 @@ func (s *WatcherTestSuite) TestCheckpoint() { s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(1)).Return(nil).Times(1), s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(2)).Return(nil).Times(1), ) - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler) + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) - s.wait(ctx, w, 2) + s.startAndWait(s.ctx, w, 2) // Manually checkpoint - err = w.Checkpoint(ctx, 1) + err = w.Checkpoint(s.ctx, 1) s.Require().NoError(err) - w.Stop(ctx) - - // Verify the checkpoint - checkpoint, err := s.mockStore.GetCheckpoint(ctx, w.ID()) + // Verify both checkpoint and checkpointIterator + checkpoint, err := s.mockStore.GetCheckpoint(s.ctx, w.ID()) s.Require().NoError(err) s.Equal(uint64(1), checkpoint) + s.Equal(watcher.AfterSequenceNumberIterator(1), w.Stats().CheckpointIterator) + + w.Stop(s.ctx) - // Start a new watcher and verify that it starts from the checkpoint + // Start a new watcher and verify it starts from checkpoint newHandler := watcher.NewMockEventHandler(s.ctrl) newHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(2)).Return(nil).Times(1) - w, err = s.registry.Watch(ctx, "test-watcher", newHandler) + w, err = watcher.New(s.ctx, "test-watcher", s.mockStore) + s.Require().NoError(err) + s.Require().NoError(w.SetHandler(newHandler)) + + // Verify both iterators start at checkpoint + s.Equal(watcher.AfterSequenceNumberIterator(1), w.Stats().CheckpointIterator) + s.Equal(watcher.AfterSequenceNumberIterator(1), w.Stats().NextEventIterator) + + s.startWaitAndStop(s.ctx, w, 2) +} + +func (s *WatcherTestSuite) TestRestartFromCheckpoint() { + events := []watcher.StoreEventRequest{ + {Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test1"}, + {Operation: watcher.OperationUpdate, ObjectType: "StringObject", Object: "test2"}, + } + + for _, event := range events { + err := s.mockStore.StoreEvent(s.ctx, event) + s.Require().NoError(err) + } + + // First start and process events + gomock.InOrder( + s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(1)).Return(nil).Times(1), + s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(2)).Return(nil).Times(1), + ) + + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.Require().NoError(w.Start(s.ctx)) - s.waitAndStop(ctx, w, 2) + // Wait and checkpoint at 1 + s.wait(s.ctx, w, 2) + s.Require().NoError(w.Checkpoint(s.ctx, 1)) + w.Stop(s.ctx) + + // Restart and verify it starts from checkpoint + s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(2)).Return(nil).Times(1) + + s.Require().NoError(w.Start(s.ctx)) + s.Equal(watcher.AfterSequenceNumberIterator(1), w.Stats().NextEventIterator) + s.Equal(watcher.AfterSequenceNumberIterator(1), w.Stats().CheckpointIterator) + + s.waitAndStop(s.ctx, w, 2) } func (s *WatcherTestSuite) TestSeekToOffset() { - ctx := context.Background() events := []watcher.StoreEventRequest{ {Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test1"}, {Operation: watcher.OperationUpdate, ObjectType: "StringObject", Object: "test2"}, @@ -318,7 +426,7 @@ func (s *WatcherTestSuite) TestSeekToOffset() { } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } @@ -329,25 +437,23 @@ func (s *WatcherTestSuite) TestSeekToOffset() { // last event is processed twice after seek s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(3)).Return(nil).Times(2), ) - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler) + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) s.Require().NoError(err) - s.wait(ctx, w, 3) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.startAndWait(s.ctx, w, 3) // Seek to offset 2 - err = w.SeekToOffset(ctx, 2) - s.Require().NoError(err) + s.Require().NoError(w.SeekToOffset(s.ctx, 2)) // Verify the checkpoint - checkpoint, err := s.mockStore.GetCheckpoint(ctx, w.ID()) + checkpoint, err := s.mockStore.GetCheckpoint(s.ctx, w.ID()) s.Require().NoError(err) s.Equal(uint64(2), checkpoint) - s.waitAndStop(ctx, w, 3) + s.waitAndStop(s.ctx, w, 3) } func (s *WatcherTestSuite) TestCheckpointAndStartSeqNum() { - ctx := context.Background() - testCases := []struct { name string checkpoint uint64 @@ -408,13 +514,13 @@ func (s *WatcherTestSuite) TestCheckpointAndStartSeqNum() { } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } // Set up the checkpoint if provided if tc.checkpoint != 0 { - err := s.mockStore.StoreCheckpoint(ctx, "test-watcher", tc.checkpoint) + err := s.mockStore.StoreCheckpoint(s.ctx, "test-watcher", tc.checkpoint) s.Require().NoError(err) } @@ -424,25 +530,25 @@ func (s *WatcherTestSuite) TestCheckpointAndStartSeqNum() { } // Start the watcher - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithInitialEventIterator(watcher.AfterSequenceNumberIterator(tc.startSeqNum))) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) // Wait for processing and stop - s.waitAndStop(ctx, w, uint64(5)) + s.startWaitAndStop(s.ctx, w, uint64(5)) }) } } func (s *WatcherTestSuite) TestHandleEventErrorWithBlockStrategy() { - ctx := context.Background() events := []watcher.StoreEventRequest{ {Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test1"}, {Operation: watcher.OperationUpdate, ObjectType: "StringObject", Object: "test2"}, } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } @@ -461,20 +567,20 @@ func (s *WatcherTestSuite) TestHandleEventErrorWithBlockStrategy() { s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(2)).Return(nil).Times(1), ) - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithMaxRetries(3), // will be ignored with block strategy watcher.WithInitialBackoff(1*time.Nanosecond), watcher.WithMaxBackoff(1*time.Nanosecond), watcher.WithRetryStrategy(watcher.RetryStrategyBlock)) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) - s.waitAndStop(ctx, w, 2) + s.startWaitAndStop(s.ctx, w, 2) s.Equal(uint64(2), w.Stats().LastProcessedSeqNum) s.Equal(maxFails+1, failCount) // Verify that it failed 100 times before succeeding } func (s *WatcherTestSuite) TestDifferentIteratorTypes() { - ctx := context.Background() events := []watcher.StoreEventRequest{ {Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test1"}, {Operation: watcher.OperationUpdate, ObjectType: "StringObject", Object: "test2"}, @@ -483,7 +589,7 @@ func (s *WatcherTestSuite) TestDifferentIteratorTypes() { } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } @@ -521,17 +627,17 @@ func (s *WatcherTestSuite) TestDifferentIteratorTypes() { s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(seqNum)).Return(nil).Times(1) } - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithInitialEventIterator(tc.iterator)) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) - s.waitAndStop(ctx, w, 4) + s.startWaitAndStop(s.ctx, w, 4) }) } } func (s *WatcherTestSuite) TestEmptyEventStoreWithDifferentIterators() { - ctx := context.Background() testCases := []struct { name string @@ -570,10 +676,11 @@ func (s *WatcherTestSuite) TestEmptyEventStoreWithDifferentIterators() { s.mockHandler = watcher.NewMockEventHandler(s.ctrl) // Test with empty store - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithInitialEventIterator(tc.iterator)) s.Require().NoError(err) - s.waitAndStop(ctx, w, tc.expectedIterator.SequenceNumber) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.startWaitAndStop(s.ctx, w, tc.expectedIterator.SequenceNumber) // Check if the next event iterator matches the expected iterator s.Equal(tc.expectedIterator, w.Stats().NextEventIterator) @@ -582,14 +689,13 @@ func (s *WatcherTestSuite) TestEmptyEventStoreWithDifferentIterators() { } func (s *WatcherTestSuite) TestHandleEventErrorWithSkipStrategy() { - ctx := context.Background() events := []watcher.StoreEventRequest{ {Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test1"}, {Operation: watcher.OperationUpdate, ObjectType: "StringObject", Object: "test2"}, } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } @@ -598,18 +704,18 @@ func (s *WatcherTestSuite) TestHandleEventErrorWithSkipStrategy() { s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(2)).Return(nil).Times(1), ) - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithMaxRetries(3), watcher.WithInitialBackoff(1*time.Nanosecond), watcher.WithRetryStrategy(watcher.RetryStrategySkip)) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) - s.waitAndStop(ctx, w, 2) + s.startWaitAndStop(s.ctx, w, 2) s.Equal(uint64(2), w.Stats().LastProcessedSeqNum) } func (s *WatcherTestSuite) TestBatchOptions() { - ctx := context.Background() testCases := []struct { name string @@ -629,7 +735,7 @@ func (s *WatcherTestSuite) TestBatchOptions() { s.SetupTest() for i := 0; i < tc.eventCount; i++ { - s.Require().NoError(s.mockStore.StoreEvent(ctx, watcher.StoreEventRequest{ + s.Require().NoError(s.mockStore.StoreEvent(s.ctx, watcher.StoreEventRequest{ Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: fmt.Sprintf("test%d", i+1), })) } @@ -640,21 +746,21 @@ func (s *WatcherTestSuite) TestBatchOptions() { return nil }) - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithBatchSize(tc.batchSize), ) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) s.mockHandler.EXPECT().HandleEvent(gomock.Any(), gomock.Any()).Return(nil).Times(tc.eventCount) - s.waitAndStop(ctx, w, uint64(tc.eventCount)) + s.startWaitAndStop(s.ctx, w, uint64(tc.eventCount)) s.Equal(tc.expectedCalls+1, getEventsCallCount) // one extra call longpolling for new events }) } } func (s *WatcherTestSuite) TestFilterEdgeCases() { - ctx := context.Background() testCases := []struct { name string @@ -702,32 +808,32 @@ func (s *WatcherTestSuite) TestFilterEdgeCases() { s.Run(tc.name, func() { for _, event := range tc.events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithFilter(tc.filter)) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) s.mockHandler.EXPECT().HandleEvent(gomock.Any(), gomock.Any()). Return(nil).Times(tc.expectedEvents) - s.waitAndStop(ctx, w, uint64(len(tc.events))) + s.startWaitAndStop(s.ctx, w, uint64(len(tc.events))) s.Equal(uint64(tc.expectedEvents), w.Stats().LastProcessedSeqNum) }) } } func (s *WatcherTestSuite) TestRestartBehavior() { - ctx := context.Background() events := []watcher.StoreEventRequest{ {Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test1"}, {Operation: watcher.OperationUpdate, ObjectType: "StringObject", Object: "test2"}, } for _, event := range events { - err := s.mockStore.StoreEvent(ctx, event) + err := s.mockStore.StoreEvent(s.ctx, event) s.Require().NoError(err) } @@ -738,19 +844,21 @@ func (s *WatcherTestSuite) TestRestartBehavior() { handler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(2)).Return(nil).Times(1), ) - w, err := s.registry.Watch(ctx, "test-watcher", handler) + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) s.Require().NoError(err) - s.waitAndStop(ctx, w, 2) + + s.Require().NoError(w.SetHandler(handler)) + s.startWaitAndStop(s.ctx, w, 2) } } func (s *WatcherTestSuite) TestListenAndStoreConcurrently() { - ctx := context.Background() - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithBatchSize(7)) s.Require().NoError(err) - s.wait(ctx, w, 0) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.startAndWait(s.ctx, w, 0) eventCount := 100 @@ -764,7 +872,7 @@ func (s *WatcherTestSuite) TestListenAndStoreConcurrently() { done := make(chan struct{}) go func() { for i := 0; i < eventCount; i++ { - s.Require().NoError(s.mockStore.StoreEvent(ctx, watcher.StoreEventRequest{ + s.Require().NoError(s.mockStore.StoreEvent(s.ctx, watcher.StoreEventRequest{ Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: fmt.Sprintf("test%d", i), @@ -776,25 +884,26 @@ func (s *WatcherTestSuite) TestListenAndStoreConcurrently() { <-done gomock.InOrder(expectations...) - s.waitAndStop(ctx, w, uint64(eventCount)) + s.waitAndStop(s.ctx, w, uint64(eventCount)) } func (s *WatcherTestSuite) TestEventStoreConsistency() { - ctx := context.Background() - w, err := s.registry.Watch(ctx, "test-watcher", s.mockHandler, + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore, watcher.WithBatchSize(2)) s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.startAndWait(s.ctx, w, 0) // Set up expectations gomock.InOrder( s.mockHandler.EXPECT().HandleEvent(gomock.Any(), watchertest.EventWithSeqNum(1)).DoAndReturn( func(ctx context.Context, event watcher.Event) error { // Store more events while processing - s.Require().NoError(s.mockStore.StoreEvent(ctx, watcher.StoreEventRequest{ + s.Require().NoError(s.mockStore.StoreEvent(s.ctx, watcher.StoreEventRequest{ Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test", })) - s.Require().NoError(s.mockStore.StoreEvent(ctx, watcher.StoreEventRequest{ + s.Require().NoError(s.mockStore.StoreEvent(s.ctx, watcher.StoreEventRequest{ Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test", })) return nil @@ -805,26 +914,115 @@ func (s *WatcherTestSuite) TestEventStoreConsistency() { ) // Store initial events - s.Require().NoError(s.mockStore.StoreEvent(ctx, watcher.StoreEventRequest{ + s.Require().NoError(s.mockStore.StoreEvent(s.ctx, watcher.StoreEventRequest{ Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test", })) - s.Require().NoError(s.mockStore.StoreEvent(ctx, watcher.StoreEventRequest{ + s.Require().NoError(s.mockStore.StoreEvent(s.ctx, watcher.StoreEventRequest{ Operation: watcher.OperationCreate, ObjectType: "StringObject", Object: "test", })) - s.waitAndStop(ctx, w, 4) + s.waitAndStop(s.ctx, w, 4) +} + +func (s *WatcherTestSuite) TestStopStates() { + + s.Run("stop running watcher", func() { + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) + s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.startAndWait(s.ctx, w, 0) + + w.Stop(s.ctx) + s.Equal(watcher.StateStopped, w.Stats().State) + }) + + s.Run("stop stopped watcher", func() { + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) + s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.startAndWait(s.ctx, w, 0) + + w.Stop(s.ctx) + s.Equal(watcher.StateStopped, w.Stats().State) + + // Stop again + w.Stop(s.ctx) + s.Equal(watcher.StateStopped, w.Stats().State) + }) + + s.Run("stop not-started watcher", func() { + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) + s.Require().NoError(err) + s.Equal(watcher.StateIdle, w.Stats().State) + + w.Stop(s.ctx) + s.Equal(watcher.StateStopped, w.Stats().State) + }) + + s.Run("concurrent stop calls", func() { + // Create a channel to control GetEvents + getEventsCh := make(chan struct{}) + + // Set up the mockStore to block on GetEvents + s.mockStore.WithGetEventsInterceptor(func() error { + <-getEventsCh + return nil + }) + + w, err := watcher.New(s.ctx, "test-watcher", s.mockStore) + s.Require().NoError(err) + s.Require().NoError(w.SetHandler(s.mockHandler)) + s.startAndWait(s.ctx, w, 0) + + // Start first stop operation - will be blocked due to GetEvents + go w.Stop(s.ctx) + + // Wait for watcher to enter stopping state + s.Eventually(func() bool { + return w.Stats().State == watcher.StateStopping + }, 200*time.Millisecond, 10*time.Millisecond) + + // Try second stop while first one is still in progress + go w.Stop(s.ctx) + + // State should still be stopping + s.Eventually(func() bool { + return w.Stats().State == watcher.StateStopping + }, 200*time.Millisecond, 10*time.Millisecond) + + // Unblock GetEvents + close(getEventsCh) + + // Now watcher should transition to stopped + s.Eventually(func() bool { + return w.Stats().State == watcher.StateStopped + }, 200*time.Millisecond, 10*time.Millisecond) + }) } func (s *WatcherTestSuite) wait(ctx context.Context, w watcher.Watcher, continuationSeqNum uint64) { - // wait for the watcher to consume the events - s.Eventually(func() bool { return w.Stats().NextEventIterator.SequenceNumber == continuationSeqNum }, 1*time.Second, 10*time.Millisecond) - s.Equal(watcher.StateRunning, w.Stats().State) + s.Require().Eventually(func() bool { + return w.Stats().State == watcher.StateRunning && + w.Stats().NextEventIterator.SequenceNumber == continuationSeqNum + }, 1*time.Second, 10*time.Millisecond) +} + +// startWaitAndStop +func (s *WatcherTestSuite) startAndWait(ctx context.Context, w watcher.Watcher, continuationSeqNum uint64) { + s.Require().NoError(w.Start(s.ctx)) + s.wait(s.ctx, w, continuationSeqNum) +} + +// startWaitAndStop +func (s *WatcherTestSuite) startWaitAndStop(ctx context.Context, w watcher.Watcher, continuationSeqNum uint64) { + s.Require().NoError(w.Start(s.ctx)) + s.waitAndStop(s.ctx, w, continuationSeqNum) } func (s *WatcherTestSuite) waitAndStop(ctx context.Context, w watcher.Watcher, continuationSeqNum uint64) { - s.wait(ctx, w, continuationSeqNum) - w.Stop(ctx) - s.Eventually(func() bool { return w.Stats().State == watcher.StateStopped }, 1*time.Second, 10*time.Millisecond) + s.wait(s.ctx, w, continuationSeqNum) + w.Stop(s.ctx) + s.Require().Eventually(func() bool { return w.Stats().State == watcher.StateStopped }, 1*time.Second, 10*time.Millisecond) } // helper function to get pointer to uint64 diff --git a/pkg/node/compute.go b/pkg/node/compute.go index ad12f01590..a7879d8a63 100644 --- a/pkg/node/compute.go +++ b/pkg/node/compute.go @@ -44,7 +44,7 @@ type Compute struct { Storages storage.StorageProvider Publishers publisher.PublisherProvider Bidder compute.Bidder - Watchers watcher.Registry + Watchers watcher.Manager ManagementClient *compute.ManagementClient cleanupFunc func(ctx context.Context) nodeInfoDecorator models.NodeInfoDecorator @@ -385,12 +385,13 @@ func setupComputeWatchers( computeCallback compute.Callback, bufferRunner *compute.ExecutorBuffer, bidder compute.Bidder, -) (watcher.Registry, error) { - watcherRegistry := watcher.NewRegistry(executionStore.GetEventStore()) +) (watcher.Manager, error) { + watcherRegistry := watcher.NewManager(executionStore.GetEventStore()) // Set up execution logger watcher - _, err := watcherRegistry.Watch(ctx, computeExecutionLoggerWatcherID, - watchers.NewExecutionLogger(log.Logger), + _, err := watcherRegistry.Create(ctx, computeExecutionLoggerWatcherID, + watcher.WithHandler(watchers.NewExecutionLogger(log.Logger)), + watcher.WithAutoStart(), watcher.WithInitialEventIterator(watcher.LatestIterator())) if err != nil { return nil, fmt.Errorf("failed to setup execution logger watcher: %w", err) @@ -405,7 +406,9 @@ func setupComputeWatchers( return nil, err } - _, err = watcherRegistry.Watch(ctx, computeToOrchestratorDispatcherWatcherID, dispatcher, + _, err = watcherRegistry.Create(ctx, computeToOrchestratorDispatcherWatcherID, + watcher.WithHandler(dispatcher), + watcher.WithAutoStart(), watcher.WithFilter(watcher.EventFilter{ ObjectTypes: []string{compute.EventObjectExecutionUpsert}, }), @@ -418,7 +421,9 @@ func setupComputeWatchers( // Set up execution handler watcher executionHandler := watchers.NewExecutionUpsertHandler(bufferRunner, bidder) - _, err = watcherRegistry.Watch(ctx, computeExecutionHandlerWatcherID, executionHandler, + _, err = watcherRegistry.Create(ctx, computeExecutionHandlerWatcherID, + watcher.WithHandler(executionHandler), + watcher.WithAutoStart(), watcher.WithFilter(watcher.EventFilter{ ObjectTypes: []string{compute.EventObjectExecutionUpsert}, }), diff --git a/pkg/node/requester.go b/pkg/node/requester.go index 3d94020bae..1243b9e949 100644 --- a/pkg/node/requester.go +++ b/pkg/node/requester.go @@ -431,11 +431,13 @@ func setupOrchestratorWatchers( evalBroker orchestrator.EvaluationBroker, nodeManager routing.NodeInfoStore, computeProxy compute.Endpoint, -) (watcher.Registry, error) { - watcherRegistry := watcher.NewRegistry(jobStore.GetEventStore()) +) (watcher.Manager, error) { + watcherRegistry := watcher.NewManager(jobStore.GetEventStore()) // Start watching for evaluation events using latest iterator - _, err := watcherRegistry.Watch(ctx, orchestratorEvaluationWatcherID, evaluation.NewWatchHandler(evalBroker), + _, err := watcherRegistry.Create(ctx, orchestratorEvaluationWatcherID, + watcher.WithHandler(evaluation.NewWatchHandler(evalBroker)), + watcher.WithAutoStart(), watcher.WithInitialEventIterator(watcher.LatestIterator()), watcher.WithFilter(watcher.EventFilter{ ObjectTypes: []string{jobstore.EventObjectEvaluation}, @@ -447,7 +449,9 @@ func setupOrchestratorWatchers( } // Set up execution logger watcher - _, err = watcherRegistry.Watch(ctx, orchestratorExecutionLoggerWatcherID, watchers.NewExecutionLogger(log.Logger), + _, err = watcherRegistry.Create(ctx, orchestratorExecutionLoggerWatcherID, + watcher.WithHandler(watchers.NewExecutionLogger(log.Logger)), + watcher.WithAutoStart(), watcher.WithFilter(watcher.EventFilter{ ObjectTypes: []string{jobstore.EventObjectExecutionUpsert}, }), @@ -477,7 +481,9 @@ func setupOrchestratorWatchers( } // TODO: Add checkpointing or else events will be missed - _, err = watcherRegistry.Watch(ctx, orchestratorToComputeDispatcherWatcherID, dispatcher, + _, err = watcherRegistry.Create(ctx, orchestratorToComputeDispatcherWatcherID, + watcher.WithHandler(dispatcher), + watcher.WithAutoStart(), watcher.WithFilter(watcher.EventFilter{ ObjectTypes: []string{jobstore.EventObjectExecutionUpsert}, }), diff --git a/pkg/orchestrator/evaluation/watcher_test.go b/pkg/orchestrator/evaluation/watcher_test.go index 4703a0082c..e7c19579c9 100644 --- a/pkg/orchestrator/evaluation/watcher_test.go +++ b/pkg/orchestrator/evaluation/watcher_test.go @@ -22,7 +22,7 @@ type WatchHandlerTestSuite struct { suite.Suite store jobstore.Store broker *evaluation.InMemoryBroker - registry watcher.Registry + registry watcher.Manager watchHandler *evaluation.WatchHandler ctx context.Context } @@ -37,11 +37,13 @@ func (s *WatchHandlerTestSuite) SetupTest() { s.Require().NoError(err) s.broker.SetEnabled(true) - s.registry = watcher.NewRegistry(s.store.GetEventStore()) + s.registry = watcher.NewManager(s.store.GetEventStore()) s.watchHandler = evaluation.NewWatchHandler(s.broker) // Start watching for events - w, err := s.registry.Watch(s.ctx, "test-watcher", s.watchHandler, + w, err := s.registry.Create(s.ctx, "test-watcher", + watcher.WithHandler(s.watchHandler), + watcher.WithAutoStart(), watcher.WithInitialEventIterator(watcher.LatestIterator()), watcher.WithFilter(watcher.EventFilter{ ObjectTypes: []string{jobstore.EventObjectEvaluation},