Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Clone for TcpStream #689

Merged
3 commits merged into from Jan 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/tcp-echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ use async_std::task;
async fn process(stream: TcpStream) -> io::Result<()> {
println!("Accepted from: {}", stream.peer_addr()?);

let (reader, writer) = &mut (&stream, &stream);
io::copy(reader, writer).await?;
let mut reader = stream.clone();
let mut writer = stream;
io::copy(&mut reader, &mut writer).await?;

Ok(())
}
Expand Down
5 changes: 3 additions & 2 deletions examples/tcp-ipv4-and-6-echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ use async_std::task;
async fn process(stream: TcpStream) -> io::Result<()> {
println!("Accepted from: {}", stream.peer_addr()?);

let (reader, writer) = &mut (&stream, &stream);
io::copy(reader, writer).await?;
let mut reader = stream.clone();
let mut writer = stream;
io::copy(&mut reader, &mut writer).await?;

Ok(())
}
Expand Down
7 changes: 3 additions & 4 deletions src/net/tcp/listener.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;

use crate::future;
use crate::io;
Expand Down Expand Up @@ -75,9 +76,7 @@ impl TcpListener {
/// [`local_addr`]: #method.local_addr
pub async fn bind<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpListener> {
let mut last_err = None;
let addrs = addrs
.to_socket_addrs()
.await?;
let addrs = addrs.to_socket_addrs().await?;

for addr in addrs {
match mio::net::TcpListener::bind(&addr) {
Expand Down Expand Up @@ -121,7 +120,7 @@ impl TcpListener {

let mio_stream = mio::net::TcpStream::from_stream(io)?;
let stream = TcpStream {
watcher: Watcher::new(mio_stream),
watcher: Arc::new(Watcher::new(mio_stream)),
};
Ok((stream, addr))
}
Expand Down
26 changes: 16 additions & 10 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io::{IoSlice, IoSliceMut, Read as _, Write as _};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;

use crate::future;
use crate::io::{self, Read, Write};
Expand Down Expand Up @@ -44,9 +45,9 @@ use crate::task::{Context, Poll};
/// #
/// # Ok(()) }) }
/// ```
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct TcpStream {
pub(super) watcher: Watcher<mio::net::TcpStream>,
pub(super) watcher: Arc<Watcher<mio::net::TcpStream>>,
}

impl TcpStream {
Expand All @@ -71,9 +72,7 @@ impl TcpStream {
/// ```
pub async fn connect<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpStream> {
let mut last_err = None;
let addrs = addrs
.to_socket_addrs()
.await?;
let addrs = addrs.to_socket_addrs().await?;

for addr in addrs {
// mio's TcpStream::connect is non-blocking and may just be in progress
Expand All @@ -84,16 +83,20 @@ impl TcpStream {
Ok(s) => Watcher::new(s),
Err(e) => {
last_err = Some(e);
continue
continue;
}
};

future::poll_fn(|cx| watcher.poll_write_ready(cx)).await;

match watcher.get_ref().take_error() {
Ok(None) => return Ok(TcpStream { watcher }),
Ok(None) => {
return Ok(TcpStream {
watcher: Arc::new(watcher),
});
}
Ok(Some(e)) => last_err = Some(e),
Err(e) => last_err = Some(e)
Err(e) => last_err = Some(e),
}
}

Expand Down Expand Up @@ -369,7 +372,7 @@ impl From<std::net::TcpStream> for TcpStream {
fn from(stream: std::net::TcpStream) -> TcpStream {
let mio_stream = mio::net::TcpStream::from_stream(stream).unwrap();
TcpStream {
watcher: Watcher::new(mio_stream),
watcher: Arc::new(Watcher::new(mio_stream)),
}
}
}
Expand All @@ -391,7 +394,10 @@ cfg_unix! {

impl IntoRawFd for TcpStream {
fn into_raw_fd(self) -> RawFd {
self.watcher.into_inner().into_raw_fd()
// TODO(stjepang): This does not mean `RawFd` is now the sole owner of the file
// descriptor because it's possible that there are other clones of this `TcpStream`
// using it at the same time. We should probably document that behavior.
self.as_raw_fd()
}
}
}
Expand Down
22 changes: 22 additions & 0 deletions tests/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,25 @@ fn smoke_async_stream_to_std_listener() -> io::Result<()> {

Ok(())
}

#[test]
fn cloned_streams() -> io::Result<()> {
task::block_on(async {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;

let mut stream = TcpStream::connect(&addr).await?;
let mut cloned_stream = stream.clone();
let mut incoming = listener.incoming();
let mut write_stream = incoming.next().await.unwrap()?;
write_stream.write_all(b"Each your doing").await?;

let mut buf = [0; 15];
stream.read_exact(&mut buf[..8]).await?;
cloned_stream.read_exact(&mut buf[8..]).await?;

assert_eq!(&buf[..15], b"Each your doing");

Ok(())
})
}