Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows: move BSD socket shims to netc #127734

Merged
merged 1 commit into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 1 addition & 96 deletions library/std/src/sys/pal/windows/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

use crate::ffi::CStr;
use crate::mem;
use crate::os::raw::{c_char, c_int, c_uint, c_ulong, c_ushort, c_void};
use crate::os::raw::{c_uint, c_ulong, c_ushort, c_void};
use crate::os::windows::io::{AsRawHandle, BorrowedHandle};
use crate::ptr;

Expand All @@ -19,12 +19,6 @@ pub use windows_sys::*;

pub type WCHAR = u16;

pub type socklen_t = c_int;
pub type ADDRESS_FAMILY = c_ushort;
pub use FD_SET as fd_set;
pub use LINGER as linger;
pub use TIMEVAL as timeval;

pub const INVALID_HANDLE_VALUE: HANDLE = ::core::ptr::without_provenance_mut(-1i32 as _);

// https://learn.microsoft.com/en-us/cpp/c-runtime-library/exit-success-exit-failure?view=msvc-170
Expand All @@ -42,20 +36,6 @@ pub const INIT_ONCE_STATIC_INIT: INIT_ONCE = INIT_ONCE { Ptr: ptr::null_mut() };
pub const OBJ_DONT_REPARSE: u32 = windows_sys::OBJ_DONT_REPARSE as u32;
pub const FRS_ERR_SYSVOL_POPULATE_TIMEOUT: u32 =
windows_sys::FRS_ERR_SYSVOL_POPULATE_TIMEOUT as u32;
pub const AF_INET: c_int = windows_sys::AF_INET as c_int;
pub const AF_INET6: c_int = windows_sys::AF_INET6 as c_int;

#[repr(C)]
pub struct ip_mreq {
pub imr_multiaddr: in_addr,
pub imr_interface: in_addr,
}

#[repr(C)]
pub struct ipv6_mreq {
pub ipv6mr_multiaddr: in6_addr,
pub ipv6mr_interface: c_uint,
}

// Equivalent to the `NT_SUCCESS` C preprocessor macro.
// See: https://docs.microsoft.com/en-us/windows-hardware/drivers/kernel/using-ntstatus-values
Expand Down Expand Up @@ -127,45 +107,6 @@ pub struct MOUNT_POINT_REPARSE_BUFFER {
pub PathBuffer: WCHAR,
}

#[repr(C)]
pub struct SOCKADDR_STORAGE_LH {
pub ss_family: ADDRESS_FAMILY,
pub __ss_pad1: [c_char; 6],
pub __ss_align: i64,
pub __ss_pad2: [c_char; 112],
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct sockaddr_in {
pub sin_family: ADDRESS_FAMILY,
pub sin_port: c_ushort,
pub sin_addr: in_addr,
pub sin_zero: [c_char; 8],
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct sockaddr_in6 {
pub sin6_family: ADDRESS_FAMILY,
pub sin6_port: c_ushort,
pub sin6_flowinfo: c_ulong,
pub sin6_addr: in6_addr,
pub sin6_scope_id: c_ulong,
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct in_addr {
pub s_addr: u32,
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct in6_addr {
pub s6_addr: [u8; 16],
}

// Desktop specific functions & types
cfg_if::cfg_if! {
if #[cfg(not(target_vendor = "uwp"))] {
Expand Down Expand Up @@ -205,42 +146,6 @@ pub unsafe extern "system" fn ReadFileEx(
)
}

// POSIX compatibility shims.
pub unsafe fn recv(socket: SOCKET, buf: *mut c_void, len: c_int, flags: c_int) -> c_int {
windows_sys::recv(socket, buf.cast::<u8>(), len, flags)
}
pub unsafe fn send(socket: SOCKET, buf: *const c_void, len: c_int, flags: c_int) -> c_int {
windows_sys::send(socket, buf.cast::<u8>(), len, flags)
}
pub unsafe fn recvfrom(
socket: SOCKET,
buf: *mut c_void,
len: c_int,
flags: c_int,
addr: *mut SOCKADDR,
addrlen: *mut c_int,
) -> c_int {
windows_sys::recvfrom(socket, buf.cast::<u8>(), len, flags, addr, addrlen)
}
pub unsafe fn sendto(
socket: SOCKET,
buf: *const c_void,
len: c_int,
flags: c_int,
addr: *const SOCKADDR,
addrlen: c_int,
) -> c_int {
windows_sys::sendto(socket, buf.cast::<u8>(), len, flags, addr, addrlen)
}
pub unsafe fn getaddrinfo(
node: *const c_char,
service: *const c_char,
hints: *const ADDRINFOA,
res: *mut *mut ADDRINFOA,
) -> c_int {
windows_sys::getaddrinfo(node.cast::<u8>(), service.cast::<u8>(), hints, res)
}

cfg_if::cfg_if! {
if #[cfg(not(target_vendor = "uwp"))] {
pub unsafe fn NtReadFile(
Expand Down
1 change: 1 addition & 0 deletions library/std/src/sys/pal/windows/c/bindings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,7 @@ Windows.Win32.Networking.WinSock.SOCK_RDM
Windows.Win32.Networking.WinSock.SOCK_SEQPACKET
Windows.Win32.Networking.WinSock.SOCK_STREAM
Windows.Win32.Networking.WinSock.SOCKADDR
Windows.Win32.Networking.WinSock.SOCKADDR_STORAGE
Windows.Win32.Networking.WinSock.SOCKADDR_UN
Windows.Win32.Networking.WinSock.SOCKET
Windows.Win32.Networking.WinSock.SOCKET_ERROR
Expand Down
8 changes: 8 additions & 0 deletions library/std/src/sys/pal/windows/c/windows_sys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2890,6 +2890,14 @@ pub struct SOCKADDR {
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct SOCKADDR_STORAGE {
pub ss_family: ADDRESS_FAMILY,
pub __ss_pad1: [i8; 6],
pub __ss_align: i64,
pub __ss_pad2: [i8; 112],
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct SOCKADDR_UN {
pub sun_family: ADDRESS_FAMILY,
pub sun_path: [i8; 108],
Expand Down
112 changes: 99 additions & 13 deletions library/std/src/sys/pal/windows/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,100 @@ use crate::time::Duration;

use core::ffi::{c_int, c_long, c_ulong, c_ushort};

#[allow(non_camel_case_types)]
pub type wrlen_t = i32;

pub mod netc {
pub use crate::sys::c::ADDRESS_FAMILY as sa_family_t;
pub use crate::sys::c::ADDRINFOA as addrinfo;
pub use crate::sys::c::SOCKADDR as sockaddr;
pub use crate::sys::c::SOCKADDR_STORAGE_LH as sockaddr_storage;
pub use crate::sys::c::*;
//! BSD socket compatibility shim
//!
//! Some Windows API types are not quite what's expected by our cross-platform
//! net code. E.g. naming differences or different pointer types.
use crate::sys::c::{self, ADDRESS_FAMILY, ADDRINFOA, SOCKADDR, SOCKET};
use core::ffi::{c_char, c_int, c_uint, c_ulong, c_ushort, c_void};

// re-exports from Windows API bindings.
pub use crate::sys::c::{
bind, connect, freeaddrinfo, getpeername, getsockname, getsockopt, listen, setsockopt,
ADDRESS_FAMILY as sa_family_t, ADDRINFOA as addrinfo, IPPROTO_IP, IPPROTO_IPV6,
IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_MULTICAST_LOOP, IPV6_V6ONLY,
IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, IP_TTL,
SOCKADDR as sockaddr, SOCKADDR_STORAGE as sockaddr_storage, SOCK_DGRAM, SOCK_STREAM,
SOL_SOCKET, SO_BROADCAST, SO_RCVTIMEO, SO_SNDTIMEO,
};

#[allow(non_camel_case_types)]
pub type socklen_t = c_int;

pub const AF_INET: i32 = c::AF_INET as i32;
pub const AF_INET6: i32 = c::AF_INET6 as i32;

// The following two structs use a union in the generated bindings but
// our cross-platform code expects a normal field so it's redefined here.
// As a consequence, we also need to redefine other structs that use this struct.
#[repr(C)]
#[derive(Copy, Clone)]
pub struct in_addr {
pub s_addr: u32,
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct in6_addr {
pub s6_addr: [u8; 16],
}

#[repr(C)]
pub struct ip_mreq {
pub imr_multiaddr: in_addr,
pub imr_interface: in_addr,
}

#[repr(C)]
pub struct ipv6_mreq {
pub ipv6mr_multiaddr: in6_addr,
pub ipv6mr_interface: c_uint,
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct sockaddr_in {
pub sin_family: ADDRESS_FAMILY,
pub sin_port: c_ushort,
pub sin_addr: in_addr,
pub sin_zero: [c_char; 8],
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct sockaddr_in6 {
pub sin6_family: ADDRESS_FAMILY,
pub sin6_port: c_ushort,
pub sin6_flowinfo: c_ulong,
pub sin6_addr: in6_addr,
pub sin6_scope_id: c_ulong,
}

pub unsafe fn send(socket: SOCKET, buf: *const c_void, len: c_int, flags: c_int) -> c_int {
unsafe { c::send(socket, buf.cast::<u8>(), len, flags) }
}
pub unsafe fn sendto(
socket: SOCKET,
buf: *const c_void,
len: c_int,
flags: c_int,
addr: *const SOCKADDR,
addrlen: c_int,
) -> c_int {
unsafe { c::sendto(socket, buf.cast::<u8>(), len, flags, addr, addrlen) }
}
pub unsafe fn getaddrinfo(
node: *const c_char,
service: *const c_char,
hints: *const ADDRINFOA,
res: *mut *mut ADDRINFOA,
) -> c_int {
unsafe { c::getaddrinfo(node.cast::<u8>(), service.cast::<u8>(), hints, res) }
}
}

pub struct Socket(OwnedSocket);
Expand Down Expand Up @@ -102,8 +188,8 @@ where
impl Socket {
pub fn new(addr: &SocketAddr, ty: c_int) -> io::Result<Socket> {
let family = match *addr {
SocketAddr::V4(..) => c::AF_INET,
SocketAddr::V6(..) => c::AF_INET6,
SocketAddr::V4(..) => netc::AF_INET,
SocketAddr::V6(..) => netc::AF_INET6,
};
let socket = unsafe {
c::WSASocketW(
Expand Down Expand Up @@ -157,7 +243,7 @@ impl Socket {
return Err(io::Error::ZERO_TIMEOUT);
}

let mut timeout = c::timeval {
let mut timeout = c::TIMEVAL {
tv_sec: cmp::min(timeout.as_secs(), c_long::MAX as u64) as c_long,
tv_usec: timeout.subsec_micros() as c_long,
};
Expand All @@ -167,7 +253,7 @@ impl Socket {
}

let fds = {
let mut fds = unsafe { mem::zeroed::<c::fd_set>() };
let mut fds = unsafe { mem::zeroed::<c::FD_SET>() };
fds.fd_count = 1;
fds.fd_array[0] = self.as_raw();
fds
Expand Down Expand Up @@ -295,8 +381,8 @@ impl Socket {
buf: &mut [u8],
flags: c_int,
) -> io::Result<(usize, SocketAddr)> {
let mut storage = unsafe { mem::zeroed::<c::SOCKADDR_STORAGE_LH>() };
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
let mut storage = unsafe { mem::zeroed::<c::SOCKADDR_STORAGE>() };
let mut addrlen = mem::size_of_val(&storage) as netc::socklen_t;
let length = cmp::min(buf.len(), <wrlen_t>::MAX as usize) as wrlen_t;

// On unix when a socket is shut down all further reads return 0, so we
Expand Down Expand Up @@ -399,7 +485,7 @@ impl Socket {
}

pub fn set_linger(&self, linger: Option<Duration>) -> io::Result<()> {
let linger = c::linger {
let linger = c::LINGER {
l_onoff: linger.is_some() as c_ushort,
l_linger: linger.unwrap_or_default().as_secs() as c_ushort,
};
Expand All @@ -408,7 +494,7 @@ impl Socket {
}

pub fn linger(&self) -> io::Result<Option<Duration>> {
let val: c::linger = net::getsockopt(self, c::SOL_SOCKET, c::SO_LINGER)?;
let val: c::LINGER = net::getsockopt(self, c::SOL_SOCKET, c::SO_LINGER)?;

Ok((val.l_onoff != 0).then(|| Duration::from_secs(val.l_linger as u64)))
}
Expand Down
Loading