From 559e78a9c045b25a6caefe3e67cd04726c9895af Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Wed, 11 Dec 2024 17:22:23 +0200 Subject: [PATCH] always use orchestrator seqNum during handshake (#4740) ## Summary by CodeRabbit - **New Features** - Enhanced node management with improved event storage capabilities. - New methods added to validate handshake sequence number logic. - Updated dispatcher setup process to refine event handling. - **Bug Fixes** - Improved error handling during node manager initialization and dispatcher setup. - **Tests** - Expanded test suite to cover edge cases in handshake sequence number logic and event storage. - Added tests for handshake sequence number logic and concurrent operations. - **Documentation** - Updated method signatures to reflect new parameters and functionalities. --- pkg/node/requester.go | 11 +- pkg/orchestrator/nodes/manager.go | 42 ++++- pkg/orchestrator/nodes/manager_test.go | 146 ++++++++++++++++++ .../bprotocol/orchestrator/heartbeat_test.go | 7 + .../nclprotocol/orchestrator/dataplane.go | 1 + 5 files changed, 196 insertions(+), 11 deletions(-) diff --git a/pkg/node/requester.go b/pkg/node/requester.go index cdea80e0a4..e6f6f7264d 100644 --- a/pkg/node/requester.go +++ b/pkg/node/requester.go @@ -65,18 +65,18 @@ func NewRequesterNode( transportLayer *nats_transport.NATSTransport, metadataStore MetadataStore, ) (*Requester, error) { - natsConn, err := transportLayer.CreateClient(ctx) + jobStore, err := createJobStore(ctx, cfg) if err != nil { return nil, err } - nodeID := cfg.NodeID - nodesManager, nodeStore, err := createNodeManager(ctx, cfg, natsConn) + natsConn, err := transportLayer.CreateClient(ctx) if err != nil { return nil, err } - jobStore, err := createJobStore(ctx, cfg) + nodeID := cfg.NodeID + nodesManager, nodeStore, err := createNodeManager(ctx, cfg, jobStore.GetEventStore(), natsConn) if err != nil { return nil, err } @@ -371,7 +371,7 @@ func createJobStore(ct context.Context, cfg NodeConfig) (jobstore.Store, error) return jobStore, nil } -func createNodeManager(ctx context.Context, cfg NodeConfig, natsConn *nats.Conn) ( +func createNodeManager(ctx context.Context, cfg NodeConfig, eventStore watcher.EventStore, natsConn *nats.Conn) ( nodes.Manager, nodes.Store, error) { nodeInfoStore, err := kvstore.NewNodeStore(ctx, kvstore.NodeStoreParams{ BucketName: kvstore.BucketNameCurrent, @@ -385,6 +385,7 @@ func createNodeManager(ctx context.Context, cfg NodeConfig, natsConn *nats.Conn) Store: nodeInfoStore, NodeDisconnectedAfter: cfg.BacalhauConfig.Orchestrator.NodeManager.DisconnectTimeout.AsTimeDuration(), ManualApproval: cfg.BacalhauConfig.Orchestrator.NodeManager.ManualApproval, + EventStore: eventStore, }) if err != nil { diff --git a/pkg/orchestrator/nodes/manager.go b/pkg/orchestrator/nodes/manager.go index b90cf091de..866183cb09 100644 --- a/pkg/orchestrator/nodes/manager.go +++ b/pkg/orchestrator/nodes/manager.go @@ -2,6 +2,8 @@ package nodes import ( "context" + "errors" + "fmt" "sync" "time" @@ -9,6 +11,8 @@ import ( "github.com/rs/zerolog/log" "github.com/bacalhau-project/bacalhau/pkg/bacerrors" + "github.com/bacalhau-project/bacalhau/pkg/lib/validate" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/models/messages" ) @@ -42,8 +46,9 @@ const ( // state persistence with configurable intervals. type nodesManager struct { // Core dependencies - store Store // Persistent storage for node states - clock clock.Clock // Time source (can be mocked for testing) + store Store // Persistent storage for node states + eventstore watcher.EventStore // Store for events + clock clock.Clock // Time source (can be mocked for testing) // Configuration defaultApprovalState models.NodeMembershipState // Initial membership state for new nodes @@ -94,6 +99,10 @@ type ManagerParams struct { // ShutdownTimeout is the timeout for graceful shutdown (optional) ShutdownTimeout time.Duration + + // EventStore provides storage for events so that node manager can assign + // new nodes with latest sequence number in the store + EventStore watcher.EventStore } // trackedLiveState holds the runtime state for an active node. @@ -138,8 +147,16 @@ func NewManager(params ManagerParams) (Manager, error) { params.ShutdownTimeout = defaultShutdownTimeout } + if err := errors.Join( + validate.NotNil(params.Store, "store required"), + validate.NotNil(params.EventStore, "event store required"), + ); err != nil { + return nil, fmt.Errorf("node manager invalid params: %w", err) + } + return &nodesManager{ store: params.Store, + eventstore: params.EventStore, clock: params.Clock, liveState: &sync.Map{}, defaultApprovalState: defaultApprovalState, @@ -425,16 +442,29 @@ func (n *nodesManager) Handshake( Info: request.NodeInfo, Membership: n.defaultApprovalState, ConnectionState: models.ConnectionState{ - Status: models.NodeStates.CONNECTED, - ConnectedSince: n.clock.Now(), - LastHeartbeat: n.clock.Now(), - LastOrchestratorSeqNum: request.LastOrchestratorSeqNum, + Status: models.NodeStates.CONNECTED, + ConnectedSince: n.clock.Now(), + LastHeartbeat: n.clock.Now(), }, } + // If a node is reconnecting, we trust and preserve the sequence numbers from its previous state, + // rather than using the sequence numbers from the handshake request. For new nodes, + // we assign them the latest event sequence number from the event store. + // This prevents several edge cases: + // - Compute node losing its state. The handshake will ask to start from 0. + // - Orchestrator losing their state and compute nodes asking to start from a later seqNum that no longer exist. + // - New compute nodes joining. The handshake will also ask to start from 0. if isReconnect { state.Membership = existing.Membership state.ConnectionState.LastComputeSeqNum = existing.ConnectionState.LastComputeSeqNum + state.ConnectionState.LastOrchestratorSeqNum = existing.ConnectionState.LastOrchestratorSeqNum + } else { + // Assign the latest sequence number from the event store + state.ConnectionState.LastOrchestratorSeqNum, err = n.eventstore.GetLatestEventNum(ctx) + if err != nil { + return messages.HandshakeResponse{}, fmt.Errorf("failed to initialize node with latest event number: %w", err) + } } if err = n.store.Put(ctx, state); err != nil { diff --git a/pkg/orchestrator/nodes/manager_test.go b/pkg/orchestrator/nodes/manager_test.go index 2afc6a5a91..f3bf5a003c 100644 --- a/pkg/orchestrator/nodes/manager_test.go +++ b/pkg/orchestrator/nodes/manager_test.go @@ -15,10 +15,12 @@ import ( "github.com/stretchr/testify/suite" "github.com/bacalhau-project/bacalhau/pkg/bacerrors" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/models/messages" "github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes" "github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes/inmemory" + testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils" ) type NodeManagerTestSuite struct { @@ -26,6 +28,7 @@ type NodeManagerTestSuite struct { ctx context.Context clock *clock.Mock store nodes.Store + eventStore watcher.EventStore manager nodes.Manager disconnected time.Duration } @@ -44,8 +47,11 @@ func (s *NodeManagerTestSuite) SetupTest() { TTL: time.Hour, }) + s.eventStore, _ = testutils.CreateStringEventStore(s.T()) + manager, err := nodes.NewManager(nodes.ManagerParams{ Store: s.store, + EventStore: s.eventStore, Clock: s.clock, NodeDisconnectedAfter: s.disconnected, HealthCheckFrequency: 1 * time.Second, @@ -62,6 +68,10 @@ func (s *NodeManagerTestSuite) SetupTest() { func (s *NodeManagerTestSuite) TearDownTest() { err := s.manager.Stop(s.ctx) s.Require().NoError(err) + + // Cleanup event store + err = s.eventStore.Close(s.ctx) + s.Require().NoError(err) } func (s *NodeManagerTestSuite) createNodeInfo(id string) models.NodeInfo { @@ -139,6 +149,134 @@ func (s *NodeManagerTestSuite) TestHeartbeatMaintainsConnection() { // Edge Cases and Error Scenarios +func (s *NodeManagerTestSuite) TestHandshakeSequenceNumberLogic() { + // Test initial handshake with new node + nodeInfo := s.createNodeInfo("new-node") + + // First add some events to the event store to have a non-zero latest sequence + ctx := context.Background() + for i := 0; i < 5; i++ { + err := s.eventStore.StoreEvent(ctx, watcher.StoreEventRequest{ + Operation: watcher.OperationCreate, + ObjectType: testutils.TypeString, + Object: fmt.Sprintf("test-event-%d", i), + }) + s.Require().NoError(err) + } + + // Get the latest sequence number for verification + latestSeqNum, err := s.eventStore.GetLatestEventNum(ctx) + s.Require().NoError(err) + + // Perform initial handshake + resp1, err := s.manager.Handshake(ctx, messages.HandshakeRequest{ + NodeInfo: nodeInfo, + LastOrchestratorSeqNum: 100, // Should be ignored for new nodes + }) + s.Require().NoError(err) + s.Require().True(resp1.Accepted) + + // Verify the node was assigned the latest sequence number + state, err := s.manager.Get(ctx, nodeInfo.ID()) + s.Require().NoError(err) + s.Assert().Equal(latestSeqNum, state.ConnectionState.LastOrchestratorSeqNum, + "New node should be assigned latest sequence number") + + // Update sequence numbers via heartbeat + updatedOrchSeqNum := uint64(200) + updatedComputeSeqNum := uint64(150) + _, err = s.manager.Heartbeat(ctx, nodes.ExtendedHeartbeatRequest{ + HeartbeatRequest: messages.HeartbeatRequest{ + NodeID: nodeInfo.ID(), + LastOrchestratorSeqNum: updatedOrchSeqNum, + }, + LastComputeSeqNum: updatedComputeSeqNum, + }) + s.Require().NoError(err) + + // Simulate disconnect + s.clock.Add(s.disconnected + time.Second) + s.Eventually(func() bool { + state, err := s.manager.Get(ctx, nodeInfo.ID()) + s.Require().NoError(err) + return state.ConnectionState.Status == models.NodeStates.DISCONNECTED + }, 500*time.Millisecond, 20*time.Millisecond) + + // Reconnect with different sequence number - should keep existing + resp2, err := s.manager.Handshake(ctx, messages.HandshakeRequest{ + NodeInfo: nodeInfo, + LastOrchestratorSeqNum: 300, // Should be ignored for reconnecting nodes + }) + s.Require().NoError(err) + s.Require().True(resp2.Accepted) + s.Assert().Contains(resp2.Reason, "reconnected") + + // Verify sequence numbers were preserved from previous state + state, err = s.manager.Get(ctx, nodeInfo.ID()) + s.Require().NoError(err) + s.Assert().Equal(updatedOrchSeqNum, state.ConnectionState.LastOrchestratorSeqNum, + "Reconnected node should preserve previous orchestrator sequence number") + s.Assert().Equal(updatedComputeSeqNum, state.ConnectionState.LastComputeSeqNum, + "Reconnected node should preserve previous compute sequence number") +} + +func (s *NodeManagerTestSuite) TestHandshakeSequenceNumberEdgeCases() { + ctx := context.Background() + + // Test zero sequence numbers in event store + nodeInfo1 := s.createNodeInfo("zero-seq-node") + resp1, err := s.manager.Handshake(ctx, messages.HandshakeRequest{ + NodeInfo: nodeInfo1, + }) + s.Require().NoError(err) + s.Require().True(resp1.Accepted) + + state1, err := s.manager.Get(ctx, nodeInfo1.ID()) + s.Require().NoError(err) + s.Assert().Equal(uint64(0), state1.ConnectionState.LastOrchestratorSeqNum, + "New node should get zero sequence when event store is empty") + + // Test concurrent handshakes with sequence numbers + var wg sync.WaitGroup + const numConcurrent = 10 + + // Add some events first + for i := 0; i < 5; i++ { + err = s.eventStore.StoreEvent(ctx, watcher.StoreEventRequest{ + Operation: watcher.OperationCreate, + ObjectType: testutils.TypeString, + Object: fmt.Sprintf("test-event-%d", i), + }) + s.Require().NoError(err) + } + + latestSeqNum, err := s.eventStore.GetLatestEventNum(ctx) + s.Require().NoError(err) + + for i := 0; i < numConcurrent; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + nodeInfo := s.createNodeInfo(fmt.Sprintf("concurrent-node-%d", id)) + resp, err := s.manager.Handshake(ctx, messages.HandshakeRequest{ + NodeInfo: nodeInfo, + LastOrchestratorSeqNum: 999, // Should be ignored + }) + s.Require().NoError(err) + s.Require().True(resp.Accepted) + + // Verify assigned sequence number + state, err := s.manager.Get(ctx, nodeInfo.ID()) + s.Require().NoError(err) + s.Assert().Equal(latestSeqNum, state.ConnectionState.LastOrchestratorSeqNum, + "Concurrent new nodes should all get latest sequence number") + }(i) + } + + wg.Wait() +} + func (s *NodeManagerTestSuite) TestHeartbeatWithoutHandshake() { _, err := s.manager.Heartbeat(s.ctx, nodes.ExtendedHeartbeatRequest{ HeartbeatRequest: messages.HeartbeatRequest{ @@ -394,6 +532,7 @@ func (s *NodeManagerTestSuite) TestConcurrentOperations() { manager, err := nodes.NewManager(nodes.ManagerParams{ Store: s.store, + EventStore: s.eventStore, Clock: clock.New(), // Use real clock for this test NodeDisconnectedAfter: s.disconnected, HealthCheckFrequency: 1 * time.Second, @@ -552,6 +691,7 @@ func (s *NodeManagerTestSuite) TestStartStop() { // Create a new manager without starting it manager, err := nodes.NewManager(nodes.ManagerParams{ Store: s.store, + EventStore: s.eventStore, Clock: s.clock, NodeDisconnectedAfter: s.disconnected, HealthCheckFrequency: 1 * time.Second, @@ -578,6 +718,7 @@ func (s *NodeManagerTestSuite) TestStartAlreadyStarted() { // Create and start a manager manager, err := nodes.NewManager(nodes.ManagerParams{ Store: s.store, + EventStore: s.eventStore, Clock: s.clock, NodeDisconnectedAfter: s.disconnected, }) @@ -603,6 +744,7 @@ func (s *NodeManagerTestSuite) TestStartAlreadyStarted() { func (s *NodeManagerTestSuite) TestStartContextCancellation() { manager, err := nodes.NewManager(nodes.ManagerParams{ Store: s.store, + EventStore: s.eventStore, Clock: s.clock, NodeDisconnectedAfter: s.disconnected, HealthCheckFrequency: 1 * time.Second, @@ -630,6 +772,7 @@ func (s *NodeManagerTestSuite) TestStopAlreadyStopped() { // Create and start a manager manager, err := nodes.NewManager(nodes.ManagerParams{ Store: s.store, + EventStore: s.eventStore, Clock: s.clock, NodeDisconnectedAfter: s.disconnected, }) @@ -657,6 +800,7 @@ func (s *NodeManagerTestSuite) TestPeriodicStatePersistence() { persistInterval := 100 * time.Millisecond manager, err := nodes.NewManager(nodes.ManagerParams{ Store: s.store, + EventStore: s.eventStore, Clock: s.clock, NodeDisconnectedAfter: s.disconnected, PersistInterval: persistInterval, @@ -720,6 +864,7 @@ func (s *NodeManagerTestSuite) TestStatePersistenceOnStop() { // Create manager manager, err := nodes.NewManager(nodes.ManagerParams{ Store: s.store, + EventStore: s.eventStore, Clock: s.clock, NodeDisconnectedAfter: s.disconnected, PersistInterval: time.Hour, // Long interval to ensure persistence happens on stop @@ -764,6 +909,7 @@ func (s *NodeManagerTestSuite) TestPersistenceWithContextCancellation() { // Create manager with short persist interval manager, err := nodes.NewManager(nodes.ManagerParams{ Store: s.store, + EventStore: s.eventStore, Clock: s.clock, NodeDisconnectedAfter: s.disconnected, PersistInterval: 100 * time.Millisecond, diff --git a/pkg/transport/bprotocol/orchestrator/heartbeat_test.go b/pkg/transport/bprotocol/orchestrator/heartbeat_test.go index 56b4868cce..feb397bc28 100644 --- a/pkg/transport/bprotocol/orchestrator/heartbeat_test.go +++ b/pkg/transport/bprotocol/orchestrator/heartbeat_test.go @@ -18,6 +18,7 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" + "github.com/bacalhau-project/bacalhau/pkg/lib/watcher" "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/models/messages" "github.com/bacalhau-project/bacalhau/pkg/models/messages/legacy" @@ -41,6 +42,7 @@ type HeartbeatTestSuite struct { messageSerDeRegistry *envelope.Registry heartbeatServer *orchestrator.Server nodeManager nodes.Manager + eventStore watcher.EventStore } func TestHeartbeatTestSuite(t *testing.T) { @@ -53,12 +55,14 @@ func (s *HeartbeatTestSuite) SetupTest() { // Setup NATS server and client s.natsServer, s.natsConn = testutils.StartNats(s.T()) + s.eventStore, _ = testutils.CreateStringEventStore(s.T()) // Setup real node manager s.nodeManager, err = nodes.NewManager(nodes.ManagerParams{ Clock: s.clock, Store: inmemory.NewNodeStore(inmemory.NodeStoreParams{TTL: 1 * time.Hour}), NodeDisconnectedAfter: 5 * time.Second, + EventStore: s.eventStore, }) s.Require().NoError(err) s.Require().NoError(s.nodeManager.Start(context.Background())) @@ -100,6 +104,9 @@ func (s *HeartbeatTestSuite) TearDownTest() { if s.nodeManager != nil { s.nodeManager.Stop(context.Background()) } + if s.eventStore != nil { + s.eventStore.Close(context.Background()) + } } func (s *HeartbeatTestSuite) TestUpdateNodeInfo() { diff --git a/pkg/transport/nclprotocol/orchestrator/dataplane.go b/pkg/transport/nclprotocol/orchestrator/dataplane.go index 07d35340ad..d254646826 100644 --- a/pkg/transport/nclprotocol/orchestrator/dataplane.go +++ b/pkg/transport/nclprotocol/orchestrator/dataplane.go @@ -171,6 +171,7 @@ func (dp *DataPlane) setupDispatcher(ctx context.Context) error { dispatcherWatcher, err := watcher.New(ctx, fmt.Sprintf("orchestrator-dispatcher-%s", dp.config.NodeID), dp.config.EventStore, + watcher.WithEphemeral(), watcher.WithRetryStrategy(watcher.RetryStrategyBlock), watcher.WithInitialEventIterator(watcher.AfterSequenceNumberIterator(dp.config.StartSeqNum)), watcher.WithFilter(watcher.EventFilter{