Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiNode Adaptor Managed Subscriptions #960

Open
wants to merge 12 commits into
base: BCFR-1071-Generic-MultiNodeClient
Choose a base branch
from
Open
48 changes: 33 additions & 15 deletions pkg/solana/client/multinode/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ func (m *MultiNodeAdapter[RPC, HEAD]) LenSubs() int {
}

// registerSub adds the sub to the rpcMultiNodeAdapter list
func (m *MultiNodeAdapter[RPC, HEAD]) registerSub(sub Subscription, stopInFLightCh chan struct{}) error {
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()
func (m *MultiNodeAdapter[RPC, HEAD]) registerSub(sub *ManagedSubscription, stopInFLightCh chan struct{}) error {
// ensure that the `sub` belongs to current life cycle of the `rpcMultiNodeAdapter` and it should not be killed due to
// previous `DisconnectAll` call.
select {
Expand All @@ -73,11 +71,18 @@ func (m *MultiNodeAdapter[RPC, HEAD]) registerSub(sub Subscription, stopInFLight
return fmt.Errorf("failed to register subscription - all in-flight requests were canceled")
default:
}
// TODO: BCI-3358 - delete sub when caller unsubscribes.
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()
m.subs[sub] = struct{}{}
return nil
}

func (m *MultiNodeAdapter[RPC, HEAD]) removeSub(sub Subscription) {
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()
delete(m.subs, sub)
}

func (m *MultiNodeAdapter[RPC, HEAD]) LatestBlock(ctx context.Context) (HEAD, error) {
// capture chStopInFlight to ensure we are not updating chainInfo with observations related to previous life cycle
ctx, cancel, chStopInFlight, rpc := m.AcquireQueryCtx(ctx, m.ctxTimeout)
Expand Down Expand Up @@ -109,7 +114,7 @@ func (m *MultiNodeAdapter[RPC, HEAD]) LatestFinalizedBlock(ctx context.Context)
return head, errors.New("invalid head")
}

m.OnNewFinalizedHead(ctx, chStopInFlight, head)
m.onNewFinalizedHead(ctx, chStopInFlight, head)
return head, nil
}

Expand All @@ -134,13 +139,18 @@ func (m *MultiNodeAdapter[RPC, HEAD]) SubscribeToHeads(ctx context.Context) (<-c
return nil, nil, err
}

err := m.registerSub(&poller, chStopInFlight)
sub := &ManagedSubscription{
Subscription: &poller,
onUnsubscribe: m.removeSub,
}

err := m.registerSub(sub, chStopInFlight)
if err != nil {
poller.Unsubscribe()
sub.Unsubscribe()
return nil, nil, err
}

return channel, &poller, nil
return channel, sub, nil
}

func (m *MultiNodeAdapter[RPC, HEAD]) SubscribeToFinalizedHeads(ctx context.Context) (<-chan HEAD, Subscription, error) {
Expand All @@ -162,13 +172,18 @@ func (m *MultiNodeAdapter[RPC, HEAD]) SubscribeToFinalizedHeads(ctx context.Cont
return nil, nil, err
}

err := m.registerSub(&poller, chStopInFlight)
sub := &ManagedSubscription{
Subscription: &poller,
onUnsubscribe: m.removeSub,
}

err := m.registerSub(sub, chStopInFlight)
if err != nil {
poller.Unsubscribe()
return nil, nil, err
}

return channel, &poller, nil
return channel, sub, nil
}

func (m *MultiNodeAdapter[RPC, HEAD]) onNewHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) {
Expand All @@ -189,7 +204,7 @@ func (m *MultiNodeAdapter[RPC, HEAD]) onNewHead(ctx context.Context, requestCh <
}
}

func (m *MultiNodeAdapter[RPC, HEAD]) OnNewFinalizedHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) {
func (m *MultiNodeAdapter[RPC, HEAD]) onNewFinalizedHead(ctx context.Context, requestCh <-chan struct{}, head HEAD) {
if !head.IsValid() {
return
}
Expand Down Expand Up @@ -236,19 +251,22 @@ func (m *MultiNodeAdapter[RPC, HEAD]) AcquireQueryCtx(parentCtx context.Context,

func (m *MultiNodeAdapter[RPC, HEAD]) UnsubscribeAllExcept(subs ...Subscription) {
m.subsSliceMu.Lock()
defer m.subsSliceMu.Unlock()

keepSubs := map[Subscription]struct{}{}
for _, sub := range subs {
keepSubs[sub] = struct{}{}
}

var unsubs []Subscription
for sub := range m.subs {
if _, keep := keepSubs[sub]; !keep {
sub.Unsubscribe()
delete(m.subs, sub)
unsubs = append(unsubs, sub)
}
}
m.subsSliceMu.Unlock()

for _, sub := range unsubs {
sub.Unsubscribe()
}
}

// cancelInflightRequests closes and replaces the chStopInFlight
Expand Down
62 changes: 54 additions & 8 deletions pkg/solana/client/multinode/adaptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,33 @@ func TestMultiNodeClient_HeadSubscriptions(t *testing.T) {
t.Fatal("failed to receive finalized head: ", ctx.Err())
}
})

t.Run("Remove Subscription on Unsubscribe", func(t *testing.T) {
c := newTestClient(t)
_, sub1, err := c.SubscribeToHeads(tests.Context(t))
require.NoError(t, err)
require.Equal(t, 1, c.LenSubs())
_, sub2, err := c.SubscribeToFinalizedHeads(tests.Context(t))
require.NoError(t, err)
require.Equal(t, 2, c.LenSubs())

sub1.Unsubscribe()
require.Equal(t, 1, c.LenSubs())
sub2.Unsubscribe()
require.Equal(t, 0, c.LenSubs())
})

t.Run("Ensure no deadlock on UnsubscribeAll", func(t *testing.T) {
c := newTestClient(t)
_, _, err := c.SubscribeToHeads(tests.Context(t))
require.NoError(t, err)
require.Equal(t, 1, c.LenSubs())
_, _, err = c.SubscribeToFinalizedHeads(tests.Context(t))
require.NoError(t, err)
require.Equal(t, 2, c.LenSubs())
c.UnsubscribeAllExcept()
require.Equal(t, 0, c.LenSubs())
})
}

type mockSub struct {
Expand All @@ -134,41 +161,60 @@ func (s *mockSub) Err() <-chan error {
func TestMultiNodeClient_RegisterSubs(t *testing.T) {
t.Run("registerSub", func(t *testing.T) {
c := newTestClient(t)
sub := newMockSub()
mockSub := newMockSub()
sub := &ManagedSubscription{
Subscription: mockSub,
onUnsubscribe: c.removeSub,
}
err := c.registerSub(sub, make(chan struct{}))
require.NoError(t, err)
require.Equal(t, 1, c.LenSubs())
c.UnsubscribeAllExcept()
require.Equal(t, 0, c.LenSubs())
require.Equal(t, true, mockSub.unsubscribed)
})

t.Run("chStopInFlight returns error and unsubscribes", func(t *testing.T) {
c := newTestClient(t)
chStopInFlight := make(chan struct{})
close(chStopInFlight)
sub := newMockSub()
mockSub := newMockSub()
sub := &ManagedSubscription{
Subscription: mockSub,
onUnsubscribe: c.removeSub,
}
err := c.registerSub(sub, chStopInFlight)
require.Error(t, err)
require.Equal(t, true, sub.unsubscribed)
require.Equal(t, true, mockSub.unsubscribed)
})

t.Run("UnsubscribeAllExcept", func(t *testing.T) {
c := newTestClient(t)
chStopInFlight := make(chan struct{})
sub1 := newMockSub()
sub2 := newMockSub()
mockSub1 := newMockSub()
sub1 := &ManagedSubscription{
Subscription: mockSub1,
onUnsubscribe: c.removeSub,
}
mockSub2 := newMockSub()
sub2 := &ManagedSubscription{
Subscription: mockSub2,
onUnsubscribe: c.removeSub,
}
err := c.registerSub(sub1, chStopInFlight)
require.NoError(t, err)
err = c.registerSub(sub2, chStopInFlight)
require.NoError(t, err)
require.Equal(t, 2, c.LenSubs())

// Ensure passed sub is not removed
c.UnsubscribeAllExcept(sub1)
require.Equal(t, 1, c.LenSubs())
require.Equal(t, true, sub2.unsubscribed)
require.Equal(t, false, sub1.unsubscribed)
require.Equal(t, true, mockSub2.unsubscribed)
require.Equal(t, false, mockSub1.unsubscribed)

c.UnsubscribeAllExcept()
require.Equal(t, 0, c.LenSubs())
require.Equal(t, true, sub1.unsubscribed)
require.Equal(t, true, mockSub1.unsubscribed)
})
}
13 changes: 13 additions & 0 deletions pkg/solana/client/multinode/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ type Subscription interface {
Err() <-chan error
}

// ManagedSubscription is a Subscription which contains an onUnsubscribe callback
type ManagedSubscription struct {
Subscription
onUnsubscribe func(sub Subscription)
}

func (w *ManagedSubscription) Unsubscribe() {
w.Subscription.Unsubscribe()
if w.onUnsubscribe != nil {
w.onUnsubscribe(w)
}
}

// RPCClient includes all the necessary generalized RPC methods used by Node to perform health checks
type RPCClient[
CHAIN_ID ID,
Expand Down
Loading