Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite Windows compat_fn macro #99553

Merged
merged 1 commit into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions library/std/src/sys/windows/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#![cfg_attr(test, allow(dead_code))]
#![unstable(issue = "none", feature = "windows_c")]

use crate::ffi::CStr;
use crate::mem;
use crate::os::raw::{c_char, c_int, c_long, c_longlong, c_uint, c_ulong, c_ushort};
use crate::os::windows::io::{BorrowedHandle, HandleOrInvalid, HandleOrNull};
Expand Down Expand Up @@ -1219,8 +1220,8 @@ extern "system" {

// Functions that aren't available on every version of Windows that we support,
// but we still use them and just provide some form of a fallback implementation.
compat_fn! {
"kernel32":
compat_fn_with_fallback! {
pub static KERNEL32: &CStr = ansi_str!("kernel32");

// >= Win10 1607
// https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-setthreaddescription
Expand All @@ -1243,8 +1244,8 @@ compat_fn! {
}
}

compat_fn! {
"api-ms-win-core-synch-l1-2-0":
compat_fn_optional! {
pub static SYNCH_API: &CStr = ansi_str!("api-ms-win-core-synch-l1-2-0");

// >= Windows 8 / Server 2012
// https://docs.microsoft.com/en-us/windows/win32/api/synchapi/nf-synchapi-waitonaddress
Expand All @@ -1253,17 +1254,13 @@ compat_fn! {
CompareAddress: LPVOID,
AddressSize: SIZE_T,
dwMilliseconds: DWORD
) -> BOOL {
panic!("WaitOnAddress not available")
}
pub fn WakeByAddressSingle(Address: LPVOID) -> () {
// If this api is unavailable, there cannot be anything waiting, because
// WaitOnAddress would've panicked. So it's fine to do nothing here.
}
) -> BOOL;
pub fn WakeByAddressSingle(Address: LPVOID) -> ();
}

compat_fn! {
"ntdll":
compat_fn_with_fallback! {
pub static NTDLL: &CStr = ansi_str!("ntdll");

pub fn NtCreateFile(
FileHandle: *mut HANDLE,
DesiredAccess: ACCESS_MASK,
Expand Down
246 changes: 195 additions & 51 deletions library/std/src/sys/windows/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,81 +49,225 @@
//! * call any Rust function or CRT function that touches any static
//! (global) state.

macro_rules! compat_fn {
($module:literal: $(
use crate::ffi::{c_void, CStr};
use crate::ptr::NonNull;
use crate::sys::c;

/// Helper macro for creating CStrs from literals and symbol names.
macro_rules! ansi_str {
(sym $ident:ident) => {{
#[allow(unused_unsafe)]
crate::sys::compat::const_cstr_from_bytes(concat!(stringify!($ident), "\0").as_bytes())
}};
($lit:literal) => {{ crate::sys::compat::const_cstr_from_bytes(concat!($lit, "\0").as_bytes()) }};
}

/// Creates a C string wrapper from a byte slice, in a constant context.
///
/// This is a utility function used by the [`ansi_str`] macro.
///
/// # Panics
///
/// Panics if the slice is not null terminated or contains nulls, except as the last item
pub(crate) const fn const_cstr_from_bytes(bytes: &'static [u8]) -> &'static CStr {
if !matches!(bytes.last(), Some(&0)) {
panic!("A CStr must be null terminated");
}
let mut i = 0;
// At this point `len()` is at least 1.
while i < bytes.len() - 1 {
if bytes[i] == 0 {
panic!("A CStr must not have interior nulls")
}
i += 1;
}
// SAFETY: The safety is ensured by the above checks.
unsafe { crate::ffi::CStr::from_bytes_with_nul_unchecked(bytes) }
}

#[used]
#[link_section = ".CRT$XCU"]
static INIT_TABLE_ENTRY: unsafe extern "C" fn() = init;

/// This is where the magic preloading of symbols happens.
///
/// Note that any functions included here will be unconditionally included in
/// the final binary, regardless of whether or not they're actually used.
///
/// Therefore, this is limited to `compat_fn_optional` functions which must be
/// preloaded and any functions which may be more time sensitive, even for the first call.
unsafe extern "C" fn init() {
// There is no locking here. This code is executed before main() is entered, and
// is guaranteed to be single-threaded.
//
// DO NOT do anything interesting or complicated in this function! DO NOT call
// any Rust functions or CRT functions if those functions touch any global state,
// because this function runs during global initialization. For example, DO NOT
// do any dynamic allocation, don't call LoadLibrary, etc.

if let Some(synch) = Module::new(c::SYNCH_API) {
// These are optional and so we must manually attempt to load them
// before they can be used.
c::WaitOnAddress::preload(synch);
c::WakeByAddressSingle::preload(synch);
}

if let Some(kernel32) = Module::new(c::KERNEL32) {
// Preloading this means getting a precise time will be as fast as possible.
c::GetSystemTimePreciseAsFileTime::preload(kernel32);
}
}

/// Represents a loaded module.
///
/// Note that the modules std depends on must not be unloaded.
/// Therefore a `Module` is always valid for the lifetime of std.
#[derive(Copy, Clone)]
pub(in crate::sys) struct Module(NonNull<c_void>);
impl Module {
/// Try to get a handle to a loaded module.
///
/// # SAFETY
///
/// This should only be use for modules that exist for the lifetime of std
/// (e.g. kernel32 and ntdll).
pub unsafe fn new(name: &CStr) -> Option<Self> {
// SAFETY: A CStr is always null terminated.
let module = c::GetModuleHandleA(name.as_ptr());
NonNull::new(module).map(Self)
}

// Try to get the address of a function.
pub fn proc_address(self, name: &CStr) -> Option<NonNull<c_void>> {
// SAFETY:
// `self.0` will always be a valid module.
// A CStr is always null terminated.
let proc = unsafe { c::GetProcAddress(self.0.as_ptr(), name.as_ptr()) };
NonNull::new(proc)
}
}

/// Load a function or use a fallback implementation if that fails.
macro_rules! compat_fn_with_fallback {
(pub static $module:ident: &CStr = $name:expr; $(
$(#[$meta:meta])*
pub fn $symbol:ident($($argname:ident: $argtype:ty),*) -> $rettype:ty $fallback_body:block
)*) => ($(
)*) => (
pub static $module: &CStr = $name;
$(
$(#[$meta])*
pub mod $symbol {
#[allow(unused_imports)]
use super::*;
use crate::mem;
use crate::ffi::CStr;
use crate::sync::atomic::{AtomicPtr, Ordering};
use crate::sys::compat::Module;

type F = unsafe extern "system" fn($($argtype),*) -> $rettype;

/// Points to the DLL import, or the fallback function.
///
/// This static can be an ordinary, unsynchronized, mutable static because
/// we guarantee that all of the writes finish during CRT initialization,
/// and all of the reads occur after CRT initialization.
static mut PTR: Option<F> = None;

/// This symbol is what allows the CRT to find the `init` function and call it.
/// It is marked `#[used]` because otherwise Rust would assume that it was not
/// used, and would remove it.
#[used]
#[link_section = ".CRT$XCU"]
static INIT_TABLE_ENTRY: unsafe extern "C" fn() = init;

unsafe extern "C" fn init() {
PTR = get_f();
/// `PTR` contains a function pointer to one of three functions.
/// It starts with the `load` function.
/// When that is called it attempts to load the requested symbol.
/// If it succeeds, `PTR` is set to the address of that symbol.
/// If it fails, then `PTR` is set to `fallback`.
static PTR: AtomicPtr<c_void> = AtomicPtr::new(load as *mut _);

unsafe extern "system" fn load($($argname: $argtype),*) -> $rettype {
let func = load_from_module(Module::new($module));
func($($argname),*)
}

unsafe extern "C" fn get_f() -> Option<F> {
// There is no locking here. This code is executed before main() is entered, and
// is guaranteed to be single-threaded.
//
// DO NOT do anything interesting or complicated in this function! DO NOT call
// any Rust functions or CRT functions, if those functions touch any global state,
// because this function runs during global initialization. For example, DO NOT
// do any dynamic allocation, don't call LoadLibrary, etc.
let module_name: *const u8 = concat!($module, "\0").as_ptr();
let symbol_name: *const u8 = concat!(stringify!($symbol), "\0").as_ptr();
let module_handle = $crate::sys::c::GetModuleHandleA(module_name as *const i8);
if !module_handle.is_null() {
let ptr = $crate::sys::c::GetProcAddress(module_handle, symbol_name as *const i8);
if !ptr.is_null() {
// Transmute to the right function pointer type.
return Some(mem::transmute(ptr));
fn load_from_module(module: Option<Module>) -> F {
unsafe {
static symbol_name: &CStr = ansi_str!(sym $symbol);
if let Some(f) = module.and_then(|m| m.proc_address(symbol_name)) {
PTR.store(f.as_ptr(), Ordering::Relaxed);
mem::transmute(f)
} else {
PTR.store(fallback as *mut _, Ordering::Relaxed);
fallback
}
}
return None;
}

#[allow(dead_code)]
#[allow(unused_variables)]
unsafe extern "system" fn fallback($($argname: $argtype),*) -> $rettype {
$fallback_body
}

#[allow(unused)]
pub(in crate::sys) fn preload(module: Module) {
load_from_module(Some(module));
}

#[inline(always)]
pub unsafe fn call($($argname: $argtype),*) -> $rettype {
let func: F = mem::transmute(PTR.load(Ordering::Relaxed));
func($($argname),*)
}
}
$(#[$meta])*
pub use $symbol::call as $symbol;
)*)
}

/// A function that either exists or doesn't.
///
/// NOTE: Optional functions must be preloaded in the `init` function above, or they will always be None.
macro_rules! compat_fn_optional {
(pub static $module:ident: &CStr = $name:expr; $(
$(#[$meta:meta])*
pub fn $symbol:ident($($argname:ident: $argtype:ty),*) -> $rettype:ty;
)*) => (
pub static $module: &CStr = $name;
$(
$(#[$meta])*
pub mod $symbol {
#[allow(unused_imports)]
use super::*;
use crate::mem;
use crate::sync::atomic::{AtomicPtr, Ordering};
use crate::sys::compat::Module;
use crate::ptr::{self, NonNull};

type F = unsafe extern "system" fn($($argtype),*) -> $rettype;

/// `PTR` will either be `null()` or set to the loaded function.
static PTR: AtomicPtr<c_void> = AtomicPtr::new(ptr::null_mut());

/// Only allow access to the function if it has loaded successfully.
#[inline(always)]
#[cfg(not(miri))]
pub fn option() -> Option<F> {
unsafe {
if cfg!(miri) {
// Miri does not run `init`, so we just call `get_f` each time.
get_f()
} else {
PTR
}
NonNull::new(PTR.load(Ordering::Relaxed)).map(|f| mem::transmute(f))
}
}

#[allow(dead_code)]
pub unsafe fn call($($argname: $argtype),*) -> $rettype {
if let Some(ptr) = option() {
return ptr($($argname),*);
// Miri does not understand the way we do preloading
// therefore load the function here instead.
#[cfg(miri)]
pub fn option() -> Option<F> {
let mut func = NonNull::new(PTR.load(Ordering::Relaxed));
if func.is_none() {
Module::new($module).map(preload);
func = NonNull::new(PTR.load(Ordering::Relaxed));
}
unsafe {
func.map(|f| mem::transmute(f))
}
$fallback_body
}
}

$(#[$meta])*
pub use $symbol::call as $symbol;
#[allow(unused)]
pub(in crate::sys) fn preload(module: Module) {
unsafe {
let symbol_name = ansi_str!(sym $symbol);
if let Some(f) = module.proc_address(symbol_name) {
PTR.store(f.as_ptr(), Ordering::Relaxed);
}
}
}
}
)*)
}