diff --git a/src/sys/unix.rs b/src/sys/unix.rs index 418c6907..fe9ba3b0 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -204,6 +204,10 @@ const MAX_BUF_LEN: usize = ssize_t::MAX as usize; #[cfg(target_vendor = "apple")] const MAX_BUF_LEN: usize = c_int::MAX as usize - 1; +// TCP_CA_NAME_MAX isn't defined in user space include files(not in libc) +#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))] +const TCP_CA_NAME_MAX: usize = 16; + #[cfg(any( all( target_os = "linux", @@ -2154,6 +2158,55 @@ impl crate::Socket { ) } } + + /// Get the value of the `TCP_CONGESTION` option for this socket. + /// + /// For more information about this option, see [`set_tcp_congestion`]. + /// + /// [`set_tcp_congestion`]: Socket::set_tcp_congestion + #[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))] + #[cfg_attr( + docsrs, + doc(cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))) + )] + pub fn tcp_congestion(&self) -> io::Result> { + let mut payload: [u8; TCP_CA_NAME_MAX] = [0; TCP_CA_NAME_MAX]; + let mut len = payload.len() as libc::socklen_t; + syscall!(getsockopt( + self.as_raw(), + IPPROTO_TCP, + libc::TCP_CONGESTION, + payload.as_mut_ptr().cast(), + &mut len, + )) + .map(|_| { + let buf = &payload[..len as usize]; + // TODO: use `MaybeUninit::slice_assume_init_ref` once stable. + unsafe { &*(buf as *const [_] as *const [u8]) }.into() + }) + } + + /// Set the value of the `TCP_CONGESTION` option for this socket. + /// + /// Specifies the TCP congestion control algorithm to use for this socket. + /// + /// The value must be a valid TCP congestion control algorithm name of the + /// platform. For example, Linux may supports "reno", "cubic". + #[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))] + #[cfg_attr( + docsrs, + doc(cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))) + )] + pub fn set_tcp_congestion(&self, tcp_ca_name: &[u8]) -> io::Result<()> { + syscall!(setsockopt( + self.as_raw(), + IPPROTO_TCP, + libc::TCP_CONGESTION, + tcp_ca_name.as_ptr() as *const _, + tcp_ca_name.len() as libc::socklen_t, + )) + .map(|_| ()) + } } #[cfg_attr(docsrs, doc(cfg(unix)))] diff --git a/tests/socket.rs b/tests/socket.rs index 39838e4c..c3872bd1 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -1341,3 +1341,49 @@ fn original_dst_ipv6() { Err(err) => assert_eq!(err.raw_os_error(), Some(libc::EOPNOTSUPP)), } } + +#[test] +#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))] +fn tcp_congestion() { + let socket: Socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); + // Get and set current tcp_ca + let origin_tcp_ca = socket + .tcp_congestion() + .expect("failed to get tcp congestion algorithm"); + socket + .set_tcp_congestion(&origin_tcp_ca) + .expect("failed to set tcp congestion algorithm"); + // Return a Err when set a non-exist tcp_ca + socket + .set_tcp_congestion(b"tcp_congestion_does_not_exist") + .unwrap_err(); + let cur_tcp_ca = socket.tcp_congestion().unwrap(); + assert_eq!( + cur_tcp_ca, origin_tcp_ca, + "expected {origin_tcp_ca:?} but get {cur_tcp_ca:?}" + ); + let cur_tcp_ca = cur_tcp_ca.splitn(2, |num| *num == 0).next().unwrap(); + const OPTIONS: [&[u8]; 2] = [ + b"cubic", + #[cfg(target_os = "linux")] // or Android. + b"reno", + #[cfg(target_os = "freebsd")] + b"newreno", + ]; + // Set a new tcp ca + #[cfg(target_os = "linux")] + let new_tcp_ca = if cur_tcp_ca == OPTIONS[0] { + OPTIONS[1] + } else { + OPTIONS[0] + }; + #[cfg(target_os = "freebsd")] + let new_tcp_ca = OPTIONS[1]; + socket.set_tcp_congestion(new_tcp_ca).unwrap(); + // Check if new tcp ca is successfully set + let cur_tcp_ca = socket.tcp_congestion().unwrap(); + assert_eq!( + cur_tcp_ca.splitn(2, |num| *num == 0).next().unwrap(), + new_tcp_ca, + ); +}