Skip to content

Commit

Permalink
better exchange of starting seqNum during handshakes (#4766)
Browse files Browse the repository at this point in the history
## Problem
When nodes restart or lose state, the current sequence number
synchronization can lead to message gaps or duplicates. This occurs
because nodes unconditionally trust each other's sequence numbers during
handshake, without considering local state recovery scenarios.

## Solution
This PR implements a "trust your own state" approach for sequence number
synchronization during handshakes. Each node relies on its local state
to determine its starting point, while using the handshake to inform the
other party of its position.

### Changes in Handshake Flow
- **Orchestrator Behavior**
- Uses its local knowledge of the compute node's last received sequence
number
- Starts streaming from 0 if no prior state exists, regardless of
compute node's reported position
  - Tracks compute node's progress through heartbeat updates

- **Compute Node Behavior**
  - Starts from its local checkpoint, preserved across restarts
- Ignores orchestrator's suggested sequence position if local state
exists
  - Continues reporting processed sequence numbers via heartbeats


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced new methods for managing node sequence numbers and state
handling.
- Added functionality to ensure message publishing starts correctly
after node restarts.
  - Enhanced dispatcher state management with a new structured format.

- **Bug Fixes**
- Improved error handling for sequence number resolution and
checkpointing processes.

- **Tests**
- Added new tests to verify the behavior of the data plane and
dispatcher during various scenarios.

- **Documentation**
- Updated comments and documentation to reflect changes in methods and
logic.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
wdbaruni authored Dec 16, 2024
1 parent d71b758 commit 86e8a96
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 52 deletions.
6 changes: 3 additions & 3 deletions pkg/node/requester.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type Requester struct {
Endpoint *orchestrator.BaseEndpoint
JobStore jobstore.Store
// We need a reference to the node info store until libp2p is removed
NodeInfoStore nodes.Store
NodeInfoStore nodes.Lookup
cleanupFunc func(ctx context.Context)
debugInfoProviders []models.DebugInfoProvider
}
Expand All @@ -77,7 +77,7 @@ func NewRequesterNode(
}

nodeID := cfg.NodeID
nodesManager, nodeStore, err := createNodeManager(ctx, cfg, jobStore.GetEventStore(), nodeInfoProvider, natsConn)
nodesManager, _, err := createNodeManager(ctx, cfg, jobStore.GetEventStore(), nodeInfoProvider, natsConn)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -327,7 +327,7 @@ func NewRequesterNode(

return &Requester{
Endpoint: endpointV2,
NodeInfoStore: nodeStore,
NodeInfoStore: nodesManager,
JobStore: jobStore,
cleanupFunc: cleanupFunc,
debugInfoProviders: debugInfoProviders,
Expand Down
77 changes: 55 additions & 22 deletions pkg/orchestrator/nodes/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func (n *nodesManager) Handshake(
var existing models.NodeState

// Check if node is already registered, and if so, if it was rejected
existing, err := n.store.Get(ctx, request.NodeInfo.ID())
existing, err := n.Get(ctx, request.NodeInfo.ID())
if err == nil {
if existing.Membership == models.NodeMembership.REJECTED {
return messages.HandshakeResponse{
Expand Down Expand Up @@ -459,23 +459,15 @@ func (n *nodesManager) Handshake(
},
}

// If a node is reconnecting, we trust and preserve the sequence numbers from its previous state,
// rather than using the sequence numbers from the handshake request. For new nodes,
// we assign them the latest event sequence number from the event store.
// This prevents several edge cases:
// - Compute node losing its state. The handshake will ask to start from 0.
// - Orchestrator losing their state and compute nodes asking to start from a later seqNum that no longer exist.
// - New compute nodes joining. The handshake will also ask to start from 0.
if isReconnect {
state.Membership = existing.Membership
state.ConnectionState.LastComputeSeqNum = existing.ConnectionState.LastComputeSeqNum
state.ConnectionState.LastOrchestratorSeqNum = existing.ConnectionState.LastOrchestratorSeqNum
} else {
// Assign the latest sequence number from the event store
state.ConnectionState.LastOrchestratorSeqNum, err = n.eventstore.GetLatestEventNum(ctx)
if err != nil {
return messages.HandshakeResponse{}, fmt.Errorf("failed to initialize node with latest event number: %w", err)
}
}

// Resolve where the node should start receiving messages from
state.ConnectionState.LastOrchestratorSeqNum, err = n.resolveStartingOrchestratorSeqNum(ctx, isReconnect, existing)
if err != nil {
return messages.HandshakeResponse{}, fmt.Errorf("failed to resolve starting sequence number: %w", err)
}

if err = n.store.Put(ctx, state); err != nil {
Expand Down Expand Up @@ -572,12 +564,7 @@ func (n *nodesManager) Heartbeat(
// updated connection state
updated := existing.connectionState
updated.LastHeartbeat = n.clock.Now().UTC()
if request.LastOrchestratorSeqNum > 0 {
updated.LastOrchestratorSeqNum = request.LastOrchestratorSeqNum
}
if request.LastComputeSeqNum > 0 {
updated.LastComputeSeqNum = request.LastComputeSeqNum
}
n.updateSequenceNumbers(&updated, request.LastOrchestratorSeqNum, request.LastComputeSeqNum)

// Store updated state back if no concurrent modification
if !n.liveState.CompareAndSwap(request.NodeID, existing, &trackedLiveState{
Expand Down Expand Up @@ -764,6 +751,52 @@ func (n *nodesManager) List(ctx context.Context, filters ...NodeStateFilter) ([]
return states, nil
}

// resolveStartingOrchestratorSeqNum determines where a node should start receiving messages from.
//
// For reconnecting nodes, we trust the sequence numbers from our store rather than what the
// compute node reports. This prevents issues with compute nodes restarting with same ID but
// fresh state, where they would ask to start from 0.
//
// For new nodes, we start them from the latest sequence number to avoid overwhelming them
// with historical events.
//
// TODO: Add support for snapshots to allow nodes to catch up on missed state without
// replaying all historical events. For now, we always start from latest to avoid
// overwhelming nodes that have been down for a long time.
func (n *nodesManager) resolveStartingOrchestratorSeqNum(
ctx context.Context, isReconnect bool, existing models.NodeState) (uint64, error) {
if isReconnect {
// For reconnecting nodes, trust our stored sequence number
return existing.ConnectionState.LastOrchestratorSeqNum, nil
}

// For new nodes or nodes that have been gone too long,
// start from latest to avoid overwhelming them
latestSeq, err := n.eventstore.GetLatestEventNum(ctx)
if err != nil {
return 0, fmt.Errorf("failed to get latest event number: %w", err)
}

return latestSeq, nil
}

// updateSequenceNumbers updates the last known sequence numbers for message tracking.
// - LastOrchestratorSeqNum tracks what messages the compute node has processed
// - LastComputeSeqNum tracks what messages the orchestrator has processed from this node. This is
// populated locally by the orchestrator's data plane.
//
// TODO: Add smarter logic when updating sequence numbers by comparing current state versus observed states.
// Currently we trust what each node reports about their message processing:
// - We trust what compute node says it has received from orchestrator (orchestratorSeqNum)
// - We trust what orchestrator data plane says it processed from compute node (computeSeqNum)
//
// This simple approach could allow sequence numbers to move backwards in certain failure scenarios.
// We should implement proper comparison logic to ensure sequence numbers only advance forward.
func (n *nodesManager) updateSequenceNumbers(state *models.ConnectionState, orchestratorSeq, computeSeq uint64) {
state.LastOrchestratorSeqNum = orchestratorSeq
state.LastComputeSeqNum = computeSeq
}

// enrichState adds live tracking data to a node state.
// For connected nodes, it adds:
// - Current connection state
Expand Down Expand Up @@ -800,7 +833,7 @@ func (n *nodesManager) selfRegister(ctx context.Context) error {
nodeInfo := n.nodeInfoProvider.GetNodeInfo(ctx)

// get node info from the store if it exists
state, err := n.store.Get(ctx, nodeInfo.ID())
state, err := n.Get(ctx, nodeInfo.ID())
if err != nil {
if !bacerrors.IsErrorWithCode(err, bacerrors.NotFoundError) {
return bacerrors.New("failed to self-register node: %v", err).
Expand Down
3 changes: 3 additions & 0 deletions pkg/transport/nclprotocol/compute/controlplane.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ func (cp *ControlPlane) Stop(ctx context.Context) error {

select {
case <-done:
if err := cp.checkpointProgress(ctx); err != nil {
log.Error().Err(err).Msg("Failed to checkpoint progress before stopping")
}
return nil
case <-ctx.Done():
return ctx.Err()
Expand Down
21 changes: 20 additions & 1 deletion pkg/transport/nclprotocol/compute/dataplane.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (dp *DataPlane) Start(ctx context.Context) error {
var dispatcherWatcher watcher.Watcher
dispatcherWatcher, err = watcher.New(ctx, watcherID, dp.config.EventStore,
watcher.WithRetryStrategy(watcher.RetryStrategyBlock),
watcher.WithInitialEventIterator(watcher.AfterSequenceNumberIterator(dp.lastReceivedSeqNum)),
watcher.WithInitialEventIterator(dp.resolveStartingIterator(dp.lastReceivedSeqNum)),
watcher.WithFilter(watcher.EventFilter{
ObjectTypes: []string{compute.EventObjectExecutionUpsert},
}),
Expand Down Expand Up @@ -174,6 +174,25 @@ func (dp *DataPlane) IsRunning() bool {
return dp.running
}

// resolveStartingIterator determines where message publishing should start from
// when beginning with no checkpoint.
//
// Currently returns TrimHorizonIterator (start from beginning) even if the
// orchestrator provides lastReceivedSeqNum. This ensures no messages are lost
// when a compute node restarts with the same ID but fresh state.
//
// Note that this is only used when starting fresh - if there is an existing
// checkpoint, the watcher will automatically resume from the last checkpointed
// position instead.
//
// The lastReceivedSeqNum parameter from the orchestrator is currently ignored
// but preserved for future use cases where we may want to optimize message
// replay by allowing the orchestrator to indicate its last known position
// when starting fresh.
func (dp *DataPlane) resolveStartingIterator(lastReceivedSeqNum uint64) watcher.EventIterator {
return watcher.TrimHorizonIterator()
}

// cleanup handles the orderly shutdown of data plane components.
// It ensures resources are released in the correct order and collects any errors.
func (dp *DataPlane) cleanup(ctx context.Context) error {
Expand Down
56 changes: 54 additions & 2 deletions pkg/transport/nclprotocol/compute/dataplane_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ func (s *DataPlaneTestSuite) TestStartupFailureCleanup() {
s.cancel()
select {
case <-s.ctx.Done():
// Context cancellation has propagated
// Context cancellation has propagated
case <-time.After(100 * time.Millisecond):
s.Require().Fail("Timeout waiting for context cancellation")
s.Require().Fail("Timeout waiting for context cancellation")
}
},
expectError: "context canceled",
Expand Down Expand Up @@ -245,3 +245,55 @@ func (s *DataPlaneTestSuite) TestMessageHandling() {
})
}
}

func (s *DataPlaneTestSuite) TestStartingPosition() {
// Store some initial events
initialEvents := []models.ExecutionUpsert{
{
Current: &models.Execution{
ID: "test-job-1",
NodeID: "test-node",
},
},
{
Current: &models.Execution{
ID: "test-job-2",
NodeID: "test-node",
},
},
}

for _, event := range initialEvents {
err := s.config.EventStore.StoreEvent(s.ctx, watcher.StoreEventRequest{
Operation: watcher.OperationCreate,
ObjectType: compute.EventObjectExecutionUpsert,
Object: event,
})
s.Require().NoError(err)
}

// Create data plane with a non-zero LastReceivedSeqNum
dp, err := nclprotocolcompute.NewDataPlane(nclprotocolcompute.DataPlaneParams{
Config: s.config,
Client: s.natsConn,
LastReceivedSeqNum: 100, // This should be ignored
})
s.Require().NoError(err)
s.Require().NoError(dp.Start(s.ctx))
defer dp.Stop(context.Background())

// We should still receive ALL messages from the beginning
receivedMessages := 0
timeout := time.After(time.Second)
for receivedMessages < len(initialEvents) {
select {
case msg := <-s.msgChan:
s.Require().Equal(messages.BidResultMessageType, msg.Metadata.Get(envelope.KeyMessageType))
receivedMessages++
case <-timeout:
s.Require().Failf("Timeout waiting for messages",
"Only received %d of %d expected messages",
receivedMessages, len(initialEvents))
}
}
}
5 changes: 5 additions & 0 deletions pkg/transport/nclprotocol/compute/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ func (cm *ConnectionManager) performHandshake(
return messages.HandshakeResponse{}, fmt.Errorf(
"handshake rejected by orchestrator due to %s", handshakeResponse.Reason)
}

// Always trust the orchestrator's starting sequence number as it may have been reset
// or decided to start from a different point
cm.incomingSeqTracker.UpdateLastSeqNum(handshakeResponse.StartingOrchestratorSeqNum)

return handshakeResponse, nil
}

Expand Down
25 changes: 15 additions & 10 deletions pkg/transport/nclprotocol/compute/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,16 @@ func (s *ConnectionManagerTestSuite) TearDownTest() {
}

func (s *ConnectionManagerTestSuite) TestSuccessfulConnection() {
// Setup initial checkpoint
lastOrchestratorSeqNum := uint64(124)
s.checkpointer.SetCheckpoint("incoming-test-node", lastOrchestratorSeqNum)
// Setup initial checkpoint with one sequence number
initialSeqNum := uint64(124)
s.checkpointer.SetCheckpoint("incoming-test-node", initialSeqNum)

// Configure handshake response to return a different sequence number
handshakeSeqNum := uint64(100)
s.mockResponder.Behaviour().HandshakeResponse.Response = messages.HandshakeResponse{
Accepted: true,
StartingOrchestratorSeqNum: handshakeSeqNum,
}

err := s.manager.Start(s.ctx)
s.Require().NoError(err)
Expand All @@ -122,13 +129,12 @@ func (s *ConnectionManagerTestSuite) TestSuccessfulConnection() {
return len(s.mockResponder.GetHandshakes()) > 0
}, time.Second, 10*time.Millisecond, "handshake not received")

// Verify handshake request
// verify only one handshake, and verify the request
handshakes := s.mockResponder.GetHandshakes()
s.Require().Len(handshakes, 1)
s.Require().Equal(s.config.NodeID, handshakes[0].NodeInfo.ID())
s.Require().Equal(lastOrchestratorSeqNum, handshakes[0].LastOrchestratorSeqNum)
s.Require().Equal(initialSeqNum, handshakes[0].LastOrchestratorSeqNum)

// Verify connection established
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected
Expand All @@ -146,15 +152,15 @@ func (s *ConnectionManagerTestSuite) TestSuccessfulConnection() {
return len(s.mockResponder.GetHeartbeats()) > 0
}, time.Second, 10*time.Millisecond, "manager did not send heartbeats")

// Verify heartbeat content
// verify heartbeat content
nodeInfo := s.nodeInfoProvider.GetNodeInfo(s.ctx)
heartbeats := s.mockResponder.GetHeartbeats()
s.Require().Len(heartbeats, 1)
s.Require().Equal(messages.HeartbeatRequest{
NodeID: nodeInfo.NodeID,
AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity,
QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity,
LastOrchestratorSeqNum: lastOrchestratorSeqNum,
LastOrchestratorSeqNum: handshakeSeqNum, // Should use sequence number from handshake response
}, heartbeats[0])

// verify state
Expand All @@ -173,10 +179,9 @@ func (s *ConnectionManagerTestSuite) TestSuccessfulConnection() {
NodeID: nodeInfo.NodeID,
AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity,
QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity,
LastOrchestratorSeqNum: lastOrchestratorSeqNum,
LastOrchestratorSeqNum: handshakeSeqNum, // Should continue using sequence number from handshake
})
}, time.Second, 10*time.Millisecond, "manager did not send heartbeats")

}

func (s *ConnectionManagerTestSuite) TestRejectedHandshake() {
Expand Down
37 changes: 26 additions & 11 deletions pkg/transport/nclprotocol/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func New(publisher ncl.OrderedPublisher,
return d, nil
}

func (d *Dispatcher) State() State {
return d.state.GetState()
}

// Start begins processing events and managing async publish results.
// It launches background goroutines for processing publish results,
// checking for stalled messages, and checkpointing progress.
Expand Down Expand Up @@ -226,19 +230,30 @@ func (d *Dispatcher) checkpointLoop(ctx context.Context) {
case <-ctx.Done():
return
case <-d.stopCh:
d.doCheckpoint(ctx) // Final checkpoint on shutdown
return
case <-ticker.C:
checkpointTarget := d.state.getCheckpointSeqNum()
// Only checkpoint if we have something new to save
if checkpointTarget > 0 {
checkpointCtx, cancel := context.WithTimeout(ctx, d.config.CheckpointTimeout)
if err := d.watcher.Checkpoint(checkpointCtx, checkpointTarget); err != nil {
log.Error().Err(err).Msg("Failed to checkpoint watcher")
} else {
d.state.updateLastCheckpoint(checkpointTarget)
}
cancel()
}
d.doCheckpoint(ctx) // Periodic checkpoint
}
}
}

// doCheckpoint attempts to checkpoint the current sequence number if needed
func (d *Dispatcher) doCheckpoint(ctx context.Context) {
checkpointTarget := d.state.getCheckpointSeqNum()
if checkpointTarget == 0 { // Nothing new to checkpoint
return
}

checkpointCtx, cancel := context.WithTimeout(ctx, d.config.CheckpointTimeout)
defer cancel()

if err := d.watcher.Checkpoint(checkpointCtx, checkpointTarget); err != nil {
log.Error().Err(err).
Uint64("target", checkpointTarget).
Msg("Failed to checkpoint watcher")
return
}

d.state.updateLastCheckpoint(checkpointTarget)
}
Loading

0 comments on commit 86e8a96

Please sign in to comment.