diff --git a/routing.go b/routing.go index 70e7087..f5ebad0 100644 --- a/routing.go +++ b/routing.go @@ -2,6 +2,7 @@ package discovery import ( "context" + "github.com/libp2p/go-libp2p-core/discovery" "time" cid "github.com/ipfs/go-cid" @@ -83,3 +84,29 @@ func nsToCid(ns string) (cid.Cid, error) { return cid.NewCidV1(cid.Raw, h), nil } + +func NewDiscoveryRouting(disc discovery.Discovery) *DiscoveryRouting { + return &DiscoveryRouting{disc} +} + +type DiscoveryRouting struct { + discovery.Discovery +} + +func (r *DiscoveryRouting) Provide(ctx context.Context, c cid.Cid, bcast bool) error { + if !bcast { + return nil + } + + _, err := r.Advertise(ctx, cidToNs(c)) + return err +} + +func (r *DiscoveryRouting) FindProvidersAsync(ctx context.Context, c cid.Cid, limit int) <-chan peer.AddrInfo { + ch, _ := r.FindPeers(ctx, cidToNs(c), discovery.Limit(limit)) + return ch +} + +func cidToNs(c cid.Cid) string { + return "/provider/" + c.String() +} diff --git a/routing_test.go b/routing_test.go index 3fa39bd..1534a5b 100644 --- a/routing_test.go +++ b/routing_test.go @@ -4,9 +4,11 @@ import ( "context" "sync" "testing" + "time" - cid "github.com/ipfs/go-cid" + "github.com/ipfs/go-cid" bhost "github.com/libp2p/go-libp2p-blankhost" + "github.com/libp2p/go-libp2p-core/discovery" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" swarmt "github.com/libp2p/go-libp2p-swarm/testing" @@ -69,6 +71,107 @@ func (m *mockRouting) FindProvidersAsync(ctx context.Context, cid cid.Cid, limit return ch } +type mockDiscoveryServer struct { + mx sync.Mutex + db map[string]map[peer.ID]*discoveryRegistration +} + +type discoveryRegistration struct { + info peer.AddrInfo + expiration time.Time +} + +func newDiscoveryServer() *mockDiscoveryServer { + return &mockDiscoveryServer{ + db: make(map[string]map[peer.ID]*discoveryRegistration), + } +} + +func (s *mockDiscoveryServer) Advertise(ns string, info peer.AddrInfo, ttl time.Duration) (time.Duration, error) { + s.mx.Lock() + defer s.mx.Unlock() + + peers, ok := s.db[ns] + if !ok { + peers = make(map[peer.ID]*discoveryRegistration) + s.db[ns] = peers + } + peers[info.ID] = &discoveryRegistration{info, time.Now().Add(ttl)} + return ttl, nil +} + +func (s *mockDiscoveryServer) FindPeers(ns string, limit int) (<-chan peer.AddrInfo, error) { + s.mx.Lock() + defer s.mx.Unlock() + + peers, ok := s.db[ns] + if !ok || len(peers) == 0 { + emptyCh := make(chan peer.AddrInfo) + close(emptyCh) + return emptyCh, nil + } + + count := len(peers) + if limit != 0 && count > limit { + count = limit + } + + iterTime := time.Now() + ch := make(chan peer.AddrInfo, count) + numSent := 0 + for p, reg := range peers { + if numSent == count { + break + } + if iterTime.After(reg.expiration) { + delete(peers, p) + continue + } + + numSent++ + ch <- reg.info + } + close(ch) + + return ch, nil +} + +func (s *mockDiscoveryServer) hasPeerRecord(ns string, pid peer.ID) bool { + s.mx.Lock() + defer s.mx.Unlock() + + if peers, ok := s.db[ns]; ok { + _, ok := peers[pid] + return ok + } + return false +} + +type mockDiscoveryClient struct { + host host.Host + server *mockDiscoveryServer +} + +func (d *mockDiscoveryClient) Advertise(ctx context.Context, ns string, opts ...discovery.Option) (time.Duration, error) { + var options discovery.Options + err := options.Apply(opts...) + if err != nil { + return 0, err + } + + return d.server.Advertise(ns, *host.InfoFromHost(d.host), options.Ttl) +} + +func (d *mockDiscoveryClient) FindPeers(ctx context.Context, ns string, opts ...discovery.Option) (<-chan peer.AddrInfo, error) { + var options discovery.Options + err := options.Apply(opts...) + if err != nil { + return nil, err + } + + return d.server.FindPeers(ns, options.Limit) +} + func TestRoutingDiscovery(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -102,3 +205,43 @@ func TestRoutingDiscovery(t *testing.T) { t.Fatalf("Unexpected peer: %s", pi.ID) } } + +func TestDiscoveryRouting(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + h2 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + + dserver := newDiscoveryServer() + d1 := &mockDiscoveryClient{h1, dserver} + d2 := &mockDiscoveryClient{h2, dserver} + + r1 := NewDiscoveryRouting(d1) + r2 := NewDiscoveryRouting(d2) + + c, err := nsToCid("/test") + if err != nil { + t.Fatal(err) + } + + if err := r1.Provide(ctx, c, true); err != nil { + t.Fatal(err) + } + + pch := r2.FindProvidersAsync(ctx, c, 20) + + var allAIs []peer.AddrInfo + for ai := range pch { + allAIs = append(allAIs, ai) + } + + if len(allAIs) != 1 { + t.Fatalf("Expected 1 peer, got %d", len(allAIs)) + } + + ai := allAIs[0] + if ai.ID != h1.ID() { + t.Fatalf("Unexpected peer: %s", ai.ID) + } +}