diff --git a/dht.go b/dht.go index 19e7f585c..f7bf19705 100644 --- a/dht.go +++ b/dht.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/libp2p/go-eventbus" + "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -84,6 +86,10 @@ type IpfsDHT struct { // "forked" DHTs (e.g., DHTs with custom protocols and/or private // networks). enableProviders, enableValues bool + + subscriptions struct { + evtPeerIdentification event.Subscription + } } // Assert that IPFS assumptions about interfaces aren't broken. These aren't a @@ -114,15 +120,22 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er dht.enableProviders = cfg.EnableProviders dht.enableValues = cfg.EnableValues + subnot := (*subscriberNotifee)(dht) + // register for network notifs. - dht.host.Network().Notify((*netNotifiee)(dht)) + dht.host.Network().Notify(subnot) dht.proc = goprocessctx.WithContextAndTeardown(ctx, func() error { // remove ourselves from network notifs. - dht.host.Network().StopNotify((*netNotifiee)(dht)) + dht.host.Network().StopNotify((*subscriberNotifee)(dht)) + + if dht.subscriptions.evtPeerIdentification != nil { + _ = dht.subscriptions.evtPeerIdentification.Close() + } return nil }) + dht.proc.AddChild(subnot.Process(ctx)) dht.proc.AddChild(dht.providers.Process()) dht.Validator = cfg.Validator @@ -190,6 +203,13 @@ func makeDHT(ctx context.Context, h host.Host, cfg opts.Options) *IpfsDHT { triggerRtRefresh: make(chan chan<- error), } + var err error + evts := []interface{}{&event.EvtPeerIdentificationCompleted{}, &event.EvtPeerIdentificationFailed{}} + dht.subscriptions.evtPeerIdentification, err = h.EventBus().Subscribe(evts, eventbus.BufSize(256)) + if err != nil { + logger.Errorf("dht not subscribed to peer identification events; things will fail; err: %s", err) + } + dht.ctx = dht.newContextWithLocalTags(ctx) return dht diff --git a/ext_test.go b/ext_test.go index 91d54d9af..8bf63ef37 100644 --- a/ext_test.go +++ b/ext_test.go @@ -8,8 +8,11 @@ import ( "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" "github.com/libp2p/go-libp2p-core/routing" opts "github.com/libp2p/go-libp2p-kad-dht/opts" + swarmt "github.com/libp2p/go-libp2p-swarm/testing" + bhost "github.com/libp2p/go-libp2p/p2p/host/basic" ggio "github.com/gogo/protobuf/io" u "github.com/ipfs/go-ipfs-util" @@ -24,25 +27,28 @@ func TestGetFailures(t *testing.T) { } ctx := context.Background() - mn, err := mocknet.FullMeshConnected(ctx, 2) - if err != nil { - t.Fatal(err) - } - hosts := mn.Hosts() - os := []opts.Option{opts.DisableAutoRefresh()} - d, err := New(ctx, hosts[0], os...) + host1 := bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)) + host2 := bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)) + + d, err := New(ctx, host1, opts.DisableAutoRefresh()) if err != nil { t.Fatal(err) } - d.Update(ctx, hosts[1].ID()) // Reply with failures to every message - hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) { + host2.SetStreamHandler(d.protocols[0], func(s network.Stream) { time.Sleep(400 * time.Millisecond) s.Close() }) + host1.Peerstore().AddAddrs(host2.ID(), host2.Addrs(), peerstore.ConnectedAddrTTL) + _, err = host1.Network().DialPeer(ctx, host2.ID()) + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + // This one should time out ctx1, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() @@ -61,7 +67,7 @@ func TestGetFailures(t *testing.T) { t.Log("Timeout test passed.") // Reply with failures to every message - hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) { + host2.SetStreamHandler(d.protocols[0], func(s network.Stream) { defer s.Close() pbr := ggio.NewDelimitedReader(s, network.MessageSizeMax) @@ -113,7 +119,7 @@ func TestGetFailures(t *testing.T) { Record: rec, } - s, err := hosts[1].NewStream(context.Background(), hosts[0].ID(), d.protocols[0]) + s, err := host2.NewStream(context.Background(), host1.ID(), d.protocols[0]) if err != nil { t.Fatal(err) } diff --git a/go.mod b/go.mod index 0481826de..89f092954 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,8 @@ require ( github.com/ipfs/go-ipfs-util v0.0.1 github.com/ipfs/go-log v0.0.1 github.com/jbenet/goprocess v0.1.3 - github.com/libp2p/go-libp2p v0.5.2 + github.com/libp2p/go-eventbus v0.1.0 + github.com/libp2p/go-libp2p v0.5.3-0.20200221174525-7ba322244e0a github.com/libp2p/go-libp2p-core v0.3.1 github.com/libp2p/go-libp2p-kbucket v0.2.3 github.com/libp2p/go-libp2p-peerstore v0.1.4 diff --git a/go.sum b/go.sum index 49b7f76f1..2a396cdf8 100644 --- a/go.sum +++ b/go.sum @@ -163,8 +163,8 @@ github.com/libp2p/go-flow-metrics v0.0.2 h1:U5TvqfoyR6GVRM+bC15Ux1ltar1kbj6Zw6xO github.com/libp2p/go-flow-metrics v0.0.2/go.mod h1:HeoSNUrOJVK1jEpDqVEiUOIXqhbnS27omG0uWU5slZs= github.com/libp2p/go-flow-metrics v0.0.3 h1:8tAs/hSdNvUiLgtlSy3mxwxWP4I9y/jlkPFT7epKdeM= github.com/libp2p/go-flow-metrics v0.0.3/go.mod h1:HeoSNUrOJVK1jEpDqVEiUOIXqhbnS27omG0uWU5slZs= -github.com/libp2p/go-libp2p v0.5.2 h1:fjQUTyB7x/4XgO31OEWkJ5uFeHRgpoExlf0rXz5BO8k= -github.com/libp2p/go-libp2p v0.5.2/go.mod h1:o2r6AcpNl1eNGoiWhRtPji03NYOvZumeQ6u+X6gSxnM= +github.com/libp2p/go-libp2p v0.5.3-0.20200221174525-7ba322244e0a h1:cxYryrTPI23R5InZb9Kc86dj819f7yVMapQPuj1Ti1s= +github.com/libp2p/go-libp2p v0.5.3-0.20200221174525-7ba322244e0a/go.mod h1:8UlWMmxcKNxyY0ocYX8Ft4IZ0mMfr7b89v1qZdXxwrk= github.com/libp2p/go-libp2p-autonat v0.1.1 h1:WLBZcIRsjZlWdAZj9CiBSvU2wQXoUOiS1Zk1tM7DTJI= github.com/libp2p/go-libp2p-autonat v0.1.1/go.mod h1:OXqkeGOY2xJVWKAGV2inNF5aKN/djNA3fdpCWloIudE= github.com/libp2p/go-libp2p-blankhost v0.1.1/go.mod h1:pf2fvdLJPsC1FsVrNP3DUUvMzUts2dsLLBEpo1vW1ro= diff --git a/notif.go b/notif.go deleted file mode 100644 index a7913a5f5..000000000 --- a/notif.go +++ /dev/null @@ -1,142 +0,0 @@ -package dht - -import ( - "github.com/libp2p/go-libp2p-core/helpers" - "github.com/libp2p/go-libp2p-core/network" - - ma "github.com/multiformats/go-multiaddr" - mstream "github.com/multiformats/go-multistream" -) - -// netNotifiee defines methods to be used with the IpfsDHT -type netNotifiee IpfsDHT - -func (nn *netNotifiee) DHT() *IpfsDHT { - return (*IpfsDHT)(nn) -} - -func (nn *netNotifiee) Connected(n network.Network, v network.Conn) { - dht := nn.DHT() - select { - case <-dht.Process().Closing(): - return - default: - } - - p := v.RemotePeer() - protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) - if err == nil && len(protos) != 0 { - // We lock here for consistency with the lock in testConnection. - // This probably isn't necessary because (dis)connect - // notifications are serialized but it's nice to be consistent. - dht.plk.Lock() - defer dht.plk.Unlock() - if dht.host.Network().Connectedness(p) == network.Connected { - refresh := dht.routingTable.Size() <= minRTRefreshThreshold - dht.Update(dht.Context(), p) - if refresh && dht.autoRefresh { - select { - case dht.triggerRtRefresh <- nil: - default: - } - } - } - return - } - - // Note: Unfortunately, the peerstore may not yet know that this peer is - // a DHT server. So, if it didn't return a positive response above, test - // manually. - go nn.testConnection(v) -} - -func (nn *netNotifiee) testConnection(v network.Conn) { - dht := nn.DHT() - p := v.RemotePeer() - - // Forcibly use *this* connection. Otherwise, if we have two connections, we could: - // 1. Test it twice. - // 2. Have it closed from under us leaving the second (open) connection untested. - s, err := v.NewStream() - if err != nil { - // Connection error - return - } - defer helpers.FullClose(s) - - selected, err := mstream.SelectOneOf(dht.protocolStrs(), s) - if err != nil { - // Doesn't support the protocol - return - } - // Remember this choice (makes subsequent negotiations faster) - dht.peerstore.AddProtocols(p, selected) - - // We lock here as we race with disconnect. If we didn't lock, we could - // finish processing a connect after handling the associated disconnect - // event and add the peer to the routing table after removing it. - dht.plk.Lock() - defer dht.plk.Unlock() - if dht.host.Network().Connectedness(p) == network.Connected { - refresh := dht.routingTable.Size() <= minRTRefreshThreshold - dht.Update(dht.Context(), p) - if refresh && dht.autoRefresh { - select { - case dht.triggerRtRefresh <- nil: - default: - } - } - } -} - -func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) { - dht := nn.DHT() - select { - case <-dht.Process().Closing(): - return - default: - } - - p := v.RemotePeer() - - // Lock and check to see if we're still connected. We lock to make sure - // we don't concurrently process a connect event. - dht.plk.Lock() - defer dht.plk.Unlock() - if dht.host.Network().Connectedness(p) == network.Connected { - // We're still connected. - return - } - - dht.routingTable.Remove(p) - if dht.routingTable.Size() < minRTRefreshThreshold { - // TODO: Actively bootstrap. For now, just try to add the currently connected peers. - for _, p := range dht.host.Network().Peers() { - // Don't bother probing, we do that on connect. - protos, err := dht.peerstore.SupportsProtocols(p, dht.protocolStrs()...) - if err == nil && len(protos) != 0 { - dht.Update(dht.Context(), p) - } - } - } - - dht.smlk.Lock() - defer dht.smlk.Unlock() - ms, ok := dht.strmap[p] - if !ok { - return - } - delete(dht.strmap, p) - - // Do this asynchronously as ms.lk can block for a while. - go func() { - ms.lk.Lock() - defer ms.lk.Unlock() - ms.invalidate() - }() -} - -func (nn *netNotifiee) OpenedStream(n network.Network, v network.Stream) {} -func (nn *netNotifiee) ClosedStream(n network.Network, v network.Stream) {} -func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {} -func (nn *netNotifiee) ListenClose(n network.Network, a ma.Multiaddr) {} diff --git a/notify_test.go b/notify_test.go index 3a15a8e82..4c1046b66 100644 --- a/notify_test.go +++ b/notify_test.go @@ -16,8 +16,8 @@ func TestNotifieeMultipleConn(t *testing.T) { d1 := setupDHT(ctx, t, false) d2 := setupDHT(ctx, t, false) - nn1 := (*netNotifiee)(d1) - nn2 := (*netNotifiee)(d2) + nn1 := (*subscriberNotifee)(d1) + nn2 := (*subscriberNotifee)(d2) connect(t, ctx, d1, d2) c12 := d1.host.Network().ConnsToPeer(d2.self)[0] diff --git a/subscriber_notifee.go b/subscriber_notifee.go new file mode 100644 index 000000000..d2bf136ed --- /dev/null +++ b/subscriber_notifee.go @@ -0,0 +1,90 @@ +package dht + +import ( + "context" + + "github.com/jbenet/goprocess" + goprocessctx "github.com/jbenet/goprocess/context" + "github.com/libp2p/go-libp2p-core/event" + "github.com/libp2p/go-libp2p-core/network" + ma "github.com/multiformats/go-multiaddr" +) + +// subscriberNotifee implements network.Notifee and also manages the subscriber to the event bus. We consume peer +// identification events to trigger inclusion in the routing table, and we consume Disconnected events to eject peers +// from it. +type subscriberNotifee IpfsDHT + +func (nn *subscriberNotifee) DHT() *IpfsDHT { + return (*IpfsDHT)(nn) +} + +func (nn *subscriberNotifee) Process(ctx context.Context) goprocess.Process { + proc := goprocessctx.WithContext(ctx) + proc.Go(nn.subscribe) + return proc +} + +func (nn *subscriberNotifee) subscribe(proc goprocess.Process) { + dht := nn.DHT() + for { + select { + case evt, more := <-dht.subscriptions.evtPeerIdentification.Out(): + if !more { + return + } + switch ev := evt.(type) { + case event.EvtPeerIdentificationCompleted: + protos, err := dht.peerstore.SupportsProtocols(ev.Peer, dht.protocolStrs()...) + if err == nil && len(protos) != 0 { + dht.Update(dht.ctx, ev.Peer) + } + } + case <-proc.Closing(): + return + } + } +} + +func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { + dht := nn.DHT() + select { + case <-dht.Process().Closing(): + return + default: + } + + p := v.RemotePeer() + + // Lock and check to see if we're still connected. We lock to make sure + // we don't concurrently process a connect event. + dht.plk.Lock() + defer dht.plk.Unlock() + if dht.host.Network().Connectedness(p) == network.Connected { + // We're still connected. + return + } + + dht.routingTable.Remove(p) + + dht.smlk.Lock() + defer dht.smlk.Unlock() + ms, ok := dht.strmap[p] + if !ok { + return + } + delete(dht.strmap, p) + + // Do this asynchronously as ms.lk can block for a while. + go func() { + ms.lk.Lock() + defer ms.lk.Unlock() + ms.invalidate() + }() +} + +func (nn *subscriberNotifee) Connected(n network.Network, v network.Conn) {} +func (nn *subscriberNotifee) OpenedStream(n network.Network, v network.Stream) {} +func (nn *subscriberNotifee) ClosedStream(n network.Network, v network.Stream) {} +func (nn *subscriberNotifee) Listen(n network.Network, a ma.Multiaddr) {} +func (nn *subscriberNotifee) ListenClose(n network.Network, a ma.Multiaddr) {}