From fadc662bfb65f65679e410f1f8b4de0dee6d84e9 Mon Sep 17 00:00:00 2001 From: driftluo Date: Fri, 24 Sep 2021 19:39:34 +0800 Subject: [PATCH 1/2] fix: fix tls and ws listen poll --- bench/Cargo.toml | 2 +- tentacle/Cargo.toml | 2 +- tentacle/src/transports/mod.rs | 35 +- tentacle/src/transports/tls.rs | 85 ++-- tentacle/src/transports/ws.rs | 84 ++-- .../tests/certificates/node2-wrong/ca.crt | 16 + .../tests/certificates/node2-wrong/server.crt | 17 + .../tests/certificates/node2-wrong/server.key | 5 + tentacle/tests/test_tls_reconnect.rs | 393 ++++++++++++++++++ 9 files changed, 551 insertions(+), 88 deletions(-) create mode 100644 tentacle/tests/certificates/node2-wrong/ca.crt create mode 100644 tentacle/tests/certificates/node2-wrong/server.crt create mode 100644 tentacle/tests/certificates/node2-wrong/server.key create mode 100644 tentacle/tests/test_tls_reconnect.rs diff --git a/bench/Cargo.toml b/bench/Cargo.toml index c2f9ac9d..c143075b 100644 --- a/bench/Cargo.toml +++ b/bench/Cargo.toml @@ -17,6 +17,6 @@ rand = "0.7.1" futures = { version = "0.3.0" } tokio = { version = "1.0.0", features = ["time", "io-util", "net", "rt-multi-thread"] } tokio-util = { version = "0.6.0", features = ["codec"] } -crossbeam-channel = "0.3.6" +crossbeam-channel = "0.5" env_logger = "0.6.0" bytes = "1.0.0" diff --git a/tentacle/Cargo.toml b/tentacle/Cargo.toml index c8f9d1f6..c666b857 100644 --- a/tentacle/Cargo.toml +++ b/tentacle/Cargo.toml @@ -64,7 +64,7 @@ socket2 = { version = "0.4.0", optional = true } [dev-dependencies] env_logger = "0.6.0" -crossbeam-channel = "0.3.6" +crossbeam-channel = "0.5" systemstat = "0.1.3" futures-test = "0.3.5" diff --git a/tentacle/src/transports/mod.rs b/tentacle/src/transports/mod.rs index c873b4d9..7c1eba0c 100644 --- a/tentacle/src/transports/mod.rs +++ b/tentacle/src/transports/mod.rs @@ -407,22 +407,27 @@ mod os { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match self.get_mut() { - MultiIncoming::Tcp(inner) => match inner.poll_accept(cx)? { - // Why can't get the peer address of the connected stream ? - // Error will be "Transport endpoint is not connected", - // so why incoming will appear unconnected stream ? - Poll::Ready((stream, _)) => match stream.peer_addr() { - Ok(remote_address) => Poll::Ready(Some(Ok(( - socketaddr_to_multiaddr(remote_address), - MultiStream::Tcp(stream), - )))), - Err(err) => { - debug!("stream get peer address error: {:?}", err); - Poll::Pending + MultiIncoming::Tcp(inner) => { + loop { + match inner.poll_accept(cx)? { + // Why can't get the peer address of the connected stream ? + // Error will be "Transport endpoint is not connected", + // so why incoming will appear unconnected stream ? + Poll::Ready((stream, _)) => match stream.peer_addr() { + Ok(remote_address) => { + break Poll::Ready(Some(Ok(( + socketaddr_to_multiaddr(remote_address), + MultiStream::Tcp(stream), + )))) + } + Err(err) => { + debug!("stream get peer address error: {:?}", err); + } + }, + Poll::Pending => break Poll::Pending, } - }, - Poll::Pending => Poll::Pending, - }, + } + } MultiIncoming::Memory(inner) => match inner.poll_next_unpin(cx)? { Poll::Ready(Some((addr, stream))) => { Poll::Ready(Some(Ok((addr, MultiStream::Memory(stream))))) diff --git a/tentacle/src/transports/tls.rs b/tentacle/src/transports/tls.rs index cd4a62e0..55639248 100644 --- a/tentacle/src/transports/tls.rs +++ b/tentacle/src/transports/tls.rs @@ -106,15 +106,48 @@ impl TlsListener { } } - fn poll_pending( - &mut self, - cx: &mut Context, - ) -> Poll>> { + fn poll_pending(&mut self, cx: &mut Context) -> Poll<(Multiaddr, TlsStream)> { match Pin::new(&mut self.pending_stream).as_mut().poll_next(cx) { - Poll::Ready(Some(res)) => Poll::Ready(Some(Ok(res))), + Poll::Ready(Some(res)) => Poll::Ready(res), Poll::Ready(None) | Poll::Pending => Poll::Pending, } } + + fn poll_listen(&mut self, cx: &mut Context) -> Poll> { + match self.inner.poll_accept(cx)? { + Poll::Ready((stream, _)) => { + match stream.peer_addr() { + Ok(remote_address) => { + let timeout = self.timeout; + let mut sender = self.sender.clone(); + let acceptor = TlsAcceptor::from(Arc::clone(&self.tls_config)); + crate::runtime::spawn(async move { + match crate::runtime::timeout(timeout, acceptor.accept(stream)).await { + Err(_) => warn!("accept tls server stream timeout"), + Ok(res) => match res { + Ok(stream) => { + let mut addr = socketaddr_to_multiaddr(remote_address); + addr.push(Protocol::Tls(Cow::Borrowed(""))); + if sender.send((addr, Box::new(stream))).await.is_err() { + warn!("receiver closed unexpectedly") + } + } + Err(err) => { + warn!("accept tls server stream err: {:?}", err); + } + }, + } + }); + } + Err(err) => { + warn!("stream get peer address error: {:?}", err); + } + } + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + } + } } impl Stream for TlsListener { @@ -122,41 +155,21 @@ impl Stream for TlsListener { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if let Poll::Ready(res) = self.poll_pending(cx) { - return Poll::Ready(res); + return Poll::Ready(Some(Ok(res))); } - match self.inner.poll_accept(cx)? { - Poll::Ready((stream, _)) => match stream.peer_addr() { - Ok(remote_address) => { - let timeout = self.timeout; - let mut sender = self.sender.clone(); - let acceptor = TlsAcceptor::from(Arc::clone(&self.tls_config)); - crate::runtime::spawn(async move { - match crate::runtime::timeout(timeout, acceptor.accept(stream)).await { - Err(_) => warn!("accept tls server stream timeout"), - Ok(res) => match res { - Ok(stream) => { - let mut addr = socketaddr_to_multiaddr(remote_address); - addr.push(Protocol::Tls(Cow::Borrowed(""))); - if sender.send((addr, Box::new(stream))).await.is_err() { - warn!("receiver closed unexpectedly") - } - } - Err(err) => { - warn!("accept tls server stream err: {:?}", err); - } - }, - } - }); - self.poll_pending(cx) + loop { + let is_pending = self.poll_listen(cx)?.is_pending(); + match self.poll_pending(cx) { + Poll::Ready(res) => return Poll::Ready(Some(Ok(res))), + Poll::Pending => { + if is_pending { + break; + } } - Err(err) => { - warn!("stream get peer address error: {:?}", err); - Poll::Pending - } - }, - Poll::Pending => Poll::Pending, + } } + Poll::Pending } } diff --git a/tentacle/src/transports/ws.rs b/tentacle/src/transports/ws.rs index 70f66d2d..b3cddf5c 100644 --- a/tentacle/src/transports/ws.rs +++ b/tentacle/src/transports/ws.rs @@ -158,15 +158,48 @@ impl WebsocketListener { } } - fn poll_pending( - &mut self, - cx: &mut Context, - ) -> Poll>> { + fn poll_pending(&mut self, cx: &mut Context) -> Poll<(Multiaddr, WsStream)> { match Pin::new(&mut self.pending_stream).as_mut().poll_next(cx) { - Poll::Ready(Some(res)) => Poll::Ready(Some(Ok(res))), + Poll::Ready(Some(res)) => Poll::Ready(res), Poll::Ready(None) | Poll::Pending => Poll::Pending, } } + + fn poll_listen(&mut self, cx: &mut Context) -> Poll> { + match self.inner.poll_accept(cx)? { + Poll::Ready((stream, _)) => { + match stream.peer_addr() { + Ok(remote_address) => { + let timeout = self.timeout; + let mut sender = self.sender.clone(); + crate::runtime::spawn(async move { + match crate::runtime::timeout(timeout, accept_async(stream)).await { + Err(_) => debug!("accept websocket stream timeout"), + Ok(res) => match res { + Ok(stream) => { + let mut addr = socketaddr_to_multiaddr(remote_address); + addr.push(Protocol::Ws); + if sender.send((addr, WsStream::new(stream))).await.is_err() + { + debug!("receiver closed unexpectedly") + } + } + Err(err) => { + debug!("accept websocket stream err: {:?}", err); + } + }, + } + }); + } + Err(err) => { + debug!("stream get peer address error: {:?}", err); + } + } + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + } + } } impl Stream for WebsocketListener { @@ -174,40 +207,21 @@ impl Stream for WebsocketListener { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if let Poll::Ready(res) = self.poll_pending(cx) { - return Poll::Ready(res); + return Poll::Ready(Some(Ok(res))); } - match self.inner.poll_accept(cx)? { - Poll::Ready((stream, _)) => match stream.peer_addr() { - Ok(remote_address) => { - let timeout = self.timeout; - let mut sender = self.sender.clone(); - crate::runtime::spawn(async move { - match crate::runtime::timeout(timeout, accept_async(stream)).await { - Err(_) => debug!("accept websocket stream timeout"), - Ok(res) => match res { - Ok(stream) => { - let mut addr = socketaddr_to_multiaddr(remote_address); - addr.push(Protocol::Ws); - if sender.send((addr, WsStream::new(stream))).await.is_err() { - debug!("receiver closed unexpectedly") - } - } - Err(err) => { - debug!("accept websocket stream err: {:?}", err); - } - }, - } - }); - self.poll_pending(cx) - } - Err(err) => { - debug!("stream get peer address error: {:?}", err); - Poll::Pending + loop { + let is_pending = self.poll_listen(cx)?.is_pending(); + match self.poll_pending(cx) { + Poll::Ready(res) => return Poll::Ready(Some(Ok(res))), + Poll::Pending => { + if is_pending { + break; + } } - }, - Poll::Pending => Poll::Pending, + } } + Poll::Pending } } diff --git a/tentacle/tests/certificates/node2-wrong/ca.crt b/tentacle/tests/certificates/node2-wrong/ca.crt new file mode 100644 index 00000000..acdea33b --- /dev/null +++ b/tentacle/tests/certificates/node2-wrong/ca.crt @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE----- +MIICmzCCAkGgAwIBAgIBKjAKBggqhkjOPQQDAjCBhzELMAkGA1UEBgwCQ04xCzAJ +BgNVBAgMAlpKMQswCQYDVQQHDAJIWjENMAsGA1UECgwEQ0lUQTEaMBgGA1UECwwR +QmxvY2tjaGFpbkRldmVsb3AxMzAxBgNVBAMMKjB4MzdkMWM3NDQ5YmZlNzZmZTlj +NDQ1ZTYyNmRhMDYyNjVlOTM3NzYwMTAgFw03NTAxMDEwMDAwMDBaGA80MDk2MDEw +MTAwMDAwMFowgYcxCzAJBgNVBAYMAkNOMQswCQYDVQQIDAJaSjELMAkGA1UEBwwC +SFoxDTALBgNVBAoMBENJVEExGjAYBgNVBAsMEUJsb2NrY2hhaW5EZXZlbG9wMTMw +MQYDVQQDDCoweDM3ZDFjNzQ0OWJmZTc2ZmU5YzQ0NWU2MjZkYTA2MjY1ZTkzNzc2 +MDEwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARqckVxxyXetaUkDFmotKVr+ryn +46Ab1HQMce1wavSMTXTfKFNF17yIbb/p+m/bMikDIDRlFdrPzCTItAxp2rvgo4GZ +MIGWMDUGA1UdEQQuMCyCKjB4MzdkMWM3NDQ5YmZlNzZmZTljNDQ1ZTYyNmRhMDYy +NjVlOTM3NzYwMTAdBgNVHSUEFjAUBggrBgEFBQcDAgYIKwYBBQUHAwEwHQYDVR0O +BBYEFCDzfo3cfAjZX6IipouRFKxtECeeMA8GA1UdEwEB/wQFMAMBAf8wDgYDVR0P +AQH/BAQDAgWgMAoGCCqGSM49BAMCA0gAMEUCIEyBoZBxC0CSAYMDGs9bB8gv2kmq +yTrIjeqz3znDjrj3AiEA5rzZpZtgf0/Vj74V6l0Cv8B7Pa/yQ8z7kf8qNXL8d3o= +-----END CERTIFICATE----- diff --git a/tentacle/tests/certificates/node2-wrong/server.crt b/tentacle/tests/certificates/node2-wrong/server.crt new file mode 100644 index 00000000..c7949d05 --- /dev/null +++ b/tentacle/tests/certificates/node2-wrong/server.crt @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICnTCCAkOgAwIBAgIBKjAKBggqhkjOPQQDAjCBhzELMAkGA1UEBgwCQ04xCzAJ +BgNVBAgMAlpKMQswCQYDVQQHDAJIWjENMAsGA1UECgwEQ0lUQTEaMBgGA1UECwwR +QmxvY2tjaGFpbkRldmVsb3AxMzAxBgNVBAMMKjB4MzdkMWM3NDQ5YmZlNzZmZTlj +NDQ1ZTYyNmRhMDYyNjVlOTM3NzYwMTAgFw03NTAxMDEwMDAwMDBaGA80MDk2MDEw +MTAwMDAwMFowgYcxCzAJBgNVBAYMAkNOMQswCQYDVQQIDAJaSjELMAkGA1UEBwwC +SFoxDTALBgNVBAoMBENJVEExGjAYBgNVBAsMEUJsb2NrY2hhaW5EZXZlbG9wMTMw +MQYDVQQDDCoweDMyOWY1Y2JmYTY3MWZiN2UzYmI2YmM0NTUxMGQ5NWY0YWM2YTZj +YjgwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAASNZNyFgQF+nQPbM9Mhvq0CMNzj +BqWgG/vprroFWsllCfHM79VfYOx1Sn2FwKQWBoU5vNO4XFGnnhqbh6jXURPfo4Gb +MIGYMB8GA1UdIwQYMBaAFCDzfo3cfAjZX6IipouRFKxtECeeMDUGA1UdEQQuMCyC +KjB4MzI5ZjVjYmZhNjcxZmI3ZTNiYjZiYzQ1NTEwZDk1ZjRhYzZhNmNiODAdBgNV +HSUEFjAUBggrBgEFBQcDAgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUwAwEBADAOBgNV +HQ8BAf8EBAMCBaAwCgYIKoZIzj0EAwIDSAAwRQIgcsxlB9Hc8uhDt/pogD0kaZ+o +wTdCK+v9AKEOzNrmaQICIQCKf3bs5EwZbucSyr6/udPPXrALnKgo1+oAQQiwffY6 +0A== +-----END CERTIFICATE----- diff --git a/tentacle/tests/certificates/node2-wrong/server.key b/tentacle/tests/certificates/node2-wrong/server.key new file mode 100644 index 00000000..6399d70e --- /dev/null +++ b/tentacle/tests/certificates/node2-wrong/server.key @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgsD0e5mrPUvCcwczo +B4Ot+0nstgq+aJ5Pnhxn//oA9WuhRANCAASNZNyFgQF+nQPbM9Mhvq0CMNzjBqWg +G/vprroFWsllCfHM79VfYOx1Sn2FwKQWBoU5vNO4XFGnnhqbh6jXURPf +-----END PRIVATE KEY----- diff --git a/tentacle/tests/test_tls_reconnect.rs b/tentacle/tests/test_tls_reconnect.rs new file mode 100644 index 00000000..24bcd659 --- /dev/null +++ b/tentacle/tests/test_tls_reconnect.rs @@ -0,0 +1,393 @@ +#![cfg(feature = "tls")] +use crossbeam_channel::Receiver; +use std::io::BufReader; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use std::{fs, thread}; +use tentacle::bytes::Bytes; +use tentacle::service::ServiceControl; +use tentacle::{ + async_trait, + builder::{MetaBuilder, ServiceBuilder}, + context::{ProtocolContext, ProtocolContextMutRef}, + multiaddr::Multiaddr, + service::{ProtocolHandle, ProtocolMeta, Service, TargetProtocol, TlsConfig}, + traits::{ServiceHandle, ServiceProtocol}, + ProtocolId, +}; +use tokio_rustls::rustls::internal::pemfile::{certs, pkcs8_private_keys, rsa_private_keys}; +use tokio_rustls::rustls::{ + AllowAnyAuthenticatedClient, Certificate, ClientConfig, KeyLogFile, NoClientAuth, PrivateKey, + ProtocolVersion, RootCertStore, ServerConfig, SupportedCipherSuite, ALL_CIPHERSUITES, +}; + +pub fn create(meta: ProtocolMeta, shandle: F, cert_path: String) -> Service +where + F: ServiceHandle + Unpin, +{ + let mut builder = ServiceBuilder::default() + .insert_protocol(meta) + .forever(true); + + let tls_config = TlsConfig::new( + Some(make_server_config(&NetConfig::example(cert_path.clone()))), + Some(make_client_config(&NetConfig::example(cert_path))), + ); + builder = builder.tls_config(tls_config); + + builder.build(shandle) +} + +struct PHandle { + sender: crossbeam_channel::Sender, + send: bool, +} + +#[async_trait] +impl ServiceProtocol for PHandle { + async fn init(&mut self, _context: &mut ProtocolContext) {} + + async fn connected(&mut self, context: ProtocolContextMutRef<'_>, _version: &str) { + if !self.send { + context + .send_message(bytes::Bytes::from("hello world")) + .await + .unwrap(); + } + } + + async fn received(&mut self, _context: ProtocolContextMutRef<'_>, data: bytes::Bytes) { + if self.send { + self.sender.try_send(data).unwrap(); + } + } +} + +#[derive(Debug, Clone)] +pub struct NetConfig { + server_cert_chain: Option, + server_key: Option, + + ca_cert: Option, + + protocols: Option>, + cypher_suits: Option>, +} + +impl NetConfig { + fn example(node_dir: String) -> Self { + Self { + server_cert_chain: Some(node_dir.clone() + "server.crt"), + server_key: Some(node_dir.clone() + "server.key"), + ca_cert: Some(node_dir + "ca.crt"), + + protocols: None, + cypher_suits: None, + } + } +} + +fn create_meta( + id: ProtocolId, + send: bool, +) -> (ProtocolMeta, crossbeam_channel::Receiver) { + // NOTE: channel size must large, otherwise send will failed. + let (sender, receiver) = crossbeam_channel::unbounded(); + + let meta = MetaBuilder::new() + .id(id) + .service_handle(move || { + if id == 0.into() { + ProtocolHandle::None + } else { + let handle = Box::new(PHandle { sender, send }); + ProtocolHandle::Callback(handle) + } + }) + .build(); + + (meta, receiver) +} + +fn create_shandle() -> Box { + // NOTE: channel size must large, otherwise send will failed. + Box::new(()) +} + +fn find_suite(name: &str) -> Option<&'static SupportedCipherSuite> { + for suite in &ALL_CIPHERSUITES { + let cs_name = format!("{:?}", suite.suite).to_lowercase(); + + if cs_name == name.to_string().to_lowercase() { + return Some(suite); + } + } + + None +} + +fn lookup_suites(suites: &[String]) -> Vec<&'static SupportedCipherSuite> { + let mut out = Vec::new(); + + for cs_name in suites { + let scs = find_suite(cs_name); + match scs { + Some(s) => out.push(s), + None => panic!("cannot look up cipher suite '{}'", cs_name), + } + } + + out +} + +/// Make a vector of protocol versions named in `versions` +fn lookup_versions(versions: &[String]) -> Vec { + let mut out = Vec::new(); + + for vname in versions { + let version = match vname.as_ref() { + "1.2" => ProtocolVersion::TLSv1_2, + "1.3" => ProtocolVersion::TLSv1_3, + _ => panic!( + "cannot look up version '{}', valid are '1.2' and '1.3'", + vname + ), + }; + out.push(version); + } + + out +} + +fn load_certs(filename: &str) -> Vec { + let certfile = fs::File::open(filename).expect("cannot open certificate file"); + let mut reader = BufReader::new(certfile); + certs(&mut reader).unwrap() +} + +fn load_private_key(filename: &str) -> PrivateKey { + let rsa_keys = { + let keyfile = fs::File::open(filename).expect("cannot open private key file"); + let mut reader = BufReader::new(keyfile); + rsa_private_keys(&mut reader).expect("file contains invalid rsa private key") + }; + + let pkcs8_keys = { + let keyfile = fs::File::open(filename).expect("cannot open private key file"); + let mut reader = BufReader::new(keyfile); + pkcs8_private_keys(&mut reader) + .expect("file contains invalid pkcs8 private key (encrypted keys not supported)") + }; + + // prefer to load pkcs8 keys + if !pkcs8_keys.is_empty() { + pkcs8_keys[0].clone() + } else { + assert!(!rsa_keys.is_empty()); + rsa_keys[0].clone() + } +} + +fn load_key_and_cert(config: &mut ClientConfig, keyfile: &str, certsfile: &str, cafile: &str) { + let mut certs = load_certs(certsfile); + let cacerts = load_certs(cafile); + let privkey = load_private_key(keyfile); + + // Specially for server.crt not a cert-chain only one server certificate, so manually make + // a cert-chain. + if certs.len() == 1 && !cacerts.is_empty() { + certs.extend(cacerts); + } + + config + .set_single_client_cert(certs, privkey) + .expect("invalid certificate or private key"); +} + +/// Build a `ServerConfig` from our NetConfig +pub fn make_server_config(config: &NetConfig) -> ServerConfig { + let cacerts = load_certs(config.ca_cert.as_ref().unwrap()); + + // server could use `NoClientAuth` mod let client connect freely + let client_auth = if config.ca_cert.is_some() { + let mut client_auth_roots = RootCertStore::empty(); + for cacert in &cacerts { + client_auth_roots.add(cacert).unwrap(); + } + AllowAnyAuthenticatedClient::new(client_auth_roots) + } else { + NoClientAuth::new() + }; + + let mut server_config = ServerConfig::new(client_auth); + server_config.key_log = Arc::new(KeyLogFile::new()); + + let mut certs = load_certs( + config + .server_cert_chain + .as_ref() + .expect("server_cert_chain option missing"), + ); + let privkey = load_private_key( + config + .server_key + .as_ref() + .expect("server_key option missing"), + ); + + // Specially for server.crt not a cert-chain only one server certificate, so manually make + // a cert-chain. + if certs.len() == 1 && !cacerts.is_empty() { + certs.extend(cacerts); + } + + server_config + .set_single_cert_with_ocsp_and_sct(certs, privkey, vec![], vec![]) + .expect("bad certificates/private key"); + + if config.cypher_suits.is_some() { + server_config.ciphersuites = lookup_suites( + &config + .cypher_suits + .as_ref() + .expect("cypher_suits option error"), + ); + } + + if config.protocols.is_some() { + server_config.versions = lookup_versions(config.protocols.as_ref().unwrap()); + server_config.set_protocols( + &config + .protocols + .as_ref() + .unwrap() + .iter() + .map(|proto| proto.as_bytes().to_vec()) + .collect::>()[..], + ); + } + + server_config +} + +/// Build a `ClientConfig` from our NetConfig +pub fn make_client_config(config: &NetConfig) -> ClientConfig { + let mut client_config = ClientConfig::new(); + client_config.key_log = Arc::new(KeyLogFile::new()); + + if config.cypher_suits.is_some() { + client_config.ciphersuites = lookup_suites(config.cypher_suits.as_ref().unwrap()); + } + + if config.protocols.is_some() { + client_config.versions = lookup_versions(config.protocols.as_ref().unwrap()); + + client_config.set_protocols( + &config + .protocols + .as_ref() + .unwrap() + .iter() + .map(|proto| proto.as_bytes().to_vec()) + .collect::>()[..], + ); + } + + let cafile = config.ca_cert.as_ref().unwrap(); + + let certfile = fs::File::open(cafile).expect("Cannot open CA file"); + let mut reader = BufReader::new(certfile); + client_config.root_store.add_pem_file(&mut reader).unwrap(); + + if config.server_key.is_some() || config.server_cert_chain.is_some() { + load_key_and_cert( + &mut client_config, + config + .server_key + .as_ref() + .expect("must provide client_key with client_cert"), + config + .server_cert_chain + .as_ref() + .expect("must provide client_cert with client_key"), + cafile, + ); + } + + client_config +} + +fn server_node(path: String, listen_address: Multiaddr) -> (Receiver, Multiaddr) { + let (meta, receiver) = create_meta(1.into(), true); + let shandle = create_shandle(); + let (addr_sender, addr_receiver) = crossbeam_channel::unbounded(); + + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create(meta, shandle, path); + rt.block_on(async move { + let listen_addr = service.listen(listen_address).await.unwrap(); + let _res = addr_sender.send(listen_addr); + service.run().await + }); + }); + + (receiver, addr_receiver.recv().unwrap()) +} + +fn clint_node_connect(path: String, dial_address: Multiaddr) { + let (meta, _) = create_meta(1.into(), false); + let shandle = create_shandle(); + + let mut service = create(meta, shandle, path); + let control: ServiceControl = service.control().clone().into(); + let handle = thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async move { + let _ = service.dial(dial_address, TargetProtocol::All).await; + service.run().await + }); + }); + thread::sleep(Duration::from_secs(3)); + + let _ignore = control.shutdown(); + handle.join().expect("test fail"); +} + +#[test] +// only node1 connect node0 +fn test_tls_reconnect_ok() { + let (receiver, dail_addr) = server_node( + "tests/certificates/node0/".to_string(), + Multiaddr::from_str("/ip4/127.0.0.1/tcp/0/tls/0x09cbaa785348dabd54c61f5f9964474f7bfad7df") + .unwrap(), + ); + + for _ in 0..2 { + clint_node_connect("tests/certificates/node1/".to_string(), dail_addr.clone()); + assert_eq!(receiver.recv(), Ok(bytes::Bytes::from("hello world"))); + } +} + +#[test] +// node1 and node2-wrong connect node1 +fn test_tls_reconnect_wrong() { + let (receiver, dail_addr) = server_node( + "tests/certificates/node0/".to_string(), + Multiaddr::from_str("/ip4/127.0.0.1/tcp/0/tls/0x09cbaa785348dabd54c61f5f9964474f7bfad7df") + .unwrap(), + ); + + // the first round everything is ok, but the second round node1 can't connect node0, and the + // test blocked + for _ in 0..2 { + clint_node_connect("tests/certificates/node1/".to_string(), dail_addr.clone()); + // due to error certificates the node2 would connect error + clint_node_connect( + "tests/certificates/node2-wrong/".to_string(), + dail_addr.clone(), + ); + assert_eq!(receiver.recv(), Ok(bytes::Bytes::from("hello world"))); + } +} From 3bc0036496deddbe1f1eaa791ded3a7400eefbdd Mon Sep 17 00:00:00 2001 From: driftluo Date: Sat, 25 Sep 2021 13:28:55 +0800 Subject: [PATCH 2/2] ci: fix ci --- tentacle/src/service/config.rs | 4 ++-- tentacle/tests/test_close.rs | 38 +++++++++++++++++----------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tentacle/src/service/config.rs b/tentacle/src/service/config.rs index 1a824383..fad25d09 100644 --- a/tentacle/src/service/config.rs +++ b/tentacle/src/service/config.rs @@ -121,7 +121,7 @@ pub enum TargetProtocol { /// Try open one protocol Single(ProtocolId), /// Try open some protocol, if return true, open it - Filter(Box bool + Send>), + Filter(Box bool + Sync + Send + 'static>), } impl From for TargetProtocol { @@ -143,7 +143,7 @@ pub enum TargetSession { /// Try send to only one Single(SessionId), /// Try send to some session, if return true, send to it - Filter(Box bool + Send>), + Filter(Box bool + Sync + Send + 'static>), } impl From for TargetSession { diff --git a/tentacle/tests/test_close.rs b/tentacle/tests/test_close.rs index 2210f778..7d3a8e22 100644 --- a/tentacle/tests/test_close.rs +++ b/tentacle/tests/test_close.rs @@ -131,9 +131,11 @@ fn test(secio: bool, shutdown: bool) { let (addr_sender, addr_receiver) = channel::(); - let handle = thread::spawn(|| { - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async move { + let rt = tokio::runtime::Runtime::new().unwrap(); + let async_handle = rt.handle().clone(); + + let handle = thread::spawn(move || { + async_handle.block_on(async move { let listen_addr = service_1 .listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .await @@ -151,10 +153,11 @@ fn test(secio: bool, shutdown: bool) { let listen_addr_3 = listen_addr.clone(); let listen_addr_4 = listen_addr.clone(); - start_service(service_2, listen_addr_2); - start_service(service_3, listen_addr_3); - start_service(service_4, listen_addr_4); - start_service(service_5, listen_addr); + let async_handle = rt.handle(); + start_service(service_2, listen_addr_2, async_handle); + start_service(service_3, listen_addr_3, async_handle); + start_service(service_4, listen_addr_4, async_handle); + start_service(service_5, listen_addr, async_handle); handle.join().expect("test fail"); } @@ -162,21 +165,18 @@ fn test(secio: bool, shutdown: bool) { fn start_service( mut service: Service, listen_addr: Multiaddr, -) -> ::std::thread::JoinHandle<()> -where + handle: &tokio::runtime::Handle, +) where F: ServiceHandle + Unpin + Send + 'static, { - thread::spawn(move || { - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async move { - service - .dial(listen_addr, TargetProtocol::All) - .await - .unwrap(); + handle.spawn(async move { + service + .dial(listen_addr, TargetProtocol::All) + .await + .unwrap(); - service.run().await - }); - }) + service.run().await; + }); } #[test]