Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't assume memory layout of std::net::SocketAddr #39

Merged
merged 4 commits into from
Nov 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ default-target = "x86_64-pc-windows-msvc"
targets = ["aarch64-pc-windows-msvc", "i686-pc-windows-msvc", "x86_64-pc-windows-msvc"]

[dependencies]
socket2 = "0.3"
socket2 = "0.3.16"

[dependencies.winapi]
version = "0.3.3"
Expand Down
120 changes: 108 additions & 12 deletions src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use std::sync::atomic::{AtomicUsize, Ordering};

use winapi::ctypes::*;
use winapi::shared::guiddef::*;
use winapi::shared::in6addr::{in6_addr_u, IN6_ADDR};
use winapi::shared::inaddr::{in_addr_S_un, IN_ADDR};
use winapi::shared::minwindef::*;
use winapi::shared::minwindef::{FALSE, TRUE};
use winapi::shared::ntdef::*;
Expand Down Expand Up @@ -456,16 +458,64 @@ fn cvt(i: c_int, size: DWORD) -> io::Result<Option<usize>> {
}
}

fn socket_addr_to_ptrs(addr: &SocketAddr) -> (*const SOCKADDR, c_int) {
/// A type with the same memory layout as `SOCKADDR`. Used in converting Rust level
/// SocketAddr* types into their system representation. The benefit of this specific
/// type over using `SOCKADDR_STORAGE` is that this type is exactly as large as it
/// needs to be and not a lot larger. And it can be initialized cleaner from Rust.
#[repr(C)]
pub(crate) union SocketAddrCRepr {
v4: SOCKADDR_IN,
v6: SOCKADDR_IN6_LH,
}

impl SocketAddrCRepr {
pub(crate) fn as_ptr(&self) -> *const SOCKADDR {
self as *const _ as *const SOCKADDR
}
}

fn socket_addr_to_ptrs(addr: &SocketAddr) -> (SocketAddrCRepr, c_int) {
match *addr {
SocketAddr::V4(ref a) => (
a as *const _ as *const _,
mem::size_of::<SOCKADDR_IN>() as c_int,
),
SocketAddr::V6(ref a) => (
a as *const _ as *const _,
mem::size_of::<SOCKADDR_IN6_LH>() as c_int,
),
SocketAddr::V4(ref a) => {
let sin_addr = unsafe {
let mut s_un = mem::zeroed::<in_addr_S_un>();
*s_un.S_addr_mut() = u32::from_ne_bytes(a.ip().octets());
IN_ADDR { S_un: s_un }
};

let sockaddr_in = SOCKADDR_IN {
sin_family: AF_INET as ADDRESS_FAMILY,
sin_port: a.port().to_be(),
sin_addr,
sin_zero: [0; 8],
};

let sockaddr = SocketAddrCRepr { v4: sockaddr_in };
(sockaddr, mem::size_of::<SOCKADDR_IN>() as c_int)
}
SocketAddr::V6(ref a) => {
let sin6_addr = unsafe {
let mut u = mem::zeroed::<in6_addr_u>();
*u.Byte_mut() = a.ip().octets();
IN6_ADDR { u }
};
let u = unsafe {
let mut u = mem::zeroed::<SOCKADDR_IN6_LH_u>();
*u.sin6_scope_id_mut() = a.scope_id();
u
};

let sockaddr_in6 = SOCKADDR_IN6_LH {
sin6_family: AF_INET6 as ADDRESS_FAMILY,
sin6_port: a.port().to_be(),
sin6_addr,
sin6_flowinfo: a.flowinfo(),
u,
};

let sockaddr = SocketAddrCRepr { v6: sockaddr_in6 };
(sockaddr, mem::size_of::<SOCKADDR_IN6_LH>() as c_int)
}
}
}

Expand Down Expand Up @@ -650,7 +700,7 @@ unsafe fn connect_overlapped(
let mut bytes_sent: DWORD = 0;
let r = connect_ex(
socket,
addr_buf,
addr_buf.as_ptr(),
addr_len,
buf.as_ptr() as *mut _,
buf.len() as u32,
Expand Down Expand Up @@ -723,7 +773,7 @@ impl UdpSocketExt for UdpSocket {
1,
&mut sent_bytes,
0,
addr_buf as *const _,
addr_buf.as_ptr() as *const _,
addr_len,
overlapped,
None,
Expand Down Expand Up @@ -970,7 +1020,10 @@ impl WsaExtension {
#[cfg(test)]
mod tests {
use std::io::prelude::*;
use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket};
use std::net::{
IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV6, TcpListener, TcpStream, UdpSocket,
};
use std::slice;
use std::thread;

use socket2::{Domain, Socket, Type};
Expand Down Expand Up @@ -1239,4 +1292,47 @@ mod tests {
assert_eq!(addrs.remote(), Some(remote));
})
}

#[test]
fn sockaddr_convert_4() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(3, 4, 5, 6)), 0xabcd);
let (raw_addr, addr_len) = super::socket_addr_to_ptrs(&addr);
assert_eq!(addr_len, 16);
let addr_bytes =
unsafe { slice::from_raw_parts(raw_addr.as_ptr() as *const u8, addr_len as usize) };
assert_eq!(
addr_bytes,
&[2, 0, 0xab, 0xcd, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0]
);
}

#[test]
fn sockaddr_convert_v6() {
let port = 0xabcd;
let flowinfo = 0x12345678;
let scope_id = 0x87654321;
let addr = SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(
0x0102, 0x0304, 0x0506, 0x0708, 0x090a, 0x0b0c, 0x0d0e, 0x0f10,
),
port,
flowinfo,
scope_id,
));
let (raw_addr, addr_len) = super::socket_addr_to_ptrs(&addr);
assert_eq!(addr_len, 28);
let addr_bytes =
unsafe { slice::from_raw_parts(raw_addr.as_ptr() as *const u8, addr_len as usize) };
assert_eq!(
addr_bytes,
&[
23, 0, // AF_INET6
0xab, 0xcd, // Port
0x78, 0x56, 0x34, 0x12, // flowinfo
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
0x0f, 0x10, // IP
0x21, 0x43, 0x65, 0x87, // scope_id
]
);
}
}