diff --git a/libp2p/connection.nim b/libp2p/connection.nim index 7f8113ee5d..4ad20e0ae5 100644 --- a/libp2p/connection.nim +++ b/libp2p/connection.nim @@ -35,21 +35,28 @@ proc newInvalidVarintException*(): ref InvalidVarintException = proc newInvalidVarintSizeException*(): ref InvalidVarintSizeException = newException(InvalidVarintSizeException, "Wrong varint size") -proc init*[T: Connection](self: var T, stream: LPStream) = +proc bindStreamClose(conn: Connection) {.async.} = + # bind stream's close event to connection's close + # to ensure correct close propagation + if not isNil(conn.stream.closeEvent): + await conn.stream.closeEvent.wait() + trace "wrapped stream closed, about to close conn", closed = this.isClosed, + peer = if not isNil(this.peerInfo): + this.peerInfo.id else: "" + if not conn.isClosed: + trace "wrapped stream closed, closing conn", closed = this.isClosed, + peer = if not isNil(this.peerInfo): + this.peerInfo.id else: "" + asyncCheck conn.close() + +proc init*[T: Connection](self: var T, stream: LPStream): T = ## create a new Connection for the specified async reader/writer new self self.stream = stream self.closeEvent = newAsyncEvent() + asyncCheck self.bindStreamClose() - # bind stream's close event to connection's close - # to ensure correct close propagation - let this = self - if not isNil(self.stream.closeEvent): - self.stream.closeEvent.wait(). - addCallback do (udata: pointer): - if not this.closed: - trace "wrapped stream closed, closing conn" - asyncCheck this.close() + return self proc newConnection*(stream: LPStream): Connection = ## create a new Connection for the specified async reader/writer @@ -108,13 +115,23 @@ method closed*(s: Connection): bool = result = s.stream.closed method close*(s: Connection) {.async, gcsafe.} = - trace "closing connection" + trace "about to close connection", closed = s.closed, + peer = if not isNil(s.peerInfo): + s.peerInfo.id else: "" + if not s.closed: if not isNil(s.stream) and not s.stream.closed: + trace "closing child stream", closed = s.closed, + peer = if not isNil(s.peerInfo): + s.peerInfo.id else: "" await s.stream.close() + s.closeEvent.fire() s.isClosed = true - trace "connection closed", closed = s.closed + + trace "connection closed", closed = s.closed, + peer = if not isNil(s.peerInfo): + s.peerInfo.id else: "" proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} = ## read lenght prefixed msg diff --git a/libp2p/crypto/chacha20poly1305.nim b/libp2p/crypto/chacha20poly1305.nim index 8aa4d80a2e..df2d44b0c2 100644 --- a/libp2p/crypto/chacha20poly1305.nim +++ b/libp2p/crypto/chacha20poly1305.nim @@ -28,7 +28,7 @@ const ChaChaPolyKeySize = 32 ChaChaPolyNonceSize = 12 ChaChaPolyTagSize = 16 - + type ChaChaPoly* = object ChaChaPolyKey* = array[ChaChaPolyKeySize, byte] @@ -46,7 +46,7 @@ proc intoChaChaPolyNonce*(s: openarray[byte]): ChaChaPolyNonce = proc intoChaChaPolyTag*(s: openarray[byte]): ChaChaPolyTag = assert s.len == ChaChaPolyTagSize copyMem(addr result[0], unsafeaddr s[0], ChaChaPolyTagSize) - + # bearssl allows us to use optimized versions # this is reconciled at runtime # we do this in the global scope / module init @@ -85,7 +85,7 @@ proc decrypt*(_: type[ChaChaPoly], unsafeaddr aad[0] else: nil - + ourPoly1305CtmulRun( unsafeaddr key[0], unsafeaddr nonce[0], diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 54ba181d76..6b47458276 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -126,8 +126,12 @@ method handle*(m: Mplex) {.async, gcsafe.} = trace "Exception occurred", exception = exc.msg finally: trace "stopping mplex main loop" - if not m.connection.closed(): - await m.connection.close() + await m.close() + +proc internalCleanup(m: Mplex, conn: Connection) {.async.} = + await conn.closeEvent.wait() + trace "connection closed, cleaning up mplex" + await m.close() proc newMplex*(conn: Connection, maxChanns: uint = MaxChannels): Mplex = @@ -137,11 +141,7 @@ proc newMplex*(conn: Connection, result.remote = initTable[uint64, LPChannel]() result.local = initTable[uint64, LPChannel]() - let m = result - conn.closeEvent.wait() - .addCallback do (udata: pointer): - trace "connection closed, cleaning up mplex" - asyncCheck m.close() + asyncCheck result.internalCleanup(conn) method newStream*(m: Mplex, name: string = "", @@ -154,5 +154,10 @@ method newStream*(m: Mplex, method close*(m: Mplex) {.async, gcsafe.} = trace "closing mplex muxer" + if not m.connection.closed(): + await m.connection.close() + await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.reset())), allFutures(toSeq(m.local.values).mapIt(it.reset()))]) + m.remote.clear() + m.local.clear() diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 9285ae93bf..636322497f 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -138,6 +138,14 @@ method handleConn*(p: PubSub, trace "pubsub peer handler ended, cleaning up" await p.cleanUpHelper(peer) +proc internalClenaup(p: PubSub, conn: Connection) {.async.} = + # handle connection close + var peer = p.getPeer(conn.peerInfo, p.codec) + await conn.closeEvent.wait() + trace "connection closed, cleaning up peer", peer = conn.peerInfo.id + + await p.cleanUpHelper(peer) + method subscribeToPeer*(p: PubSub, conn: Connection) {.base, async.} = var peer = p.getPeer(conn.peerInfo, p.codec) @@ -145,13 +153,7 @@ method subscribeToPeer*(p: PubSub, if not peer.isConnected: peer.conn = conn - # handle connection close - conn.closeEvent.wait() - .addCallback do (udata: pointer = nil): - trace "connection closed, cleaning up peer", - peer = conn.peerInfo.id - - asyncCheck p.cleanUpHelper(peer) + asyncCheck p.internalClenaup(conn) method unsubscribe*(p: PubSub, topics: seq[TopicPair]) {.base, async.} = diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index 83017bb417..7d98d42a15 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -16,6 +16,7 @@ import ../../peer import ../../peerinfo import ../../protobuf/minprotobuf import ../../utility +import ../../stream/lpstream import secure, ../../crypto/[crypto, chacha20poly1305, curve25519, hkdf], ../../stream/bufferstream @@ -26,7 +27,7 @@ logScope: const # https://godoc.org/github.com/libp2p/go-libp2p-noise#pkg-constants NoiseCodec* = "/noise" - + PayloadString = "noise-libp2p-static-key:" ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256" @@ -41,7 +42,7 @@ type KeyPair = object privateKey: Curve25519Key publicKey: Curve25519Key - + # https://noiseprotocol.org/noise.html#the-cipherstate-object CipherState = object k: ChaChaPolyKey @@ -66,7 +67,7 @@ type cs2: CipherState remoteP2psecret: seq[byte] rs: Curve25519Key - + Noise* = ref object of Secure localPrivateKey: PrivateKey localPublicKey: PublicKey @@ -89,7 +90,7 @@ type proc genKeyPair(): KeyPair = result.privateKey = Curve25519Key.random() result.publicKey = result.privateKey.public() - + proc hashProtocol(name: string): MDigest[256] = # If protocol_name is less than or equal to HASHLEN bytes in length, # sets h equal to protocol_name with zero bytes appended to make HASHLEN bytes. @@ -195,7 +196,7 @@ proc split(ss: var SymmetricState): tuple[cs1, cs2: CipherState] = proc init(_: type[HandshakeState]): HandshakeState = result.ss = SymmetricState.init() - + template write_e: untyped = trace "noise write e" # Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key). @@ -302,7 +303,7 @@ proc packNoisePayload(payload: openarray[byte]): seq[byte] = if result.len > uint16.high.int: raise newException(NoiseOversizedPayloadError, "Trying to send an unsupported oversized payload over Noise") - + trace "packed noise payload", inSize = payload.len, outSize = result.len proc unpackNoisePayload(payload: var seq[byte]) = @@ -312,7 +313,7 @@ proc unpackNoisePayload(payload: var seq[byte]) = if size > (payload.len - 2): raise newException(NoiseOversizedPayloadError, "Received a wrong payload size") - + payload = payload[2..^((payload.len - size) - 1)] trace "unpacked noise payload", size = payload.len @@ -362,7 +363,7 @@ proc handshakeXXOutbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Fut msg &= hs.ss.encryptAndHash(packed) await conn.sendHSMessage(msg) - + let (cs1, cs2) = hs.ss.split() return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) @@ -426,9 +427,9 @@ method readMessage(sconn: NoiseConnection): Future[seq[byte]] {.async.} = var plain = sconn.readCs.decryptWithAd([], cipher) unpackNoisePayload(plain) return plain - except AsyncStreamIncompleteError: + except LPStreamIncompleteError: trace "Connection dropped while reading" - except AsyncStreamReadError: + except LPStreamReadError: trace "Error reading from connection" method writeMessage(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} = @@ -460,7 +461,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool = false): Future[S # https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages let signedPayload = p.localPrivateKey.sign(PayloadString.toBytes & p.noisePublicKey.getBytes) - + var libp2pProof = initProtoBuffer() @@ -489,7 +490,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool = false): Future[S raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.") else: trace "Remote signature verified" - + if initiator and not isNil(conn.peerInfo): let pid = PeerID.init(remotePubKey) if not conn.peerInfo.peerId.validate(): @@ -508,10 +509,10 @@ method handshake*(p: Noise, conn: Connection, initiator: bool = false): Future[S secure.readCs = handshakeRes.cs1 secure.writeCs = handshakeRes.cs2 - debug "Noise handshake completed!" + trace "Noise handshake completed!" return secure - + method init*(p: Noise) {.gcsafe.} = procCall Secure(p).init() p.codec = NoiseCodec @@ -523,7 +524,7 @@ method secure*(p: Noise, conn: Connection): Future[Connection] {.async, gcsafe.} warn "securing connection failed", msg = exc.msg if not conn.closed(): await conn.close() - + proc newNoise*(privateKey: PrivateKey; outgoing: bool = true; commonPrologue: seq[byte] = @[]): Noise = new result result.outgoing = outgoing diff --git a/libp2p/protocols/secure/secio.nim b/libp2p/protocols/secure/secio.nim index 0452feff43..95a99ffb30 100644 --- a/libp2p/protocols/secure/secio.nim +++ b/libp2p/protocols/secure/secio.nim @@ -281,11 +281,11 @@ proc transactMessage(conn: Connection, else: trace "Received size of message exceed limits", conn = $conn, length = length - except AsyncStreamIncompleteError: + except LPStreamIncompleteError: trace "Connection dropped while reading", conn = $conn - except AsyncStreamReadError: + except LPStreamReadError: trace "Error reading from connection", conn = $conn - except AsyncStreamWriteError: + except LPStreamWriteError: trace "Could not write to connection", conn = $conn method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} = diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 5c6815849a..7860b39d3e 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -19,6 +19,7 @@ import ../protocol, type Secure* = ref object of LPProtocol # base type for secure managers + SecureConn* = ref object of Connection method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} = @@ -32,8 +33,9 @@ method handshake(s: Secure, initiator: bool = false): Future[SecureConn] {.async, base.} = doAssert(false, "Not implemented!") -proc readLoop(sconn: SecureConn, stream: BufferStream) {.async.} = +proc readLoop(sconn: SecureConn, conn: Connection) {.async.} = try: + let stream = BufferStream(conn.stream) while not sconn.closed: let msg = await sconn.readMessage() if msg.len == 0: @@ -44,9 +46,14 @@ proc readLoop(sconn: SecureConn, stream: BufferStream) {.async.} = except CatchableError as exc: trace "Exception occurred Secure.readLoop", exc = exc.msg finally: + trace "closing conn", closed = conn.closed() + if not conn.closed: + await conn.close() + + trace "closing sconn", closed = sconn.closed() if not sconn.closed: await sconn.close() - trace "ending Secure readLoop", isclosed = sconn.closed() + trace "ending Secure readLoop" proc handleConn*(s: Secure, conn: Connection, initiator: bool = false): Future[Connection] {.async, gcsafe.} = var sconn = await s.handshake(conn, initiator) @@ -54,14 +61,8 @@ proc handleConn*(s: Secure, conn: Connection, initiator: bool = false): Future[C trace "sending encrypted bytes", bytes = data.shortLog await sconn.writeMessage(data) - var stream = newBufferStream(writeHandler) - asyncCheck readLoop(sconn, stream) - result = newConnection(stream) - result.closeEvent.wait() - .addCallback do (udata: pointer): - trace "wrapped connection closed, closing upstream" - if not isNil(sconn) and not sconn.closed: - asyncCheck sconn.close() + result = newConnection(newBufferStream(writeHandler)) + asyncCheck readLoop(sconn, result) if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome: result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) diff --git a/libp2p/switch.nim b/libp2p/switch.nim index f4a7bbdce2..4c33f4f84c 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -198,6 +198,7 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = # handle subsequent requests await ms.handle(sconn) + await sconn.close() if (await ms.select(conn)): # just handshake # add the secure handlers @@ -289,9 +290,7 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = except CatchableError as exc: trace "Exception occurred in Switch.start", exc = exc.msg finally: - if not isNil(conn) and not conn.closed: - await conn.close() - + await conn.close() await s.cleanupConn(conn) var startFuts: seq[Future[void]] diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index 70d1d76c81..cbd2985e0a 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -21,6 +21,10 @@ logScope: type TcpTransport* = ref object of Transport server*: StreamServer +proc cleanup(t: Transport, conn: Connection) {.async.} = + await conn.closeEvent.wait() + t.connections.keepItIf(it != conn) + proc connHandler*(t: Transport, server: StreamServer, client: StreamTransport, @@ -30,10 +34,12 @@ proc connHandler*(t: Transport, let conn: Connection = newConnection(newChronosStream(server, client)) conn.observedAddrs = MultiAddress.init(client.remoteAddress) if not initiator: - let handlerFut = if isNil(t.handler): nil else: t.handler(conn) - let connHolder: ConnHolder = ConnHolder(connection: conn, - connFuture: handlerFut) - t.connections.add(connHolder) + if not isNil(t.handler): + asyncCheck t.handler(conn) + + t.connections.add(conn) + asyncCheck t.cleanup(conn) + result = conn proc connCb(server: StreamServer, @@ -51,10 +57,11 @@ method close*(t: TcpTransport): Future[void] {.async, gcsafe.} = await procCall Transport(t).close() # call base # server can be nil - if t.server != nil: + if not isNil(t.server): t.server.stop() t.server.close() - trace "transport stopped" + await t.server.join() + trace "transport stopped" method listen*(t: TcpTransport, ma: MultiAddress, diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index 8d5a2c4ccb..564afd7026 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -16,13 +16,9 @@ import ../connection, type ConnHandler* = proc (conn: Connection): Future[void] {.gcsafe.} - ConnHolder* = object - connection*: Connection - connFuture*: Future[void] - Transport* = ref object of RootObj ma*: Multiaddress - connections*: seq[ConnHolder] + connections*: seq[Connection] handler*: ConnHandler multicodec*: MultiCodec @@ -37,7 +33,7 @@ proc newTransport*(t: typedesc[Transport]): t {.gcsafe.} = method close*(t: Transport) {.base, async, gcsafe.} = ## stop and cleanup the transport ## including all outstanding connections - await allFutures(t.connections.mapIt(it.connection.close())) + await allFutures(t.connections.mapIt(it.close())) method listen*(t: Transport, ma: MultiAddress, diff --git a/tests/testconnection.nim b/tests/testconnection.nim new file mode 100644 index 0000000000..0dfb3f02db --- /dev/null +++ b/tests/testconnection.nim @@ -0,0 +1,50 @@ +import unittest +import chronos, nimcrypto/utils +import ../libp2p/[connection, + stream/lpstream, + stream/bufferstream] + +suite "Connection": + test "close": + proc test(): Future[bool] {.async.} = + var conn = newConnection(newBufferStream()) + await conn.close() + check: + conn.closed == true + + result = true + + check: + waitFor(test()) == true + + test "parent close": + proc test(): Future[bool] {.async.} = + var buf = newBufferStream() + var conn = newConnection(buf) + + await conn.close() + check: + conn.closed == true + buf.closed == true + + await sleepAsync(1.seconds) + result = true + + check: + waitFor(test()) == true + + test "child close": + proc test(): Future[bool] {.async.} = + var buf = newBufferStream() + var conn = newConnection(buf) + + await buf.close() + check: + conn.closed == true + buf.closed == true + + await sleepAsync(1.seconds) + result = true + + check: + waitFor(test()) == true