From e42172a53fd3e30cd4200d18c0d35a5d604f7a0f Mon Sep 17 00:00:00 2001 From: Ian Davis <18375+iand@users.noreply.github.com> Date: Wed, 11 Oct 2023 12:41:19 +0100 Subject: [PATCH 1/3] fix: avoid deadlocks in query behaviour --- internal/coord/behaviour.go | 35 ++++ internal/coord/query.go | 32 +++- internal/coord/query_test.go | 298 +++++++++++++++++++++++++++++++++++ internal/kadtest/chan.go | 36 +++++ routing_test.go | 86 ++++------ 5 files changed, 427 insertions(+), 60 deletions(-) create mode 100644 internal/kadtest/chan.go diff --git a/internal/coord/behaviour.go b/internal/coord/behaviour.go index 7460994..21de6b0 100644 --- a/internal/coord/behaviour.go +++ b/internal/coord/behaviour.go @@ -146,3 +146,38 @@ func (w *Waiter[E]) Close() { func (w *Waiter[E]) Chan() <-chan WaiterEvent[E] { return w.pending } + +// NotifyCloserHook implements the [NotifyCloser] interface and provides hooks +// into the Notify and Close calls by wrapping another [NotifyCloser]. This is +// intended to be used in testing. +type NotifyCloserHook[E BehaviourEvent] struct { + nc NotifyCloser[E] + BeforeNotify func(context.Context, E) + AfterNotify func(context.Context, E) + BeforeClose func() + AfterClose func() +} + +var _ NotifyCloser[BehaviourEvent] = (*NotifyCloserHook[BehaviourEvent])(nil) + +func NewNotifyCloserHook[E BehaviourEvent](nc NotifyCloser[E]) *NotifyCloserHook[E] { + return &NotifyCloserHook[E]{ + nc: nc, + BeforeNotify: func(ctx context.Context, e E) {}, + AfterNotify: func(ctx context.Context, e E) {}, + BeforeClose: func() {}, + AfterClose: func() {}, + } +} + +func (n *NotifyCloserHook[E]) Notify(ctx context.Context, ev E) { + n.BeforeNotify(ctx, ev) + n.nc.Notify(ctx, ev) + n.AfterNotify(ctx, ev) +} + +func (n *NotifyCloserHook[E]) Close() { + n.BeforeClose() + n.nc.Close() + n.AfterClose() +} diff --git a/internal/coord/query.go b/internal/coord/query.go index f0678f0..32269a9 100644 --- a/internal/coord/query.go +++ b/internal/coord/query.go @@ -107,16 +107,29 @@ func DefaultPooledQueryConfig() *PooledQueryConfig { } } +// PooledQueryBehaviour holds the behaviour and state for managing a pool of queries. type PooledQueryBehaviour struct { - cfg PooledQueryConfig - pool *query.Pool[kadt.Key, kadt.PeerID, *pb.Message] + // cfg is a copy of the optional configuration supplied to the behaviour. + cfg PooledQueryConfig + + // pool is the query pool state machine used for managing individual queries. + pool *query.Pool[kadt.Key, kadt.PeerID, *pb.Message] + + // waiters is a map that keeps track of event notifications for each running query. waiters map[coordt.QueryID]NotifyCloser[BehaviourEvent] + // pendingMu guards access to pending pendingMu sync.Mutex - pending []BehaviourEvent - ready chan struct{} + + // pending is a queue of pending events that need to be processed. + pending []BehaviourEvent + + // ready is a channel signaling that events are ready to be processed. + ready chan struct{} } +// NewPooledQueryBehaviour initialises a new PooledQueryBehaviour, setting up the query +// pool and other internal state. func NewPooledQueryBehaviour(self kadt.PeerID, cfg *PooledQueryConfig) (*PooledQueryBehaviour, error) { if cfg == nil { cfg = DefaultPooledQueryConfig() @@ -145,6 +158,9 @@ func NewPooledQueryBehaviour(self kadt.PeerID, cfg *PooledQueryConfig) (*PooledQ return h, err } +// Notify receives a behaviour event and takes appropriate actions such as starting, +// stopping, or updating queries. It also queues events for later processing and +// triggers the advancement of the query pool if applicable. func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.Notify") defer span.End() @@ -259,10 +275,15 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { } } +// Ready returns a channel that signals when the pooled query behaviour is ready to +// perform work. func (p *PooledQueryBehaviour) Ready() <-chan struct{} { return p.ready } +// Perform executes the next available task from the queue of pending events or advances +// the query pool. Returns the executed event and a boolean indicating whether work was +// performed. func (p *PooledQueryBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.Perform") defer span.End() @@ -298,6 +319,9 @@ func (p *PooledQueryBehaviour) Perform(ctx context.Context) (BehaviourEvent, boo } } +// advancePool advances the query pool state machine and returns an outbound event if +// there is work to be performed. Also notifies waiters of query completion or +// progress. func (p *PooledQueryBehaviour) advancePool(ctx context.Context, ev query.PoolEvent) (out BehaviourEvent, term bool) { ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.advancePool", trace.WithAttributes(tele.AttrInEvent(ev))) defer func() { diff --git a/internal/coord/query_test.go b/internal/coord/query_test.go index 74222a7..354f40c 100644 --- a/internal/coord/query_test.go +++ b/internal/coord/query_test.go @@ -1,9 +1,19 @@ package coord import ( + "context" + "sync" "testing" + "github.com/benbjohnson/clock" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/plprobelab/zikade/internal/coord/coordt" + "github.com/plprobelab/zikade/internal/kadtest" + "github.com/plprobelab/zikade/internal/nettest" + "github.com/plprobelab/zikade/kadt" + "github.com/plprobelab/zikade/pb" ) func TestPooledQueryConfigValidate(t *testing.T) { @@ -68,3 +78,291 @@ func TestPooledQueryConfigValidate(t *testing.T) { require.Error(t, cfg.Validate()) }) } + +func TestQueryBehaviourBase(t *testing.T) { + suite.Run(t, new(QueryBehaviourBaseTestSuite)) +} + +type QueryBehaviourBaseTestSuite struct { + suite.Suite + + cfg *PooledQueryConfig + top *nettest.Topology + nodes []*nettest.Peer +} + +func (ts *QueryBehaviourBaseTestSuite) SetupTest() { + clk := clock.NewMock() + top, nodes, err := nettest.LinearTopology(4, clk) + ts.Require().NoError(err) + + ts.top = top + ts.nodes = nodes + + ts.cfg = DefaultPooledQueryConfig() + ts.cfg.Clock = clk +} + +func (ts *QueryBehaviourBaseTestSuite) TestNotifiesNoProgress() { + t := ts.T() + ctx := kadtest.CtxShort(t) + + target := ts.nodes[3].NodeID.Key() + rt := ts.nodes[0].RoutingTable + seeds := rt.NearestNodes(target, 5) + + b, err := NewPooledQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) + ts.Require().NoError(err) + + waiter := NewWaiter[BehaviourEvent]() + cmd := &EventStartFindCloserQuery{ + QueryID: "test", + Target: target, + KnownClosestNodes: seeds, + Notify: waiter, + NumResults: 10, + } + + // queue the start of the query + b.Notify(ctx, cmd) + + // behaviour should emit EventOutboundGetCloserNodes to start the query + bev, ok := b.Perform(ctx) + ts.Require().True(ok) + ts.Require().IsType(&EventOutboundGetCloserNodes{}, bev) + + egc := bev.(*EventOutboundGetCloserNodes) + ts.Require().True(egc.To.Equal(ts.nodes[1].NodeID)) + + // notify failure + b.Notify(ctx, &EventGetCloserNodesFailure{ + QueryID: "test", + To: egc.To, + Target: target, + }) + + // ensure that the waiter received query finished event + wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) + ts.Require().IsType(&EventQueryFinished{}, wev.Event) +} + +func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryProgressed() { + t := ts.T() + ctx := kadtest.CtxShort(t) + + target := ts.nodes[3].NodeID.Key() + rt := ts.nodes[0].RoutingTable + seeds := rt.NearestNodes(target, 5) + + b, err := NewPooledQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) + ts.Require().NoError(err) + + waiter := NewWaiter[BehaviourEvent]() + cmd := &EventStartFindCloserQuery{ + QueryID: "test", + Target: target, + KnownClosestNodes: seeds, + Notify: waiter, + NumResults: 10, + } + + // queue the start of the query + b.Notify(ctx, cmd) + + // behaviour should emit EventOutboundGetCloserNodes to start the query + bev, ok := b.Perform(ctx) + ts.Require().True(ok) + ts.Require().IsType(&EventOutboundGetCloserNodes{}, bev) + + egc := bev.(*EventOutboundGetCloserNodes) + ts.Require().True(egc.To.Equal(ts.nodes[1].NodeID)) + + // notify success + b.Notify(ctx, &EventGetCloserNodesSuccess{ + QueryID: "test", + To: egc.To, + Target: target, + CloserNodes: ts.nodes[1].RoutingTable.NearestNodes(target, 5), + }) + + // ensure that the waiter received query progressed event + wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) + ts.Require().IsType(&EventQueryProgressed{}, wev.Event) +} + +func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryFinished() { + t := ts.T() + ctx := kadtest.CtxShort(t) + + target := ts.nodes[3].NodeID.Key() + rt := ts.nodes[0].RoutingTable + seeds := rt.NearestNodes(target, 5) + + b, err := NewPooledQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) + ts.Require().NoError(err) + + waiter := NewWaiter[BehaviourEvent]() + cmd := &EventStartFindCloserQuery{ + QueryID: "test", + Target: target, + KnownClosestNodes: seeds, + Notify: waiter, + NumResults: 10, + } + + // queue the start of the query + b.Notify(ctx, cmd) + + // behaviour should emit EventOutboundGetCloserNodes to start the query + bev, ok := b.Perform(ctx) + ts.Require().True(ok) + ts.Require().IsType(&EventOutboundGetCloserNodes{}, bev) + + egc := bev.(*EventOutboundGetCloserNodes) + ts.Require().True(egc.To.Equal(ts.nodes[1].NodeID)) + + // notify success + b.Notify(ctx, &EventGetCloserNodesSuccess{ + QueryID: "test", + To: egc.To, + Target: target, + CloserNodes: ts.nodes[1].RoutingTable.NearestNodes(target, 5), + }) + + // ensure that the waiter received query progressed event + wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) + ts.Require().IsType(&EventQueryProgressed{}, wev.Event) + + // skip events until next EventOutboundGetCloserNodes is reached + for { + bev, ok = b.Perform(ctx) + ts.Require().True(ok) + + egc, ok = bev.(*EventOutboundGetCloserNodes) + if ok { + break + } + } + + ts.Require().True(egc.To.Equal(ts.nodes[2].NodeID)) + // notify success but no further nodes + b.Notify(ctx, &EventGetCloserNodesSuccess{ + QueryID: "test", + To: egc.To, + Target: target, + }) + + // ensure that the waiter received query progressed event + wev = kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) + ts.Require().IsType(&EventQueryProgressed{}, wev.Event) +} + +func TestPooledQuery_deadlock_regression(t *testing.T) { + t.Skip() + ctx := kadtest.CtxShort(t) + msg := &pb.Message{} + queryID := coordt.QueryID("test") + + _, nodes, err := nettest.LinearTopology(3, clock.New()) + require.NoError(t, err) + + // it would be better to just work with the queryBehaviour in this test. + // However, we want to test as many parts as possible and waitForQuery + // is defined on the coordinator. Therfore, we instantiate a coordinator + // and close it immediately to manually control state machine progression. + c, err := NewCoordinator(nodes[0].NodeID, nodes[0].Router, nodes[0].RoutingTable, nil) + require.NoError(t, err) + require.NoError(t, c.Close()) // close immediately so that we control the state machine progression + + // define a function that produces success messages + successMsg := func(to kadt.PeerID, closer ...kadt.PeerID) *EventSendMessageSuccess { + return &EventSendMessageSuccess{ + QueryID: queryID, + Request: msg, + To: to, + Response: nil, + CloserNodes: closer, + } + } + + // start query + waiter := NewWaiter[BehaviourEvent]() + wrappedWaiter := NewNotifyCloserHook[BehaviourEvent](waiter) + + waiterDone := make(chan struct{}) + waiterMsg := make(chan struct{}) + go func() { + defer close(waiterDone) + defer close(waiterMsg) + _, _, err = c.waitForQuery(ctx, queryID, waiter, func(ctx context.Context, id kadt.PeerID, resp *pb.Message, stats coordt.QueryStats) error { + waiterMsg <- struct{}{} + return coordt.ErrSkipRemaining + }) + }() + + // start the message query + c.queryBehaviour.Notify(ctx, &EventStartMessageQuery{ + QueryID: queryID, + Target: msg.Target(), + Message: msg, + KnownClosestNodes: []kadt.PeerID{nodes[1].NodeID}, + Notify: wrappedWaiter, + NumResults: 0, + }) + + // advance state machines and assert that the state machine + // wants to send an outbound message to another peer + ev, _ := c.queryBehaviour.Perform(ctx) + require.IsType(t, &EventOutboundSendMessage{}, ev) + + // simulate a successful response from another node that returns one new node + // This should result in a message for the waiter + c.queryBehaviour.Notify(ctx, successMsg(nodes[1].NodeID, nodes[2].NodeID)) + + // Because we're blocking on the waiterMsg channel in the waitForQuery + // method above, we simulate a slow receiving waiter. + + // Advance the query pool state machine. Because we returned a new node + // above, the query pool state machine wants to send another outbound query + ev, _ = c.queryBehaviour.Perform(ctx) + require.IsType(t, &EventAddNode{}, ev) // event to notify the routing table + ev, _ = c.queryBehaviour.Perform(ctx) + require.IsType(t, &EventOutboundSendMessage{}, ev) + + hasLock := make(chan struct{}) + var once sync.Once + wrappedWaiter.BeforeNotify = func(ctx context.Context, event BehaviourEvent) { + once.Do(func() { + require.IsType(t, &EventQueryProgressed{}, event) // verify test invariant + close(hasLock) + }) + } + + // Simulate a successful response from the new node. This node didn't return + // any new nodes to contact. This means the query pool behaviour will notify + // the waiter about a query progression and afterward about a finished + // query. Because (at the time of writing) the waiter has a channel buffer + // of 1, the channel cannot hold both events. At the same time, the waiter + // doesn't consume the messages because it's busy processing the previous + // query event (because we haven't released the blocking waiterMsg call above). + go c.queryBehaviour.Notify(ctx, successMsg(nodes[2].NodeID)) + + // wait until the above Notify call was handled by waiting until the hasLock + // channel was closed in the above BeforeNotify hook. If that hook is called + // we can be sure that the above Notify call has acquired the polled query + // behaviour's pendingMu lock. + kadtest.AssertClosed(t, ctx, hasLock) + + // Since we know that the pooled query behaviour holds the lock we can + // release the slow waiter by reading an item from the waiterMsg channel. + kadtest.ReadItem(t, ctx, waiterMsg) + + // At this point, the waitForQuery QueryFunc callback returned a + // coordt.ErrSkipRemaining. This instructs the waitForQuery method to notify + // the query behaviour with an EventStopQuery event. However, because the + // query behaviour is busy sending a message to the waiter it is holding the + // lock on the pending events to process. Therefore, this notify call will + // also block. At the same time, the waiter cannot read the new messages + // from the query behaviour because it tries to notify it. + kadtest.AssertClosed(t, ctx, waiterDone) +} diff --git a/internal/kadtest/chan.go b/internal/kadtest/chan.go new file mode 100644 index 0000000..e4f030e --- /dev/null +++ b/internal/kadtest/chan.go @@ -0,0 +1,36 @@ +package kadtest + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stretchr/testify/require" +) + +func ReadItem[T any](t testing.TB, ctx context.Context, c <-chan T) T { + t.Helper() + + select { + case val, more := <-c: + require.True(t, more, "channel closed unexpectedly") + return val + case <-ctx.Done(): + t.Fatal("timeout reading item") + return *new(T) + } +} + +// AssertClosed triggers a test failure if the given channel was not closed but +// carried more values or a timeout occurs (given by the context). +func AssertClosed[T any](t testing.TB, ctx context.Context, c <-chan T) { + t.Helper() + + select { + case _, more := <-c: + assert.False(t, more) + case <-ctx.Done(): + t.Fatal("timeout closing channel") + } +} diff --git a/routing_test.go b/routing_test.go index e6807e9..d261f22 100644 --- a/routing_test.go +++ b/routing_test.go @@ -289,7 +289,7 @@ func TestDHT_FindProvidersAsync_empty_routing_table(t *testing.T) { c := newRandomContent(t) out := d.FindProvidersAsync(ctx, c, 1) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_dht_does_not_support_providers(t *testing.T) { @@ -300,7 +300,7 @@ func TestDHT_FindProvidersAsync_dht_does_not_support_providers(t *testing.T) { delete(d.backends, namespaceProviders) out := d.FindProvidersAsync(ctx, newRandomContent(t), 1) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_providers_stored_locally(t *testing.T) { @@ -315,10 +315,10 @@ func TestDHT_FindProvidersAsync_providers_stored_locally(t *testing.T) { out := d.FindProvidersAsync(ctx, c, 1) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, provider.ID, val.ID) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_returns_only_count_from_local_store(t *testing.T) { @@ -376,10 +376,10 @@ func TestDHT_FindProvidersAsync_queries_other_peers(t *testing.T) { out := d1.FindProvidersAsync(ctx, c, 1) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, provider.ID, val.ID) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_respects_cancelled_context_for_local_query(t *testing.T) { @@ -493,7 +493,7 @@ func TestDHT_FindProvidersAsync_datastore_error(t *testing.T) { be.datastore = dstore out := d.FindProvidersAsync(ctx, newRandomContent(t), 0) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_FindProvidersAsync_invalid_key(t *testing.T) { @@ -501,7 +501,7 @@ func TestDHT_FindProvidersAsync_invalid_key(t *testing.T) { d := newTestDHT(t) out := d.FindProvidersAsync(ctx, cid.Cid{}, 0) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_GetValue_happy_path(t *testing.T) { @@ -563,32 +563,6 @@ func TestDHT_GetValue_returns_not_found_error(t *testing.T) { assert.Nil(t, valueChan) } -// assertClosed triggers a test failure if the given channel was not closed but -// carried more values or a timeout occurs (given by the context). -func assertClosed[T any](t testing.TB, ctx context.Context, c <-chan T) { - t.Helper() - - select { - case _, more := <-c: - assert.False(t, more) - case <-ctx.Done(): - t.Fatal("timeout closing channel") - } -} - -func readItem[T any](t testing.TB, ctx context.Context, c <-chan T) T { - t.Helper() - - select { - case val, more := <-c: - require.True(t, more, "channel closed unexpectedly") - return val - case <-ctx.Done(): - t.Fatal("timeout reading item") - return *new(T) - } -} - func TestDHT_SearchValue_simple(t *testing.T) { // Test setup: // There is just one other server that returns a valid value. @@ -608,10 +582,10 @@ func TestDHT_SearchValue_simple(t *testing.T) { valChan, err := d1.SearchValue(ctx, key) require.NoError(t, err) - val := readItem(t, ctx, valChan) + val := kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, v, val) - assertClosed(t, ctx, valChan) + kadtest.AssertClosed(t, ctx, valChan) } func TestDHT_SearchValue_returns_best_values(t *testing.T) { @@ -657,13 +631,13 @@ func TestDHT_SearchValue_returns_best_values(t *testing.T) { valChan, err := d1.SearchValue(ctx, key) require.NoError(t, err) - val := readItem(t, ctx, valChan) + val := kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, validValue, val) - val = readItem(t, ctx, valChan) + val = kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, betterValue, val) - assertClosed(t, ctx, valChan) + kadtest.AssertClosed(t, ctx, valChan) } // In order for 'go test' to run this suite, we need to create @@ -756,10 +730,10 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumReachedPrematurely() { out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(3)) require.NoError(t, err) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.validValue, val) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func (suite *SearchValueQuorumTestSuite) TestQuorumReachedAfterDiscoveryOfBetter() { @@ -768,13 +742,13 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumReachedAfterDiscoveryOfBetter out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(5)) require.NoError(t, err) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.validValue, val) - val = readItem(t, ctx, out) + val = kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.betterValue, val) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func (suite *SearchValueQuorumTestSuite) TestQuorumZero() { @@ -785,13 +759,13 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumZero() { out, err := suite.d.SearchValue(ctx, suite.key, RoutingQuorum(0)) require.NoError(t, err) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.validValue, val) - val = readItem(t, ctx, out) + val = kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.betterValue, val) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func (suite *SearchValueQuorumTestSuite) TestQuorumUnspecified() { @@ -802,13 +776,13 @@ func (suite *SearchValueQuorumTestSuite) TestQuorumUnspecified() { out, err := suite.d.SearchValue(ctx, suite.key) require.NoError(t, err) - val := readItem(t, ctx, out) + val := kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.validValue, val) - val = readItem(t, ctx, out) + val = kadtest.ReadItem(t, ctx, out) assert.Equal(t, suite.betterValue, val) - assertClosed(t, ctx, out) + kadtest.AssertClosed(t, ctx, out) } func TestDHT_SearchValue_routing_option_returns_error(t *testing.T) { @@ -864,7 +838,7 @@ func TestDHT_SearchValue_stops_with_cancelled_context(t *testing.T) { valueChan, err := d1.SearchValue(cancelledCtx, "/"+namespaceIPNS+"/some-key") assert.NoError(t, err) - assertClosed(t, ctx, valueChan) + kadtest.AssertClosed(t, ctx, valueChan) } func TestDHT_SearchValue_has_record_locally(t *testing.T) { @@ -892,13 +866,13 @@ func TestDHT_SearchValue_has_record_locally(t *testing.T) { valChan, err := d1.SearchValue(ctx, key) require.NoError(t, err) - val := readItem(t, ctx, valChan) // from local store + val := kadtest.ReadItem(t, ctx, valChan) // from local store assert.Equal(t, validValue, val) - val = readItem(t, ctx, valChan) + val = kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, betterValue, val) - assertClosed(t, ctx, valChan) + kadtest.AssertClosed(t, ctx, valChan) } func TestDHT_SearchValue_offline(t *testing.T) { @@ -914,10 +888,10 @@ func TestDHT_SearchValue_offline(t *testing.T) { valChan, err := d.SearchValue(ctx, key, routing.Offline) require.NoError(t, err) - val := readItem(t, ctx, valChan) + val := kadtest.ReadItem(t, ctx, valChan) assert.Equal(t, v, val) - assertClosed(t, ctx, valChan) + kadtest.AssertClosed(t, ctx, valChan) } func TestDHT_SearchValue_offline_not_found_locally(t *testing.T) { From bf35bb2a28c236c9fd792833a793edfa6537ee5a Mon Sep 17 00:00:00 2001 From: Ian Davis <18375+iand@users.noreply.github.com> Date: Thu, 12 Oct 2023 12:41:47 +0100 Subject: [PATCH 2/3] Add inbound event queue to query behaviour --- internal/coord/query.go | 200 ++++++++++++++++++++++------------- internal/coord/query_test.go | 21 +++- 2 files changed, 145 insertions(+), 76 deletions(-) diff --git a/internal/coord/query.go b/internal/coord/query.go index 32269a9..1185079 100644 --- a/internal/coord/query.go +++ b/internal/coord/query.go @@ -112,19 +112,28 @@ type PooledQueryBehaviour struct { // cfg is a copy of the optional configuration supplied to the behaviour. cfg PooledQueryConfig + // performMu is held while Perform is executing to ensure sequential execution of work. + performMu sync.Mutex + // pool is the query pool state machine used for managing individual queries. + // it must only be accessed while performMu is held pool *query.Pool[kadt.Key, kadt.PeerID, *pb.Message] // waiters is a map that keeps track of event notifications for each running query. + // it must only be accessed while performMu is held waiters map[coordt.QueryID]NotifyCloser[BehaviourEvent] - // pendingMu guards access to pending - pendingMu sync.Mutex + // pendingOutbound is a queue of outbound events. + // it must only be accessed while performMu is held + pendingOutbound []BehaviourEvent + + // pendingInboundMu guards access to pendingInbound + pendingInboundMu sync.Mutex - // pending is a queue of pending events that need to be processed. - pending []BehaviourEvent + // pendingInbound is a queue of inbound events that are awaiting processing + pendingInbound []pendingEvent[BehaviourEvent] - // ready is a channel signaling that events are ready to be processed. + // ready is a channel signaling that the behaviour has work to perform. ready chan struct{} } @@ -165,11 +174,88 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.Notify") defer span.End() - p.pendingMu.Lock() - defer p.pendingMu.Unlock() + p.pendingInboundMu.Lock() + defer p.pendingInboundMu.Unlock() + + p.pendingInbound = append(p.pendingInbound, pendingEvent[BehaviourEvent]{Ctx: ctx, Event: ev}) + + select { + case p.ready <- struct{}{}: + default: + } +} + +// Ready returns a channel that signals when the pooled query behaviour is ready to +// perform work. +func (p *PooledQueryBehaviour) Ready() <-chan struct{} { + return p.ready +} + +// Perform executes the next available task from the queue of pending events or advances +// the query pool. Returns an event containing the result of the work performed and a +// true value, or nil and a false value if no event was generated. +func (p *PooledQueryBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { + p.performMu.Lock() + defer p.performMu.Unlock() + + ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.Perform") + defer span.End() + + defer p.updateReadyStatus() + + // drain queued outbound events first. + ev, ok := p.nextPendingOutbound() + if ok { + return ev, true + } + + // perform pending inbound work. + ev, ok = p.perfomNextInbound(ctx) + if ok { + return ev, true + } + + // poll the query pool to trigger any timeouts and other scheduled work + ev, ok = p.advancePool(ctx, &query.EventPoolPoll{}) + if ok { + return ev, true + } + + // return any queued outbound work that may have been generated + return p.nextPendingOutbound() +} + +func (p *PooledQueryBehaviour) nextPendingOutbound() (BehaviourEvent, bool) { + if len(p.pendingOutbound) == 0 { + return nil, false + } + var ev BehaviourEvent + ev, p.pendingOutbound = p.pendingOutbound[0], p.pendingOutbound[1:] + return ev, true +} + +func (p *PooledQueryBehaviour) nextPendingInbound() (pendingEvent[BehaviourEvent], bool) { + p.pendingInboundMu.Lock() + defer p.pendingInboundMu.Unlock() + if len(p.pendingInbound) == 0 { + return pendingEvent[BehaviourEvent]{}, false + } + var pev pendingEvent[BehaviourEvent] + pev, p.pendingInbound = p.pendingInbound[0], p.pendingInbound[1:] + return pev, true +} + +func (p *PooledQueryBehaviour) perfomNextInbound(ctx context.Context) (BehaviourEvent, bool) { + ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.perfomNextInbound") + defer span.End() + pev, ok := p.nextPendingInbound() + if !ok { + return nil, false + } + + var cmd query.PoolEvent = &query.EventPoolPoll{} - var cmd query.PoolEvent - switch ev := ev.(type) { + switch ev := pev.Event.(type) { case *EventStartFindCloserQuery: cmd = &query.EventPoolAddFindCloserQuery[kadt.Key, kadt.PeerID]{ QueryID: ev.QueryID, @@ -194,12 +280,7 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { QueryID: ev.QueryID, } case *EventGetCloserNodesSuccess: - for _, info := range ev.CloserNodes { - // TODO: do this after advancing pool - p.pending = append(p.pending, &EventAddNode{ - NodeID: info, - }) - } + p.queueAddNodeEvents(ev.CloserNodes) waiter, ok := p.waiters[ev.QueryID] if ok { waiter.Notify(ctx, &EventQueryProgressed{ @@ -217,9 +298,7 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { case *EventGetCloserNodesFailure: // queue an event that will notify the routing behaviour of a failed node p.cfg.Logger.Debug("peer has no connectivity", tele.LogAttrPeerID(ev.To), "source", "query") - p.pending = append(p.pending, &EventNotifyNonConnectivity{ - ev.To, - }) + p.queueNonConnectivityEvent(ev.To) cmd = &query.EventPoolNodeFailure[kadt.Key, kadt.PeerID]{ NodeID: ev.To, @@ -227,12 +306,7 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { Error: ev.Err, } case *EventSendMessageSuccess: - for _, info := range ev.CloserNodes { - // TODO: do this after advancing pool - p.pending = append(p.pending, &EventAddNode{ - NodeID: info, - }) - } + p.queueAddNodeEvents(ev.CloserNodes) waiter, ok := p.waiters[ev.QueryID] if ok { waiter.Notify(ctx, &EventQueryProgressed{ @@ -249,9 +323,7 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { case *EventSendMessageFailure: // queue an event that will notify the routing behaviour of a failed node p.cfg.Logger.Debug("peer has no connectivity", tele.LogAttrPeerID(ev.To), "source", "query") - p.pending = append(p.pending, &EventNotifyNonConnectivity{ - ev.To, - }) + p.queueNonConnectivityEvent(ev.To) cmd = &query.EventPoolNodeFailure[kadt.Key, kadt.PeerID]{ NodeID: ev.To, @@ -263,59 +335,28 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { } // attempt to advance the query pool - ev, ok := p.advancePool(ctx, cmd) - if ok { - p.pending = append(p.pending, ev) - } - if len(p.pending) > 0 { + return p.advancePool(pev.Ctx, cmd) +} + +func (p *PooledQueryBehaviour) updateReadyStatus() { + if len(p.pendingOutbound) != 0 { select { case p.ready <- struct{}{}: default: } + return } -} - -// Ready returns a channel that signals when the pooled query behaviour is ready to -// perform work. -func (p *PooledQueryBehaviour) Ready() <-chan struct{} { - return p.ready -} - -// Perform executes the next available task from the queue of pending events or advances -// the query pool. Returns the executed event and a boolean indicating whether work was -// performed. -func (p *PooledQueryBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { - ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.Perform") - defer span.End() - - // No inbound work can be done until Perform is complete - p.pendingMu.Lock() - defer p.pendingMu.Unlock() - - for { - // drain queued events first. - if len(p.pending) > 0 { - var ev BehaviourEvent - ev, p.pending = p.pending[0], p.pending[1:] - - if len(p.pending) > 0 { - select { - case p.ready <- struct{}{}: - default: - } - } - return ev, true - } - // attempt to advance the query pool - ev, ok := p.advancePool(ctx, &query.EventPoolPoll{}) - if ok { - return ev, true - } + p.pendingInboundMu.Lock() + hasPendingInbound := len(p.pendingInbound) != 0 + p.pendingInboundMu.Unlock() - if len(p.pending) == 0 { - return nil, false + if hasPendingInbound { + select { + case p.ready <- struct{}{}: + default: } + return } } @@ -369,3 +410,18 @@ func (p *PooledQueryBehaviour) advancePool(ctx context.Context, ev query.PoolEve return nil, false } + +func (p *PooledQueryBehaviour) queueAddNodeEvents(nodes []kadt.PeerID) { + for _, info := range nodes { + // TODO: do this after advancing pool + p.pendingOutbound = append(p.pendingOutbound, &EventAddNode{ + NodeID: info, + }) + } +} + +func (p *PooledQueryBehaviour) queueNonConnectivityEvent(nid kadt.PeerID) { + p.pendingOutbound = append(p.pendingOutbound, &EventNotifyNonConnectivity{ + NodeID: nid, + }) +} diff --git a/internal/coord/query_test.go b/internal/coord/query_test.go index 354f40c..ef4a5f4 100644 --- a/internal/coord/query_test.go +++ b/internal/coord/query_test.go @@ -141,6 +141,11 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesNoProgress() { Target: target, }) + // query will process the response and notify that node 1 is non connective + bev, ok = b.Perform(ctx) + ts.Require().True(ok) + ts.Require().IsType(&EventNotifyNonConnectivity{}, bev) + // ensure that the waiter received query finished event wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) ts.Require().IsType(&EventQueryFinished{}, wev.Event) @@ -185,6 +190,11 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryProgressed() { CloserNodes: ts.nodes[1].RoutingTable.NearestNodes(target, 5), }) + // query will process the response and ask node 1 for closer nodes + bev, ok = b.Perform(ctx) + ts.Require().True(ok) + ts.Require().IsType(&EventOutboundGetCloserNodes{}, bev) + // ensure that the waiter received query progressed event wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) ts.Require().IsType(&EventQueryProgressed{}, wev.Event) @@ -229,10 +239,6 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryFinished() { CloserNodes: ts.nodes[1].RoutingTable.NearestNodes(target, 5), }) - // ensure that the waiter received query progressed event - wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) - ts.Require().IsType(&EventQueryProgressed{}, wev.Event) - // skip events until next EventOutboundGetCloserNodes is reached for { bev, ok = b.Perform(ctx) @@ -244,6 +250,10 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryFinished() { } } + // ensure that the waiter received query progressed event + wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) + ts.Require().IsType(&EventQueryProgressed{}, wev.Event) + ts.Require().True(egc.To.Equal(ts.nodes[2].NodeID)) // notify success but no further nodes b.Notify(ctx, &EventGetCloserNodesSuccess{ @@ -252,6 +262,9 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryFinished() { Target: target, }) + bev, ok = b.Perform(ctx) + ts.Require().True(ok) + // ensure that the waiter received query progressed event wev = kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) ts.Require().IsType(&EventQueryProgressed{}, wev.Event) From bdbee9c88ca6a417b5af1e2f5282221da3579c33 Mon Sep 17 00:00:00 2001 From: Ian Davis <18375+iand@users.noreply.github.com> Date: Thu, 12 Oct 2023 15:25:20 +0100 Subject: [PATCH 3/3] Refactor perform logic in query and broadcast behaviours --- internal/coord/behaviour.go | 70 +++++---- internal/coord/brdcst.go | 242 ++++++++++++++++++++++---------- internal/coord/brdcst_events.go | 5 +- internal/coord/coordinator.go | 116 +++++++++------ internal/coord/event.go | 13 +- internal/coord/network.go | 11 -- internal/coord/query.go | 100 +++++++++---- internal/coord/query_test.go | 41 +++--- 8 files changed, 393 insertions(+), 205 deletions(-) diff --git a/internal/coord/behaviour.go b/internal/coord/behaviour.go index 21de6b0..28819aa 100644 --- a/internal/coord/behaviour.go +++ b/internal/coord/behaviour.go @@ -44,7 +44,7 @@ type WorkQueueFunc[E BehaviourEvent] func(context.Context, E) bool // WorkQueueFunc for each work item, passing the original context // and event. type WorkQueue[E BehaviourEvent] struct { - pending chan pendingEvent[E] + pending chan CtxEvent[E] fn WorkQueueFunc[E] done atomic.Bool once sync.Once @@ -52,13 +52,15 @@ type WorkQueue[E BehaviourEvent] struct { func NewWorkQueue[E BehaviourEvent](fn WorkQueueFunc[E]) *WorkQueue[E] { w := &WorkQueue[E]{ - pending: make(chan pendingEvent[E], 1), + pending: make(chan CtxEvent[E], 1), fn: fn, } return w } -type pendingEvent[E any] struct { +// CtxEvent holds and event with an associated context which may carry deadlines or +// tracing information pertinent to the event. +type CtxEvent[E any] struct { Ctx context.Context Event E } @@ -89,7 +91,7 @@ func (w *WorkQueue[E]) Enqueue(ctx context.Context, cmd E) error { select { case <-ctx.Done(): // this is the context for the work item return ctx.Err() - case w.pending <- pendingEvent[E]{ + case w.pending <- CtxEvent[E]{ Ctx: ctx, Event: cmd, }: @@ -147,37 +149,47 @@ func (w *Waiter[E]) Chan() <-chan WaiterEvent[E] { return w.pending } -// NotifyCloserHook implements the [NotifyCloser] interface and provides hooks -// into the Notify and Close calls by wrapping another [NotifyCloser]. This is -// intended to be used in testing. -type NotifyCloserHook[E BehaviourEvent] struct { - nc NotifyCloser[E] - BeforeNotify func(context.Context, E) - AfterNotify func(context.Context, E) - BeforeClose func() - AfterClose func() +// A QueryMonitor receives event notifications on the progress of a query +type QueryMonitor[E TerminalQueryEvent] interface { + // NotifyProgressed returns a channel that can be used to send notification that a + // query has made progress. If the notification cannot be sent then it will be + // queued and retried at a later time. If the query completes before the progress + // notification can be sent the notification will be discarded. + NotifyProgressed() chan<- CtxEvent[*EventQueryProgressed] + + // NotifyFinished returns a channel that can be used to send the notification that a + // query has completed. It is up to the implemention to ensure that the channel has enough + // capacity to receive the single notification. + // The sender must close all other QueryNotifier channels before sending on the NotifyFinished channel. + // The sender may attempt to drain any pending notifications before closing the other channels. + // The NotifyFinished channel will be closed once the sender has attempted to send the Finished notification. + NotifyFinished() chan<- CtxEvent[E] +} + +// QueryMonitorHook wraps a [QueryMonitor] interface and provides hooks +// that are invoked before calls to the QueryMonitor methods are forwarded. +type QueryMonitorHook[E TerminalQueryEvent] struct { + qm QueryMonitor[E] + BeforeProgressed func() + BeforeFinished func() } -var _ NotifyCloser[BehaviourEvent] = (*NotifyCloserHook[BehaviourEvent])(nil) +var _ QueryMonitor[*EventQueryFinished] = (*QueryMonitorHook[*EventQueryFinished])(nil) -func NewNotifyCloserHook[E BehaviourEvent](nc NotifyCloser[E]) *NotifyCloserHook[E] { - return &NotifyCloserHook[E]{ - nc: nc, - BeforeNotify: func(ctx context.Context, e E) {}, - AfterNotify: func(ctx context.Context, e E) {}, - BeforeClose: func() {}, - AfterClose: func() {}, +func NewQueryMonitorHook[E TerminalQueryEvent](qm QueryMonitor[E]) *QueryMonitorHook[E] { + return &QueryMonitorHook[E]{ + qm: qm, + BeforeProgressed: func() {}, + BeforeFinished: func() {}, } } -func (n *NotifyCloserHook[E]) Notify(ctx context.Context, ev E) { - n.BeforeNotify(ctx, ev) - n.nc.Notify(ctx, ev) - n.AfterNotify(ctx, ev) +func (n *QueryMonitorHook[E]) NotifyProgressed() chan<- CtxEvent[*EventQueryProgressed] { + n.BeforeProgressed() + return n.qm.NotifyProgressed() } -func (n *NotifyCloserHook[E]) Close() { - n.BeforeClose() - n.nc.Close() - n.AfterClose() +func (n *QueryMonitorHook[E]) NotifyFinished() chan<- CtxEvent[E] { + n.BeforeFinished() + return n.qm.NotifyFinished() } diff --git a/internal/coord/brdcst.go b/internal/coord/brdcst.go index 4251bf4..331ea41 100644 --- a/internal/coord/brdcst.go +++ b/internal/coord/brdcst.go @@ -15,26 +15,42 @@ import ( ) type PooledBroadcastBehaviour struct { - pool coordt.StateMachine[brdcst.PoolEvent, brdcst.PoolState] - waiters map[coordt.QueryID]NotifyCloser[BehaviourEvent] - - pendingMu sync.Mutex - pending []BehaviourEvent - ready chan struct{} - logger *slog.Logger tracer trace.Tracer + + // performMu is held while Perform is executing to ensure sequential execution of work. + performMu sync.Mutex + + // pool is the broadcast pool state machine used for managing individual broadcasts. + // it must only be accessed while performMu is held + pool coordt.StateMachine[brdcst.PoolEvent, brdcst.PoolState] + + // pendingOutbound is a queue of outbound events. + // it must only be accessed while performMu is held + pendingOutbound []BehaviourEvent + + // notifiers is a map that keeps track of event notifications for each running broadcast. + // it must only be accessed while performMu is held + notifiers map[coordt.QueryID]*queryNotifier[*EventBroadcastFinished] + + // pendingInboundMu guards access to pendingInbound + pendingInboundMu sync.Mutex + + // pendingInbound is a queue of inbound events that are awaiting processing + pendingInbound []CtxEvent[BehaviourEvent] + + ready chan struct{} } var _ Behaviour[BehaviourEvent, BehaviourEvent] = (*PooledBroadcastBehaviour)(nil) func NewPooledBroadcastBehaviour(brdcstPool *brdcst.Pool[kadt.Key, kadt.PeerID, *pb.Message], logger *slog.Logger, tracer trace.Tracer) *PooledBroadcastBehaviour { b := &PooledBroadcastBehaviour{ - pool: brdcstPool, - waiters: make(map[coordt.QueryID]NotifyCloser[BehaviourEvent]), - ready: make(chan struct{}, 1), - logger: logger.With("behaviour", "pooledBroadcast"), - tracer: tracer, + pool: brdcstPool, + notifiers: make(map[coordt.QueryID]*queryNotifier[*EventBroadcastFinished]), + ready: make(chan struct{}, 1), + logger: logger.With("behaviour", "pooledBroadcast"), + tracer: tracer, } return b } @@ -44,14 +60,108 @@ func (b *PooledBroadcastBehaviour) Ready() <-chan struct{} { } func (b *PooledBroadcastBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { + b.pendingInboundMu.Lock() + defer b.pendingInboundMu.Unlock() + ctx, span := b.tracer.Start(ctx, "PooledBroadcastBehaviour.Notify") defer span.End() - b.pendingMu.Lock() - defer b.pendingMu.Unlock() + b.pendingInbound = append(b.pendingInbound, CtxEvent[BehaviourEvent]{Ctx: ctx, Event: ev}) + + select { + case b.ready <- struct{}{}: + default: + } +} + +func (b *PooledBroadcastBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { + b.performMu.Lock() + defer b.performMu.Unlock() + + ctx, span := b.tracer.Start(ctx, "PooledBroadcastBehaviour.Perform") + defer span.End() + + defer b.updateReadyStatus() + + // first send any pending query notifications + for _, w := range b.notifiers { + w.DrainPending() + } + + // drain queued outbound events before starting new work. + ev, ok := b.nextPendingOutbound() + if ok { + return ev, true + } + + // perform one piece of pending inbound work. + ev, ok = b.perfomNextInbound(ctx) + if ok { + return ev, true + } + + // poll the broadcast pool to trigger any timeouts and other scheduled work + ev, ok = b.advancePool(ctx, &brdcst.EventPoolPoll{}) + if ok { + return ev, true + } + + // return any queued outbound work that may have been generated + return b.nextPendingOutbound() +} + +func (b *PooledBroadcastBehaviour) nextPendingOutbound() (BehaviourEvent, bool) { + if len(b.pendingOutbound) == 0 { + return nil, false + } + var ev BehaviourEvent + ev, b.pendingOutbound = b.pendingOutbound[0], b.pendingOutbound[1:] + return ev, true +} + +func (b *PooledBroadcastBehaviour) nextPendingInbound() (CtxEvent[BehaviourEvent], bool) { + b.pendingInboundMu.Lock() + defer b.pendingInboundMu.Unlock() + if len(b.pendingInbound) == 0 { + return CtxEvent[BehaviourEvent]{}, false + } + var pev CtxEvent[BehaviourEvent] + pev, b.pendingInbound = b.pendingInbound[0], b.pendingInbound[1:] + return pev, true +} + +func (b *PooledBroadcastBehaviour) updateReadyStatus() { + if len(b.pendingOutbound) != 0 { + select { + case b.ready <- struct{}{}: + default: + } + return + } + + b.pendingInboundMu.Lock() + hasPendingInbound := len(b.pendingInbound) != 0 + b.pendingInboundMu.Unlock() + + if hasPendingInbound { + select { + case b.ready <- struct{}{}: + default: + } + return + } +} + +func (b *PooledBroadcastBehaviour) perfomNextInbound(ctx context.Context) (BehaviourEvent, bool) { + ctx, span := b.tracer.Start(ctx, "PooledBroadcastBehaviour.perfomNextInbound") + defer span.End() + pev, ok := b.nextPendingInbound() + if !ok { + return nil, false + } var cmd brdcst.PoolEvent - switch ev := ev.(type) { + switch ev := pev.Event.(type) { case *EventStartBroadcast: cmd = &brdcst.EventPoolStartBroadcast[kadt.Key, kadt.PeerID, *pb.Message]{ QueryID: ev.QueryID, @@ -61,19 +171,19 @@ func (b *PooledBroadcastBehaviour) Notify(ctx context.Context, ev BehaviourEvent Config: ev.Config, } if ev.Notify != nil { - b.waiters[ev.QueryID] = ev.Notify + b.notifiers[ev.QueryID] = &queryNotifier[*EventBroadcastFinished]{monitor: ev.Notify} } case *EventGetCloserNodesSuccess: for _, info := range ev.CloserNodes { - b.pending = append(b.pending, &EventAddNode{ + b.pendingOutbound = append(b.pendingOutbound, &EventAddNode{ NodeID: info, }) } - waiter, ok := b.waiters[ev.QueryID] + waiter, ok := b.notifiers[ev.QueryID] if ok { - waiter.Notify(ctx, &EventQueryProgressed{ + waiter.TryNotifyProgressed(ctx, &EventQueryProgressed{ NodeID: ev.To, QueryID: ev.QueryID, }) @@ -88,7 +198,7 @@ func (b *PooledBroadcastBehaviour) Notify(ctx context.Context, ev BehaviourEvent case *EventGetCloserNodesFailure: // queue an event that will notify the routing behaviour of a failed node - b.pending = append(b.pending, &EventNotifyNonConnectivity{ + b.pendingOutbound = append(b.pendingOutbound, &EventNotifyNonConnectivity{ ev.To, }) @@ -101,13 +211,13 @@ func (b *PooledBroadcastBehaviour) Notify(ctx context.Context, ev BehaviourEvent case *EventSendMessageSuccess: for _, info := range ev.CloserNodes { - b.pending = append(b.pending, &EventAddNode{ + b.pendingOutbound = append(b.pendingOutbound, &EventAddNode{ NodeID: info, }) } - waiter, ok := b.waiters[ev.QueryID] + waiter, ok := b.notifiers[ev.QueryID] if ok { - waiter.Notify(ctx, &EventQueryProgressed{ + waiter.TryNotifyProgressed(ctx, &EventQueryProgressed{ NodeID: ev.To, QueryID: ev.QueryID, Response: ev.Response, @@ -123,7 +233,7 @@ func (b *PooledBroadcastBehaviour) Notify(ctx context.Context, ev BehaviourEvent case *EventSendMessageFailure: // queue an event that will notify the routing behaviour of a failed node - b.pending = append(b.pending, &EventNotifyNonConnectivity{ + b.pendingOutbound = append(b.pendingOutbound, &EventNotifyNonConnectivity{ ev.To, }) @@ -142,51 +252,7 @@ func (b *PooledBroadcastBehaviour) Notify(ctx context.Context, ev BehaviourEvent } // attempt to advance the broadcast pool - ev, ok := b.advancePool(ctx, cmd) - if ok { - b.pending = append(b.pending, ev) - } - if len(b.pending) > 0 { - select { - case b.ready <- struct{}{}: - default: - } - } -} - -func (b *PooledBroadcastBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { - ctx, span := b.tracer.Start(ctx, "PooledBroadcastBehaviour.Perform") - defer span.End() - - // No inbound work can be done until Perform is complete - b.pendingMu.Lock() - defer b.pendingMu.Unlock() - - for { - // drain queued events first. - if len(b.pending) > 0 { - var ev BehaviourEvent - ev, b.pending = b.pending[0], b.pending[1:] - - if len(b.pending) > 0 { - select { - case b.ready <- struct{}{}: - default: - } - } - return ev, true - } - - ev, ok := b.advancePool(ctx, &brdcst.EventPoolPoll{}) - if ok { - return ev, true - } - - // finally check if any pending events were accumulated in the meantime - if len(b.pending) == 0 { - return nil, false - } - } + return b.advancePool(ctx, cmd) } func (b *PooledBroadcastBehaviour) advancePool(ctx context.Context, ev brdcst.PoolEvent) (out BehaviourEvent, term bool) { @@ -215,16 +281,48 @@ func (b *PooledBroadcastBehaviour) advancePool(ctx context.Context, ev brdcst.Po Notify: b, }, true case *brdcst.StatePoolBroadcastFinished[kadt.Key, kadt.PeerID]: - waiter, ok := b.waiters[st.QueryID] + waiter, ok := b.notifiers[st.QueryID] if ok { - waiter.Notify(ctx, &EventBroadcastFinished{ + waiter.NotifyFinished(ctx, &EventBroadcastFinished{ QueryID: st.QueryID, Contacted: st.Contacted, Errors: st.Errors, }) - waiter.Close() + delete(b.notifiers, st.QueryID) } } return nil, false } + +// A BroadcastWaiter implements [QueryMonitor] for broadcasts +type BroadcastWaiter struct { + progressed chan CtxEvent[*EventQueryProgressed] + finished chan CtxEvent[*EventBroadcastFinished] +} + +var _ QueryMonitor[*EventBroadcastFinished] = (*BroadcastWaiter)(nil) + +func NewBroadcastWaiter(n int) *BroadcastWaiter { + w := &BroadcastWaiter{ + progressed: make(chan CtxEvent[*EventQueryProgressed], n), + finished: make(chan CtxEvent[*EventBroadcastFinished], 1), + } + return w +} + +func (w *BroadcastWaiter) Progressed() <-chan CtxEvent[*EventQueryProgressed] { + return w.progressed +} + +func (w *BroadcastWaiter) Finished() <-chan CtxEvent[*EventBroadcastFinished] { + return w.finished +} + +func (w *BroadcastWaiter) NotifyProgressed() chan<- CtxEvent[*EventQueryProgressed] { + return w.progressed +} + +func (w *BroadcastWaiter) NotifyFinished() chan<- CtxEvent[*EventBroadcastFinished] { + return w.finished +} diff --git a/internal/coord/brdcst_events.go b/internal/coord/brdcst_events.go index 9158939..60b44f8 100644 --- a/internal/coord/brdcst_events.go +++ b/internal/coord/brdcst_events.go @@ -14,7 +14,7 @@ type EventStartBroadcast struct { Message *pb.Message Seed []kadt.PeerID Config brdcst.Config - Notify NotifyCloser[BehaviourEvent] + Notify QueryMonitor[*EventBroadcastFinished] } func (*EventStartBroadcast) behaviourEvent() {} @@ -31,4 +31,5 @@ type EventBroadcastFinished struct { } } -func (*EventBroadcastFinished) behaviourEvent() {} +func (*EventBroadcastFinished) behaviourEvent() {} +func (*EventBroadcastFinished) terminalQueryEvent() {} diff --git a/internal/coord/coordinator.go b/internal/coord/coordinator.go index 3becf56..83335fc 100644 --- a/internal/coord/coordinator.go +++ b/internal/coord/coordinator.go @@ -314,7 +314,7 @@ func (c *Coordinator) QueryClosest(ctx context.Context, target kadt.Key, fn coor return nil, coordt.QueryStats{}, err } - waiter := NewWaiter[BehaviourEvent]() + waiter := NewQueryWaiter(numResults) queryID := c.newOperationID() cmd := &EventStartFindCloserQuery{ @@ -362,7 +362,7 @@ func (c *Coordinator) QueryMessage(ctx context.Context, msg *pb.Message, fn coor return nil, coordt.QueryStats{}, err } - waiter := NewWaiter[BehaviourEvent]() + waiter := NewQueryWaiter(numResults) queryID := c.newOperationID() cmd := &EventStartMessageQuery{ @@ -412,7 +412,7 @@ func (c *Coordinator) broadcast(ctx context.Context, msg *pb.Message, seeds []ka ctx, cancel := context.WithCancel(ctx) defer cancel() - waiter := NewWaiter[BehaviourEvent]() + waiter := NewBroadcastWaiter(0) // zero capacity since waitForBroadcast ignores progress events queryID := c.newOperationID() cmd := &EventStartBroadcast{ @@ -441,19 +441,41 @@ func (c *Coordinator) broadcast(ctx context.Context, msg *pb.Message, seeds []ka return nil } -func (c *Coordinator) waitForQuery(ctx context.Context, queryID coordt.QueryID, waiter *Waiter[BehaviourEvent], fn coordt.QueryFunc) ([]kadt.PeerID, coordt.QueryStats, error) { +func (c *Coordinator) waitForQuery(ctx context.Context, queryID coordt.QueryID, waiter *QueryWaiter, fn coordt.QueryFunc) ([]kadt.PeerID, coordt.QueryStats, error) { var lastStats coordt.QueryStats for { select { case <-ctx.Done(): return nil, lastStats, ctx.Err() - case wev, more := <-waiter.Chan(): + + case wev, more := <-waiter.Progressed(): if !more { return nil, lastStats, ctx.Err() } ctx, ev := wev.Ctx, wev.Event - switch ev := ev.(type) { - case *EventQueryProgressed: + c.cfg.Logger.Debug("query made progress", "query_id", queryID, tele.LogAttrPeerID(ev.NodeID), slog.Duration("elapsed", c.cfg.Clock.Since(ev.Stats.Start)), slog.Int("requests", ev.Stats.Requests), slog.Int("failures", ev.Stats.Failure)) + lastStats = coordt.QueryStats{ + Start: ev.Stats.Start, + Requests: ev.Stats.Requests, + Success: ev.Stats.Success, + Failure: ev.Stats.Failure, + } + err := fn(ctx, ev.NodeID, ev.Response, lastStats) + if errors.Is(err, coordt.ErrSkipRemaining) { + // done + c.cfg.Logger.Debug("query done", "query_id", queryID) + c.queryBehaviour.Notify(ctx, &EventStopQuery{QueryID: queryID}) + return nil, lastStats, nil + } + if err != nil { + // user defined error that terminates the query + c.queryBehaviour.Notify(ctx, &EventStopQuery{QueryID: queryID}) + return nil, lastStats, err + } + case wev, more := <-waiter.Finished(): + // drain the progress notification channel + for pev := range waiter.Progressed() { + ctx, ev := pev.Ctx, pev.Event c.cfg.Logger.Debug("query made progress", "query_id", queryID, tele.LogAttrPeerID(ev.NodeID), slog.Duration("elapsed", c.cfg.Clock.Since(ev.Stats.Start)), slog.Int("requests", ev.Stats.Requests), slog.Int("failures", ev.Stats.Failure)) lastStats = coordt.QueryStats{ Start: ev.Stats.Start, @@ -461,40 +483,24 @@ func (c *Coordinator) waitForQuery(ctx context.Context, queryID coordt.QueryID, Success: ev.Stats.Success, Failure: ev.Stats.Failure, } - nh, err := c.networkBehaviour.getNodeHandler(ctx, ev.NodeID) - if err != nil { - // ignore unknown node - c.cfg.Logger.Debug("node handler not found", "query_id", queryID, tele.LogAttrError, err) - break - } - - err = fn(ctx, nh.ID(), ev.Response, lastStats) - if errors.Is(err, coordt.ErrSkipRemaining) { - // done - c.cfg.Logger.Debug("query done", "query_id", queryID) - c.queryBehaviour.Notify(ctx, &EventStopQuery{QueryID: queryID}) - return nil, lastStats, nil - } - if err != nil { - // user defined error that terminates the query - c.queryBehaviour.Notify(ctx, &EventStopQuery{QueryID: queryID}) + if err := fn(ctx, ev.NodeID, ev.Response, lastStats); err != nil { return nil, lastStats, err } + } + if !more { + return nil, lastStats, ctx.Err() + } - case *EventQueryFinished: - // query is done - lastStats.Exhausted = true - c.cfg.Logger.Debug("query ran to exhaustion", "query_id", queryID, slog.Duration("elapsed", ev.Stats.End.Sub(ev.Stats.Start)), slog.Int("requests", ev.Stats.Requests), slog.Int("failures", ev.Stats.Failure)) - return ev.ClosestNodes, lastStats, nil + // query is done + lastStats.Exhausted = true + c.cfg.Logger.Debug("query ran to exhaustion", "query_id", queryID, slog.Duration("elapsed", wev.Event.Stats.End.Sub(wev.Event.Stats.Start)), slog.Int("requests", wev.Event.Stats.Requests), slog.Int("failures", wev.Event.Stats.Failure)) + return wev.Event.ClosestNodes, lastStats, nil - default: - panic(fmt.Sprintf("unexpected event: %T", ev)) - } } } } -func (c *Coordinator) waitForBroadcast(ctx context.Context, waiter *Waiter[BehaviourEvent]) ([]kadt.PeerID, map[string]struct { +func (c *Coordinator) waitForBroadcast(ctx context.Context, waiter *BroadcastWaiter) ([]kadt.PeerID, map[string]struct { Node kadt.PeerID Err error }, error, @@ -503,19 +509,11 @@ func (c *Coordinator) waitForBroadcast(ctx context.Context, waiter *Waiter[Behav select { case <-ctx.Done(): return nil, nil, ctx.Err() - case wev, more := <-waiter.Chan(): + case wev, more := <-waiter.Finished(): if !more { return nil, nil, ctx.Err() } - - switch ev := wev.Event.(type) { - case *EventQueryProgressed: - case *EventBroadcastFinished: - return ev.Contacted, ev.Errors, nil - - default: - panic(fmt.Sprintf("unexpected event: %T", ev)) - } + return wev.Event.Contacted, wev.Event.Errors, nil } } } @@ -684,3 +682,35 @@ func (w *BufferedRoutingNotifier) ExpectRoutingRemoved(ctx context.Context, id k type nullRoutingNotifier struct{} func (nullRoutingNotifier) Notify(context.Context, RoutingNotification) {} + +// A QueryWaiter implements [QueryMonitor] for general queries +type QueryWaiter struct { + progressed chan CtxEvent[*EventQueryProgressed] + finished chan CtxEvent[*EventQueryFinished] +} + +var _ QueryMonitor[*EventQueryFinished] = (*QueryWaiter)(nil) + +func NewQueryWaiter(n int) *QueryWaiter { + w := &QueryWaiter{ + progressed: make(chan CtxEvent[*EventQueryProgressed], n), + finished: make(chan CtxEvent[*EventQueryFinished], 1), + } + return w +} + +func (w *QueryWaiter) Progressed() <-chan CtxEvent[*EventQueryProgressed] { + return w.progressed +} + +func (w *QueryWaiter) Finished() <-chan CtxEvent[*EventQueryFinished] { + return w.finished +} + +func (w *QueryWaiter) NotifyProgressed() chan<- CtxEvent[*EventQueryProgressed] { + return w.progressed +} + +func (w *QueryWaiter) NotifyFinished() chan<- CtxEvent[*EventQueryFinished] { + return w.finished +} diff --git a/internal/coord/event.go b/internal/coord/event.go index 04019c2..6766ea7 100644 --- a/internal/coord/event.go +++ b/internal/coord/event.go @@ -50,6 +50,12 @@ type RoutingNotification interface { routingNotification() } +// TerminalQueryEvent is a type of [BehaviourEvent] that indicates a query has completed. +type TerminalQueryEvent interface { + BehaviourEvent + terminalQueryEvent() +} + type EventStartBootstrap struct { SeedNodes []kadt.PeerID } @@ -84,7 +90,7 @@ type EventStartMessageQuery struct { Target kadt.Key Message *pb.Message KnownClosestNodes []kadt.PeerID - Notify NotifyCloser[BehaviourEvent] + Notify QueryMonitor[*EventQueryFinished] NumResults int // the minimum number of nodes to successfully contact before considering iteration complete } @@ -95,7 +101,7 @@ type EventStartFindCloserQuery struct { QueryID coordt.QueryID Target kadt.Key KnownClosestNodes []kadt.PeerID - Notify NotifyCloser[BehaviourEvent] + Notify QueryMonitor[*EventQueryFinished] NumResults int // the minimum number of nodes to successfully contact before considering iteration complete } @@ -186,7 +192,8 @@ type EventQueryFinished struct { ClosestNodes []kadt.PeerID } -func (*EventQueryFinished) behaviourEvent() {} +func (*EventQueryFinished) behaviourEvent() {} +func (*EventQueryFinished) terminalQueryEvent() {} // EventRoutingUpdated is emitted by the coordinator when a new node has been verified and added to the routing table. type EventRoutingUpdated struct { diff --git a/internal/coord/network.go b/internal/coord/network.go index 4459d94..4dc81a7 100644 --- a/internal/coord/network.go +++ b/internal/coord/network.go @@ -106,17 +106,6 @@ func (b *NetworkBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { return nil, false } -func (b *NetworkBehaviour) getNodeHandler(ctx context.Context, id kadt.PeerID) (*NodeHandler, error) { - b.nodeHandlersMu.Lock() - nh, ok := b.nodeHandlers[id] - if !ok { - nh = NewNodeHandler(id, b.rtr, b.logger, b.tracer) - b.nodeHandlers[id] = nh - } - b.nodeHandlersMu.Unlock() - return nh, nil -} - type NodeHandler struct { self kadt.PeerID rtr coordt.Router[kadt.Key, kadt.PeerID, *pb.Message] diff --git a/internal/coord/query.go b/internal/coord/query.go index 1185079..05f23cb 100644 --- a/internal/coord/query.go +++ b/internal/coord/query.go @@ -119,9 +119,9 @@ type PooledQueryBehaviour struct { // it must only be accessed while performMu is held pool *query.Pool[kadt.Key, kadt.PeerID, *pb.Message] - // waiters is a map that keeps track of event notifications for each running query. + // notifiers is a map that keeps track of event notifications for each running query. // it must only be accessed while performMu is held - waiters map[coordt.QueryID]NotifyCloser[BehaviourEvent] + notifiers map[coordt.QueryID]*queryNotifier[*EventQueryFinished] // pendingOutbound is a queue of outbound events. // it must only be accessed while performMu is held @@ -131,7 +131,7 @@ type PooledQueryBehaviour struct { pendingInboundMu sync.Mutex // pendingInbound is a queue of inbound events that are awaiting processing - pendingInbound []pendingEvent[BehaviourEvent] + pendingInbound []CtxEvent[BehaviourEvent] // ready is a channel signaling that the behaviour has work to perform. ready chan struct{} @@ -159,10 +159,10 @@ func NewPooledQueryBehaviour(self kadt.PeerID, cfg *PooledQueryConfig) (*PooledQ } h := &PooledQueryBehaviour{ - cfg: *cfg, - pool: pool, - waiters: make(map[coordt.QueryID]NotifyCloser[BehaviourEvent]), - ready: make(chan struct{}, 1), + cfg: *cfg, + pool: pool, + notifiers: make(map[coordt.QueryID]*queryNotifier[*EventQueryFinished]), + ready: make(chan struct{}, 1), } return h, err } @@ -171,13 +171,13 @@ func NewPooledQueryBehaviour(self kadt.PeerID, cfg *PooledQueryConfig) (*PooledQ // stopping, or updating queries. It also queues events for later processing and // triggers the advancement of the query pool if applicable. func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { - ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.Notify") - defer span.End() - p.pendingInboundMu.Lock() defer p.pendingInboundMu.Unlock() - p.pendingInbound = append(p.pendingInbound, pendingEvent[BehaviourEvent]{Ctx: ctx, Event: ev}) + ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.Notify") + defer span.End() + + p.pendingInbound = append(p.pendingInbound, CtxEvent[BehaviourEvent]{Ctx: ctx, Event: ev}) select { case p.ready <- struct{}{}: @@ -203,13 +203,18 @@ func (p *PooledQueryBehaviour) Perform(ctx context.Context) (BehaviourEvent, boo defer p.updateReadyStatus() - // drain queued outbound events first. + // first send any pending query notifications + for _, w := range p.notifiers { + w.DrainPending() + } + + // drain queued outbound events before starting new work. ev, ok := p.nextPendingOutbound() if ok { return ev, true } - // perform pending inbound work. + // perform one piece of pending inbound work. ev, ok = p.perfomNextInbound(ctx) if ok { return ev, true @@ -234,13 +239,13 @@ func (p *PooledQueryBehaviour) nextPendingOutbound() (BehaviourEvent, bool) { return ev, true } -func (p *PooledQueryBehaviour) nextPendingInbound() (pendingEvent[BehaviourEvent], bool) { +func (p *PooledQueryBehaviour) nextPendingInbound() (CtxEvent[BehaviourEvent], bool) { p.pendingInboundMu.Lock() defer p.pendingInboundMu.Unlock() if len(p.pendingInbound) == 0 { - return pendingEvent[BehaviourEvent]{}, false + return CtxEvent[BehaviourEvent]{}, false } - var pev pendingEvent[BehaviourEvent] + var pev CtxEvent[BehaviourEvent] pev, p.pendingInbound = p.pendingInbound[0], p.pendingInbound[1:] return pev, true } @@ -263,7 +268,7 @@ func (p *PooledQueryBehaviour) perfomNextInbound(ctx context.Context) (Behaviour Seed: ev.KnownClosestNodes, } if ev.Notify != nil { - p.waiters[ev.QueryID] = ev.Notify + p.notifiers[ev.QueryID] = &queryNotifier[*EventQueryFinished]{monitor: ev.Notify} } case *EventStartMessageQuery: cmd = &query.EventPoolAddQuery[kadt.Key, kadt.PeerID, *pb.Message]{ @@ -273,7 +278,7 @@ func (p *PooledQueryBehaviour) perfomNextInbound(ctx context.Context) (Behaviour Seed: ev.KnownClosestNodes, } if ev.Notify != nil { - p.waiters[ev.QueryID] = ev.Notify + p.notifiers[ev.QueryID] = &queryNotifier[*EventQueryFinished]{monitor: ev.Notify} } case *EventStopQuery: cmd = &query.EventPoolStopQuery{ @@ -281,9 +286,9 @@ func (p *PooledQueryBehaviour) perfomNextInbound(ctx context.Context) (Behaviour } case *EventGetCloserNodesSuccess: p.queueAddNodeEvents(ev.CloserNodes) - waiter, ok := p.waiters[ev.QueryID] + waiter, ok := p.notifiers[ev.QueryID] if ok { - waiter.Notify(ctx, &EventQueryProgressed{ + waiter.TryNotifyProgressed(ctx, &EventQueryProgressed{ NodeID: ev.To, QueryID: ev.QueryID, // CloserNodes: CloserNodeIDs(ev.CloserNodes), @@ -307,9 +312,9 @@ func (p *PooledQueryBehaviour) perfomNextInbound(ctx context.Context) (Behaviour } case *EventSendMessageSuccess: p.queueAddNodeEvents(ev.CloserNodes) - waiter, ok := p.waiters[ev.QueryID] + waiter, ok := p.notifiers[ev.QueryID] if ok { - waiter.Notify(ctx, &EventQueryProgressed{ + waiter.TryNotifyProgressed(ctx, &EventQueryProgressed{ NodeID: ev.To, QueryID: ev.QueryID, Response: ev.Response, @@ -391,14 +396,14 @@ func (p *PooledQueryBehaviour) advancePool(ctx context.Context, ev query.PoolEve case *query.StatePoolWaitingWithCapacity: // nothing to do except wait for message response or timeout case *query.StatePoolQueryFinished[kadt.Key, kadt.PeerID]: - waiter, ok := p.waiters[st.QueryID] + waiter, ok := p.notifiers[st.QueryID] if ok { - waiter.Notify(ctx, &EventQueryFinished{ + waiter.NotifyFinished(ctx, &EventQueryFinished{ QueryID: st.QueryID, Stats: st.Stats, ClosestNodes: st.ClosestNodes, }) - waiter.Close() + delete(p.notifiers, st.QueryID) } case *query.StatePoolQueryTimeout: // TODO @@ -413,7 +418,6 @@ func (p *PooledQueryBehaviour) advancePool(ctx context.Context, ev query.PoolEve func (p *PooledQueryBehaviour) queueAddNodeEvents(nodes []kadt.PeerID) { for _, info := range nodes { - // TODO: do this after advancing pool p.pendingOutbound = append(p.pendingOutbound, &EventAddNode{ NodeID: info, }) @@ -425,3 +429,47 @@ func (p *PooledQueryBehaviour) queueNonConnectivityEvent(nid kadt.PeerID) { NodeID: nid, }) } + +type queryNotifier[E TerminalQueryEvent] struct { + monitor QueryMonitor[E] + pending []CtxEvent[*EventQueryProgressed] + stopping bool +} + +func (w *queryNotifier[E]) TryNotifyProgressed(ctx context.Context, ev *EventQueryProgressed) bool { + if w.stopping { + return false + } + ce := CtxEvent[*EventQueryProgressed]{Ctx: ctx, Event: ev} + select { + case w.monitor.NotifyProgressed() <- ce: + return true + default: + w.pending = append(w.pending, ce) + return false + } +} + +// DrainPending attempts to drain as many pending progress events as possible +func (w *queryNotifier[E]) DrainPending() { + for i, ce := range w.pending { + select { + case w.monitor.NotifyProgressed() <- ce: + default: + w.pending = w.pending[i:] + return + } + } +} + +func (w *queryNotifier[E]) NotifyFinished(ctx context.Context, ev E) { + w.stopping = true + w.DrainPending() + close(w.monitor.NotifyProgressed()) + + select { + case w.monitor.NotifyFinished() <- CtxEvent[E]{Ctx: ctx, Event: ev}: + default: + } + close(w.monitor.NotifyFinished()) +} diff --git a/internal/coord/query_test.go b/internal/coord/query_test.go index ef4a5f4..40d285b 100644 --- a/internal/coord/query_test.go +++ b/internal/coord/query_test.go @@ -114,7 +114,7 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesNoProgress() { b, err := NewPooledQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) ts.Require().NoError(err) - waiter := NewWaiter[BehaviourEvent]() + waiter := NewQueryWaiter(5) cmd := &EventStartFindCloserQuery{ QueryID: "test", Target: target, @@ -147,8 +147,7 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesNoProgress() { ts.Require().IsType(&EventNotifyNonConnectivity{}, bev) // ensure that the waiter received query finished event - wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) - ts.Require().IsType(&EventQueryFinished{}, wev.Event) + kadtest.ReadItem[CtxEvent[*EventQueryFinished]](t, ctx, waiter.Finished()) } func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryProgressed() { @@ -162,7 +161,7 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryProgressed() { b, err := NewPooledQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) ts.Require().NoError(err) - waiter := NewWaiter[BehaviourEvent]() + waiter := NewQueryWaiter(5) cmd := &EventStartFindCloserQuery{ QueryID: "test", Target: target, @@ -196,8 +195,7 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryProgressed() { ts.Require().IsType(&EventOutboundGetCloserNodes{}, bev) // ensure that the waiter received query progressed event - wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) - ts.Require().IsType(&EventQueryProgressed{}, wev.Event) + kadtest.ReadItem[CtxEvent[*EventQueryProgressed]](t, ctx, waiter.Progressed()) } func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryFinished() { @@ -211,7 +209,7 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryFinished() { b, err := NewPooledQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) ts.Require().NoError(err) - waiter := NewWaiter[BehaviourEvent]() + waiter := NewQueryWaiter(5) cmd := &EventStartFindCloserQuery{ QueryID: "test", Target: target, @@ -251,23 +249,29 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryFinished() { } // ensure that the waiter received query progressed event - wev := kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) - ts.Require().IsType(&EventQueryProgressed{}, wev.Event) + wev := kadtest.ReadItem[CtxEvent[*EventQueryProgressed]](t, ctx, waiter.Progressed()) + ts.Require().True(wev.Event.NodeID.Equal(ts.nodes[1].NodeID)) - ts.Require().True(egc.To.Equal(ts.nodes[2].NodeID)) - // notify success but no further nodes + // notify success for last seen EventOutboundGetCloserNodes but supply no further nodes b.Notify(ctx, &EventGetCloserNodesSuccess{ QueryID: "test", To: egc.To, Target: target, }) - bev, ok = b.Perform(ctx) - ts.Require().True(ok) + // skip events until behaviour runs out of work + for { + _, ok = b.Perform(ctx) + if !ok { + break + } + } // ensure that the waiter received query progressed event - wev = kadtest.ReadItem[WaiterEvent[BehaviourEvent]](t, ctx, waiter.Chan()) - ts.Require().IsType(&EventQueryProgressed{}, wev.Event) + kadtest.ReadItem[CtxEvent[*EventQueryProgressed]](t, ctx, waiter.Progressed()) + + // ensure that the waiter received query event + kadtest.ReadItem[CtxEvent[*EventQueryFinished]](t, ctx, waiter.Finished()) } func TestPooledQuery_deadlock_regression(t *testing.T) { @@ -299,8 +303,8 @@ func TestPooledQuery_deadlock_regression(t *testing.T) { } // start query - waiter := NewWaiter[BehaviourEvent]() - wrappedWaiter := NewNotifyCloserHook[BehaviourEvent](waiter) + waiter := NewQueryWaiter(5) + wrappedWaiter := NewQueryMonitorHook[*EventQueryFinished](waiter) waiterDone := make(chan struct{}) waiterMsg := make(chan struct{}) @@ -344,9 +348,8 @@ func TestPooledQuery_deadlock_regression(t *testing.T) { hasLock := make(chan struct{}) var once sync.Once - wrappedWaiter.BeforeNotify = func(ctx context.Context, event BehaviourEvent) { + wrappedWaiter.BeforeProgressed = func() { once.Do(func() { - require.IsType(t, &EventQueryProgressed{}, event) // verify test invariant close(hasLock) }) }