diff --git a/floodsub_test.go b/floodsub_test.go index 4e227d5d..6e845e70 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -1063,3 +1063,338 @@ func TestImproperlySignedMessageRejected(t *testing.T) { ) } } + +func TestSubscriptionJoinNotification(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const numLateSubscribers = 10 + const numHosts = 20 + hosts := getNetHosts(t, ctx, numHosts) + + psubs := getPubsubs(ctx, hosts) + + msgs := make([]*Subscription, numHosts) + subPeersFound := make([]map[peer.ID]struct{}, numHosts) + + // Have some peers subscribe earlier than other peers. + // This exercises whether we get subscription notifications from + // existing peers. + for i, ps := range psubs[numLateSubscribers:] { + subch, err := ps.Subscribe("foobar") + if err != nil { + t.Fatal(err) + } + + msgs[i] = subch + } + + connectAll(t, hosts) + + time.Sleep(time.Millisecond * 100) + + // Have the rest subscribe + for i, ps := range psubs[:numLateSubscribers] { + subch, err := ps.Subscribe("foobar") + if err != nil { + t.Fatal(err) + } + + msgs[i+numLateSubscribers] = subch + } + + wg := sync.WaitGroup{} + for i := 0; i < numHosts; i++ { + peersFound := make(map[peer.ID]struct{}) + subPeersFound[i] = peersFound + sub := msgs[i] + wg.Add(1) + go func(peersFound map[peer.ID]struct{}) { + defer wg.Done() + for len(peersFound) < numHosts-1 { + event, err := sub.NextPeerEvent(ctx) + if err != nil { + t.Fatal(err) + } + if event.Type == PeerJoin { + peersFound[event.Peer] = struct{}{} + } + } + }(peersFound) + } + + wg.Wait() + for _, peersFound := range subPeersFound { + if len(peersFound) != numHosts-1 { + t.Fatal("incorrect number of peers found") + } + } +} + +func TestSubscriptionLeaveNotification(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const numHosts = 20 + hosts := getNetHosts(t, ctx, numHosts) + + psubs := getPubsubs(ctx, hosts) + + msgs := make([]*Subscription, numHosts) + subPeersFound := make([]map[peer.ID]struct{}, numHosts) + + // Subscribe all peers and wait until they've all been found + for i, ps := range psubs { + subch, err := ps.Subscribe("foobar") + if err != nil { + t.Fatal(err) + } + + msgs[i] = subch + } + + connectAll(t, hosts) + + time.Sleep(time.Millisecond * 100) + + wg := sync.WaitGroup{} + for i := 0; i < numHosts; i++ { + peersFound := make(map[peer.ID]struct{}) + subPeersFound[i] = peersFound + sub := msgs[i] + wg.Add(1) + go func(peersFound map[peer.ID]struct{}) { + defer wg.Done() + for len(peersFound) < numHosts-1 { + event, err := sub.NextPeerEvent(ctx) + if err != nil { + t.Fatal(err) + } + if event.Type == PeerJoin { + peersFound[event.Peer] = struct{}{} + } + } + }(peersFound) + } + + wg.Wait() + for _, peersFound := range subPeersFound { + if len(peersFound) != numHosts-1 { + t.Fatal("incorrect number of peers found") + } + } + + // Test removing peers and verifying that they cause events + msgs[1].Cancel() + hosts[2].Close() + psubs[0].BlacklistPeer(hosts[3].ID()) + + leavingPeers := make(map[peer.ID]struct{}) + for len(leavingPeers) < 3 { + event, err := msgs[0].NextPeerEvent(ctx) + if err != nil { + t.Fatal(err) + } + if event.Type == PeerLeave { + leavingPeers[event.Peer] = struct{}{} + } + } + + if _, ok := leavingPeers[hosts[1].ID()]; !ok { + t.Fatal(fmt.Errorf("canceling subscription did not cause a leave event")) + } + if _, ok := leavingPeers[hosts[2].ID()]; !ok { + t.Fatal(fmt.Errorf("closing host did not cause a leave event")) + } + if _, ok := leavingPeers[hosts[3].ID()]; !ok { + t.Fatal(fmt.Errorf("blacklisting peer did not cause a leave event")) + } +} + +func TestSubscriptionManyNotifications(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const topic = "foobar" + + const numHosts = 35 + hosts := getNetHosts(t, ctx, numHosts) + + psubs := getPubsubs(ctx, hosts) + + msgs := make([]*Subscription, numHosts) + subPeersFound := make([]map[peer.ID]struct{}, numHosts) + + // Subscribe all peers except one and wait until they've all been found + for i := 1; i < numHosts; i++ { + subch, err := psubs[i].Subscribe(topic) + if err != nil { + t.Fatal(err) + } + + msgs[i] = subch + } + + connectAll(t, hosts) + + time.Sleep(time.Millisecond * 100) + + wg := sync.WaitGroup{} + for i := 1; i < numHosts; i++ { + peersFound := make(map[peer.ID]struct{}) + subPeersFound[i] = peersFound + sub := msgs[i] + wg.Add(1) + go func(peersFound map[peer.ID]struct{}) { + defer wg.Done() + for len(peersFound) < numHosts-2 { + event, err := sub.NextPeerEvent(ctx) + if err != nil { + t.Fatal(err) + } + if event.Type == PeerJoin { + peersFound[event.Peer] = struct{}{} + } + } + }(peersFound) + } + + wg.Wait() + for _, peersFound := range subPeersFound[1:] { + if len(peersFound) != numHosts-2 { + t.Fatalf("found %d peers, expected %d", len(peersFound), numHosts-2) + } + } + + // Wait for remaining peer to find other peers + for len(psubs[0].ListPeers(topic)) < numHosts-1 { + time.Sleep(time.Millisecond * 100) + } + + // Subscribe the remaining peer and check that all the events came through + sub, err := psubs[0].Subscribe(topic) + if err != nil { + t.Fatal(err) + } + + msgs[0] = sub + + peerState := readAllQueuedEvents(ctx, t, sub) + + if len(peerState) != numHosts-1 { + t.Fatal("incorrect number of peers found") + } + + for _, e := range peerState { + if e != PeerJoin { + t.Fatal("non Join event occurred") + } + } + + // Unsubscribe all peers except one and check that all the events came through + for i := 1; i < numHosts; i++ { + msgs[i].Cancel() + } + + // Wait for remaining peer to disconnect from the other peers + for len(psubs[0].ListPeers(topic)) != 0 { + time.Sleep(time.Millisecond * 100) + } + + peerState = readAllQueuedEvents(ctx, t, sub) + + if len(peerState) != numHosts-1 { + t.Fatal("incorrect number of peers found") + } + + for _, e := range peerState { + if e != PeerLeave { + t.Fatal("non Leave event occurred") + } + } +} + +func TestSubscriptionNotificationSubUnSub(t *testing.T) { + // Resubscribe and Unsubscribe a peers and check the state for consistency + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const topic = "foobar" + + const numHosts = 35 + hosts := getNetHosts(t, ctx, numHosts) + psubs := getPubsubs(ctx, hosts) + + for i := 1; i < numHosts; i++ { + connect(t, hosts[0], hosts[i]) + } + time.Sleep(time.Millisecond * 100) + + notifSubThenUnSub(ctx, t, topic, psubs) +} + +func notifSubThenUnSub(ctx context.Context, t *testing.T, topic string, + psubs []*PubSub) { + + ps := psubs[0] + msgs := make([]*Subscription, len(psubs)) + checkSize := len(psubs) - 1 + + // Subscribe all peers to the topic + var err error + for i, ps := range psubs { + msgs[i], err = ps.Subscribe(topic) + if err != nil { + t.Fatal(err) + } + } + + sub := msgs[0] + + // Wait for the primary peer to be connected to the other peers + for len(ps.ListPeers(topic)) < checkSize { + time.Sleep(time.Millisecond * 100) + } + + // Unsubscribe all peers except the primary + for i := 1; i < checkSize+1; i++ { + msgs[i].Cancel() + } + + // Wait for the unsubscribe messages to reach the primary peer + for len(ps.ListPeers(topic)) < 0 { + time.Sleep(time.Millisecond * 100) + } + + // read all available events and verify that there are no events to process + // this is because every peer that joined also left + peerState := readAllQueuedEvents(ctx, t, sub) + + if len(peerState) != 0 { + for p, s := range peerState { + fmt.Println(p, s) + } + t.Fatalf("Received incorrect events. %d extra events", len(peerState)) + } +} + +func readAllQueuedEvents(ctx context.Context, t *testing.T, sub *Subscription) map[peer.ID]EventType { + peerState := make(map[peer.ID]EventType) + for { + ctx, _ := context.WithTimeout(ctx, time.Millisecond*100) + event, err := sub.NextPeerEvent(ctx) + if err == context.DeadlineExceeded { + break + } else if err != nil { + t.Fatal(err) + } + + e, ok := peerState[event.Peer] + if !ok { + peerState[event.Peer] = event.Type + } else if e != event.Type { + delete(peerState, event.Peer) + } + } + return peerState +} diff --git a/pubsub.go b/pubsub.go index 6df169d8..5902c00a 100644 --- a/pubsub.go +++ b/pubsub.go @@ -333,8 +333,11 @@ func (p *PubSub) processLoop(ctx context.Context) { } delete(p.peers, pid) - for _, t := range p.topics { - delete(t, pid) + for t, tmap := range p.topics { + if _, ok := tmap[pid]; ok { + delete(tmap, pid) + p.notifyLeave(t, pid) + } } p.rt.RemovePeer(pid) @@ -392,8 +395,11 @@ func (p *PubSub) processLoop(ctx context.Context) { if ok { close(ch) delete(p.peers, pid) - for _, t := range p.topics { - delete(t, pid) + for t, tmap := range p.topics { + if _, ok := tmap[pid]; ok { + delete(tmap, pid) + p.notifyLeave(t, pid) + } } p.rt.RemovePeer(pid) } @@ -417,7 +423,7 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) { } sub.err = fmt.Errorf("subscription cancelled by calling sub.Cancel()") - close(sub.ch) + sub.close() delete(subs, sub) if len(subs) == 0 { @@ -447,7 +453,11 @@ func (p *PubSub) handleAddSubscription(req *addSubReq) { subs = p.myTopics[sub.topic] } - sub.ch = make(chan *Message, 32) + tmap := p.topics[sub.topic] + + for p := range tmap { + sub.evtLog[p] = PeerJoin + } sub.cancelCh = p.cancelCh p.myTopics[sub.topic][sub] = struct{}{} @@ -560,6 +570,14 @@ func (p *PubSub) subscribedToMsg(msg *pb.Message) bool { return false } +func (p *PubSub) notifyLeave(topic string, pid peer.ID) { + if subs, ok := p.myTopics[topic]; ok { + for s := range subs { + s.sendNotification(PeerEvent{PeerLeave, pid}) + } + } +} + func (p *PubSub) handleIncomingRPC(rpc *RPC) { for _, subopt := range rpc.GetSubscriptions() { t := subopt.GetTopicid() @@ -570,13 +588,25 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { p.topics[t] = tmap } - tmap[rpc.from] = struct{}{} + if _, ok = tmap[rpc.from]; !ok { + tmap[rpc.from] = struct{}{} + if subs, ok := p.myTopics[t]; ok { + peer := rpc.from + for s := range subs { + s.sendNotification(PeerEvent{PeerJoin, peer}) + } + } + } } else { tmap, ok := p.topics[t] if !ok { continue } - delete(tmap, rpc.from) + + if _, ok := tmap[rpc.from]; ok { + delete(tmap, rpc.from) + p.notifyLeave(t, rpc.from) + } } } @@ -666,6 +696,11 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubO sub := &Subscription{ topic: td.GetName(), + + ch: make(chan *Message, 32), + peerEvtCh: make(chan PeerEvent, 1), + evtLog: make(map[peer.ID]EventType), + evtLogCh: make(chan struct{}, 1), } for _, opt := range opts { diff --git a/subscription.go b/subscription.go index 66a9e513..45d957ec 100644 --- a/subscription.go +++ b/subscription.go @@ -2,6 +2,15 @@ package pubsub import ( "context" + "github.com/libp2p/go-libp2p-core/peer" + "sync" +) + +type EventType int + +const ( + PeerJoin EventType = iota + PeerLeave ) type Subscription struct { @@ -9,12 +18,23 @@ type Subscription struct { ch chan *Message cancelCh chan<- *Subscription err error + + peerEvtCh chan PeerEvent + evtLogMx sync.Mutex + evtLog map[peer.ID]EventType + evtLogCh chan struct{} +} + +type PeerEvent struct { + Type EventType + Peer peer.ID } func (sub *Subscription) Topic() string { return sub.topic } +// Next returns the next message in our subscription func (sub *Subscription) Next(ctx context.Context) (*Message, error) { select { case msg, ok := <-sub.ch: @@ -31,3 +51,69 @@ func (sub *Subscription) Next(ctx context.Context) (*Message, error) { func (sub *Subscription) Cancel() { sub.cancelCh <- sub } + +func (sub *Subscription) close() { + close(sub.ch) +} + +func (sub *Subscription) sendNotification(evt PeerEvent) { + sub.evtLogMx.Lock() + defer sub.evtLogMx.Unlock() + + sub.addToEventLog(evt) +} + +// addToEventLog assumes a lock has been taken to protect the event log +func (sub *Subscription) addToEventLog(evt PeerEvent) { + e, ok := sub.evtLog[evt.Peer] + if !ok { + sub.evtLog[evt.Peer] = evt.Type + // send signal that an event has been added to the event log + select { + case sub.evtLogCh <- struct{}{}: + default: + } + } else if e != evt.Type { + delete(sub.evtLog, evt.Peer) + } +} + +// pullFromEventLog assumes a lock has been taken to protect the event log +func (sub *Subscription) pullFromEventLog() (PeerEvent, bool) { + for k, v := range sub.evtLog { + evt := PeerEvent{Peer: k, Type: v} + delete(sub.evtLog, k) + return evt, true + } + return PeerEvent{}, false +} + +// NextPeerEvent returns the next event regarding subscribed peers +// Guarantees: Peer Join and Peer Leave events for a given peer will fire in order. +// Unless a peer both Joins and Leaves before NextPeerEvent emits either event +// all events will eventually be received from NextPeerEvent. +func (sub *Subscription) NextPeerEvent(ctx context.Context) (PeerEvent, error) { + for { + sub.evtLogMx.Lock() + evt, ok := sub.pullFromEventLog() + if ok { + // make sure an event log signal is available if there are events in the event log + if len(sub.evtLog) > 0 { + select { + case sub.evtLogCh <- struct{}{}: + default: + } + } + sub.evtLogMx.Unlock() + return evt, nil + } + sub.evtLogMx.Unlock() + + select { + case <-sub.evtLogCh: + continue + case <-ctx.Done(): + return PeerEvent{}, ctx.Err() + } + } +}