Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Commit

Permalink
feat: close transports that implement io.Closer
Browse files Browse the repository at this point in the history
This way, transports with shared resources (e.g., reused sockets) can clean them
up.

fixes libp2p/go-libp2p#999
  • Loading branch information
Stebalien committed Aug 28, 2020
1 parent 58c167a commit 945d870
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 5 deletions.
22 changes: 22 additions & 0 deletions swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -176,6 +177,27 @@ func (s *Swarm) teardown() error {
// Wait for everything to finish.
s.refs.Wait()

// Now close out any transports (if necessary). Do this after closing
// all connections/listeners.
s.transports.Lock()
transports := s.transports.m
s.transports.m = nil
s.transports.Unlock()

var wg sync.WaitGroup
for _, t := range transports {
if closer, ok := t.(io.Closer); ok {
wg.Add(1)
go func(c io.Closer) {
defer wg.Done()
if err := closer.Close(); err != nil {
log.Errorf("error when closing down transport %T: %s", c, err)
}
}(closer)
}
}
wg.Wait()

return nil
}

Expand Down
7 changes: 6 additions & 1 deletion swarm_listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ func (s *Swarm) Listen(addrs ...ma.Multiaddr) error {
func (s *Swarm) AddListenAddr(a ma.Multiaddr) error {
tpt := s.TransportForListening(a)
if tpt == nil {
return ErrNoTransport
select {
case <-s.proc.Closing():
return ErrSwarmClosed
default:
return ErrNoTransport
}
}

list, err := tpt.Listen(a)
Expand Down
13 changes: 11 additions & 2 deletions swarm_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport {
s.transports.RLock()
defer s.transports.RUnlock()
if len(s.transports.m) == 0 {
log.Error("you have no transports configured")
// make sure we're not just shutting down.
if s.transports.m != nil {
log.Error("you have no transports configured")
}
return nil
}

Expand Down Expand Up @@ -48,7 +51,10 @@ func (s *Swarm) TransportForListening(a ma.Multiaddr) transport.Transport {
s.transports.RLock()
defer s.transports.RUnlock()
if len(s.transports.m) == 0 {
log.Error("you have no transports configured")
// make sure we're not just shutting down.
if s.transports.m != nil {
log.Error("you have no transports configured")
}
return nil
}

Expand Down Expand Up @@ -77,6 +83,9 @@ func (s *Swarm) AddTransport(t transport.Transport) error {

s.transports.Lock()
defer s.transports.Unlock()
if s.transports.m == nil {
return ErrSwarmClosed
}
var registered []string
for _, p := range protocols {
if _, ok := s.transports.m[p]; ok {
Expand Down
37 changes: 35 additions & 2 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"testing"

swarm "github.com/libp2p/go-libp2p-swarm"
swarmt "github.com/libp2p/go-libp2p-swarm/testing"

"github.com/libp2p/go-libp2p-core/peer"
Expand All @@ -14,6 +15,7 @@ import (
type dummyTransport struct {
protocols []int
proxy bool
closed bool
}

func (dt *dummyTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
Expand All @@ -35,13 +37,44 @@ func (dt *dummyTransport) Proxy() bool {
func (dt *dummyTransport) Protocols() []int {
return dt.protocols
}
func (dt *dummyTransport) Close() error {
dt.closed = true
return nil
}

func TestUselessTransport(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
swarm := swarmt.GenSwarm(t, ctx)
err := swarm.AddTransport(new(dummyTransport))
s := swarmt.GenSwarm(t, ctx)
err := s.AddTransport(new(dummyTransport))
if err == nil {
t.Fatal("adding a transport that supports no protocols should have failed")
}
}

func TestTransportClose(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := swarmt.GenSwarm(t, ctx)
tpt := &dummyTransport{protocols: []int{1}}
if err := s.AddTransport(tpt); err != nil {
t.Fatal(err)
}
_ = s.Close()
if !tpt.closed {
t.Fatal("expected transport to be closed")
}

}

func TestTransportAfterClose(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := swarmt.GenSwarm(t, ctx)
s.Close()

tpt := &dummyTransport{protocols: []int{1}}
if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed {
t.Fatal("expected swarm closed error, got: ", err)
}
}

0 comments on commit 945d870

Please sign in to comment.