diff --git a/server.go b/server.go index 0f11821..09739e1 100644 --- a/server.go +++ b/server.go @@ -39,6 +39,7 @@ type Server struct { mu sync.Mutex listeners map[net.Listener]struct{} conns map[*gossh.ServerConn]struct{} + connWg sync.WaitGroup doneChan chan struct{} } @@ -122,16 +123,17 @@ func (srv *Server) Shutdown(ctx context.Context) error { srv.closeDoneChanLocked() srv.mu.Unlock() - listenerWgChan := make(chan struct{}, 1) + finished := make(chan struct{}, 1) go func() { srv.listenerWg.Wait() - listenerWgChan <- struct{}{} + srv.connWg.Wait() + finished <- struct{}{} }() select { case <-ctx.Done(): return ctx.Err() - case <-listenerWgChan: + case <-finished: return lnerr } } @@ -319,7 +321,9 @@ func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { } if add { srv.conns[c] = struct{}{} + srv.connWg.Add(1) } else { delete(srv.conns, c) + srv.connWg.Done() } }