Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for TCP_CONGESTION socketopt #371

Merged
merged 8 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ const MAX_BUF_LEN: usize = ssize_t::MAX as usize;
#[cfg(target_vendor = "apple")]
const MAX_BUF_LEN: usize = c_int::MAX as usize - 1;

#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))]
const TCP_CA_NAME_MAX: usize = 16;
BobAnkh marked this conversation as resolved.
Show resolved Hide resolved

#[cfg(any(
all(
target_os = "linux",
Expand Down Expand Up @@ -2154,6 +2157,55 @@ impl crate::Socket {
)
}
}

/// Get the value of the `TCP_CONGESTION` option for this socket.
///
/// For more information about this option, see [`set_tcp_congestion`].
///
/// [`set_tcp_congestion`]: Socket::set_tcp_congestion
#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))]
#[cfg_attr(
docsrs,
doc(cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux"))))
)]
pub fn tcp_congestion(&self) -> io::Result<Vec<u8>> {
let mut payload: MaybeUninit<[u8; TCP_CA_NAME_MAX]> = MaybeUninit::uninit();
let mut len = size_of::<[u8; TCP_CA_NAME_MAX]>() as libc::socklen_t;
BobAnkh marked this conversation as resolved.
Show resolved Hide resolved
syscall!(getsockopt(
self.as_raw(),
IPPROTO_TCP,
libc::TCP_CONGESTION,
payload.as_mut_ptr().cast(),
&mut len,
))
.map(|_| {
let buf = unsafe { payload.assume_init() };
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is UB. Only &buf[0..len] bytes are initialised by the kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not resolved. You can't call payload.assume_init, it's not fully initialised. I think you can simply remove this line and we're good.

Copy link
Contributor Author

@BobAnkh BobAnkh Jan 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so now I use directly [u8; TCP_CA_NAME_MAX](initialized as [0; TCP_CA_NAME_MAX]) instead of MaybeUninit

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not what I asked for, but fine for now.

Allow me to explain the UB in the old code one more time. You create an array of size TCP_CA_NAME_MAX, accessing any of those by would be UB. Then you call the kernel to fill the bytes, but it doesn't fill all of them just n bytes (return by the system call). With the old code you called payload.assume_init(), which requires the entire array all (TCP_CA_NAME_MAX bytes) to be initialised, but that's not always the case as it's only up to n bytes.

So the only change needed to the old code was to remove the payload.assume_init() line. Because the next line creating a slice of only the initialised bytes and the line after that cast it to a slice of bytes (MaybeUninit<u8> -> u8).

let name = buf.splitn(2, |num| *num == 0).next().unwrap().to_vec();
name
BobAnkh marked this conversation as resolved.
Show resolved Hide resolved
})
}

/// Set the value of the `TCP_CONGESTION` option for this socket.
///
/// Specifies the TCP congestion control algorithm to use for this socket.
///
/// The value must be a valid TCP congestion control algorithm name of the
/// platform. For example, Linux may supports "reno", "cubic".
#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))]
#[cfg_attr(
docsrs,
doc(cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux"))))
)]
pub fn set_tcp_congestion(&self, tcp_ca_name: &[u8]) -> io::Result<()> {
syscall!(setsockopt(
self.as_raw(),
IPPROTO_TCP,
libc::TCP_CONGESTION,
tcp_ca_name.as_ptr() as *const std::os::raw::c_void,
BobAnkh marked this conversation as resolved.
Show resolved Hide resolved
tcp_ca_name.len() as libc::socklen_t,
))
.map(|_| ())
}
}

#[cfg_attr(docsrs, doc(cfg(unix)))]
Expand Down
22 changes: 22 additions & 0 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1341,3 +1341,25 @@ fn original_dst_ipv6() {
Err(err) => assert_eq!(err.raw_os_error(), Some(libc::EOPNOTSUPP)),
}
}

#[test]
#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))]
fn tcp_congestion() {
let socket: Socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap();
// Get and set current tcp_ca
let origin_tcp_ca = socket
.tcp_congestion()
.expect("failed to get tcp congestion algorithm");
socket
.set_tcp_congestion(&origin_tcp_ca)
.expect("failed to set tcp congestion algorithm");
// Return a Err when set a non-exist tcp_ca
socket
.set_tcp_congestion(b"tcp_congestion_does_not_exist")
.unwrap_err();
BobAnkh marked this conversation as resolved.
Show resolved Hide resolved
let cur_tcp_ca = socket.tcp_congestion().unwrap();
assert_eq!(
cur_tcp_ca, origin_tcp_ca,
"expected {origin_tcp_ca:?} but get {cur_tcp_ca:?}"
BobAnkh marked this conversation as resolved.
Show resolved Hide resolved
);
}