diff --git a/Cargo.toml b/Cargo.toml index ec19b84..f6355be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,18 @@ doctest = false test = false [dependencies] -cfg-if = "0.1" -libc = "0.2" +cfg-if = "0.1.6" +libc = "0.2.45" + +[target.'cfg(windows)'.dependencies.winapi] +version = "0.3.6" +features = [ + 'memoryapi', + 'winbase', + 'fibersapi', + 'processthreadsapi', + 'minwindef', +] [build-dependencies] cc = "1.0" diff --git a/README.md b/README.md index c5ac5e8..ad568e7 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A stack-growth library for Rust. Enables annotating fixed points in programs where the stack may want to grow larger. Spills over to the heap if the stack -has it its limit. +has hit its limit. This library is intended on helping implement recursive algorithms. diff --git a/build.rs b/build.rs index 89ca9f7..132053f 100644 --- a/build.rs +++ b/build.rs @@ -14,6 +14,7 @@ fn main() { cfg.define("APPLE", None); } else if target.contains("windows") { cfg.define("WINDOWS", None); + cfg.file("src/arch/windows.c"); } else { panic!("\n\nusing currently unsupported target triple with \ stacker: {}\n\n", target); diff --git a/src/arch/i686.S b/src/arch/i686.S index dd143c5..49c4708 100644 --- a/src/arch/i686.S +++ b/src/arch/i686.S @@ -2,26 +2,25 @@ .text -GLOBAL(__stacker_black_box): - ret - GLOBAL(__stacker_stack_pointer): mov %esp, %eax ret -#if defined(WINDOWS) -GLOBAL(__stacker_get_tib_32): - mov %fs:0x18, %eax - ret -#endif - GLOBAL(__stacker_switch_stacks): + // CFI instructions tells the unwinder how to unwind this function + // This enables unwinding through our extended stacks and also + // backtrackes + .cfi_startproc push %ebp + .cfi_def_cfa_offset 8 // restore esp by adding 8 + .cfi_offset ebp, -8 // restore ebp from the stack mov %esp, %ebp - mov 8(%ebp), %esp // switch to our new stack + .cfi_def_cfa_register ebp // restore esp from ebp + mov 16(%ebp), %esp // switch to our new stack mov 12(%ebp), %eax // load function we're going to call - push 16(%ebp) // push argument to first function + push 8(%ebp) // push argument to first function call *%eax // call our function pointer mov %ebp, %esp // restore the old stack pointer pop %ebp ret + .cfi_endproc diff --git a/src/arch/i686.asm b/src/arch/i686.asm index 0bb1333..425d42c 100644 --- a/src/arch/i686.asm +++ b/src/arch/i686.asm @@ -2,32 +2,9 @@ .MODEL FLAT, C .CODE -__stacker_black_box PROC - RET -__stacker_black_box ENDP - __stacker_stack_pointer PROC MOV EAX, ESP RET __stacker_stack_pointer ENDP -__stacker_get_tib_32 PROC - ASSUME FS:NOTHING - MOV EAX, FS:[24] - ASSUME FS:ERROR - RET -__stacker_get_tib_32 ENDP - -__stacker_switch_stacks PROC - PUSH EBP - MOV EBP, ESP - MOV ESP, [EBP + 8] ; switch stacks - MOV EAX, [EBP + 12] ; load the function we're going to call - PUSH [EBP + 16] ; push the argument to this function - CALL EAX ; call the next function - MOV ESP, EBP ; restore the old stack pointer - POP EBP - RET -__stacker_switch_stacks ENDP - END diff --git a/src/arch/windows.c b/src/arch/windows.c new file mode 100644 index 0000000..89485a0 --- /dev/null +++ b/src/arch/windows.c @@ -0,0 +1,5 @@ +#include + +PVOID __stacker_get_current_fiber() { + return GetCurrentFiber(); +} diff --git a/src/arch/x86_64.S b/src/arch/x86_64.S index cbdf016..598efa1 100644 --- a/src/arch/x86_64.S +++ b/src/arch/x86_64.S @@ -2,35 +2,23 @@ .text -GLOBAL(__stacker_black_box): - ret - GLOBAL(__stacker_stack_pointer): movq %rsp, %rax ret -#if defined(WINDOWS) -#define ARG1 %rcx -#define ARG2 %rdx -#define ARG3 %r8 -#else -#define ARG1 %rdi -#define ARG2 %rsi -#define ARG3 %rdx -#endif - -#if defined(WINDOWS) -GLOBAL(__stacker_get_tib_64): - mov %gs:0x30, %rax - ret -#endif - GLOBAL(__stacker_switch_stacks): + // CFI instructions tells the unwinder how to unwind this function + // This enables unwinding through our extended stacks and also + // backtrackes + .cfi_startproc push %rbp + .cfi_def_cfa_offset 16 // restore rsp by adding 16 + .cfi_offset rbp, -16 // restore rbp from the stack mov %rsp, %rbp - mov ARG1, %rsp // switch to our new stack - mov ARG3, ARG1 // move the data pointer to the first argument - call *ARG2 // call our function pointer + .cfi_def_cfa_register rbp // restore rsp from rbp + mov %rdx, %rsp // switch to our new stack + call *%rsi // call our function pointer, data argument in %rdi mov %rbp, %rsp // restore the old stack pointer pop %rbp ret + .cfi_endproc diff --git a/src/arch/x86_64.asm b/src/arch/x86_64.asm index c14696c..ad5f470 100644 --- a/src/arch/x86_64.asm +++ b/src/arch/x86_64.asm @@ -1,23 +1,8 @@ _text SEGMENT -__stacker_black_box PROC - RET -__stacker_black_box ENDP - __stacker_stack_pointer PROC MOV RAX, RSP RET __stacker_stack_pointer ENDP -__stacker_switch_stacks PROC - PUSH RBP - MOV RBP, RSP - MOV RSP, RCX ; switch to our new stack - MOV RCX, R8 ; move the data pointer to the first argument - CALL RDX ; call our function pointer - MOV RSP, RBP ; restore the old stack pointer - POP RBP - RET -__stacker_switch_stacks ENDP - -END +END \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 9420627..acff7b9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,30 +27,11 @@ #[macro_use] extern crate cfg_if; extern crate libc; +#[cfg(windows)] +extern crate winapi; use std::cell::Cell; -extern { - fn __stacker_stack_pointer() -> usize; - fn __stacker_switch_stacks(new_stack: usize, - fnptr: *const u8, - dataptr: *mut u8); -} - -thread_local! { - static STACK_LIMIT: Cell = Cell::new(unsafe { - guess_os_stack_limit() - }) -} - -fn get_stack_limit() -> usize { - STACK_LIMIT.with(|s| s.get()) -} - -fn set_stack_limit(l: usize) { - STACK_LIMIT.with(|s| s.set(l)) -} - /// Grows the call stack if necessary. /// /// This function is intended to be called at manually instrumented points in a @@ -60,103 +41,308 @@ fn set_stack_limit(l: usize) { /// /// The closure `f` is guaranteed to run on a stack with at least `red_zone` /// bytes, and it will be run on the current stack if there's space available. -pub fn maybe_grow R>(red_zone: usize, - stack_size: usize, - f: F) -> R { - if remaining_stack() >= red_zone { +#[inline(always)] +pub fn maybe_grow R>( + red_zone: usize, + stack_size: usize, + f: F, +) -> R { + // if we can't guess the remaining stack (unsupported on some platforms) + // we immediately grow the stack and then cache the new stack size (which + // we do know now because we know by how much we grew the stack) + if remaining_stack().map_or(false, |remaining| remaining >= red_zone) { f() } else { - grow_the_stack(stack_size, f) + grow(stack_size, f) } } +extern { + fn __stacker_stack_pointer() -> usize; +} + /// Queries the amount of remaining stack as interpreted by this library. /// /// This function will return the amount of stack space left which will be used /// to determine whether a stack switch should be made or not. -pub fn remaining_stack() -> usize { - unsafe { - __stacker_stack_pointer() - get_stack_limit() - } +#[inline(always)] +pub fn remaining_stack() -> Option { + get_stack_limit().map(|limit| unsafe { __stacker_stack_pointer() - limit }) } +/// Always creates a new stack for the passed closure to run on. +/// The closure will still be on the same thread as the caller of `grow`. +/// This will allocate a new stack with at least `stack_size` bytes. #[inline(never)] -fn grow_the_stack R>(stack_size: usize, f: F) -> R { +pub fn grow R>(stack_size: usize, f: F) -> R { let mut f = Some(f); let mut ret = None; - unsafe { - _grow_the_stack(stack_size, &mut || { - ret = Some(f.take().unwrap()()); - }); - } + _grow(stack_size, &mut || { + ret = Some(f.take().unwrap()()); + }); ret.unwrap() } -unsafe fn _grow_the_stack(stack_size: usize, mut f: &mut FnMut()) { - // Align to 16-bytes (see below for why) - let stack_size = (stack_size + 15) / 16 * 16; +thread_local! { + static STACK_LIMIT: Cell> = Cell::new(unsafe { + guess_os_stack_limit() + }) +} - // Allocate some new stack for oureslves - let mut stack = Vec::::with_capacity(stack_size); - let new_limit = stack.as_ptr() as usize + 32 * 1024; +#[inline(always)] +fn get_stack_limit() -> Option { + STACK_LIMIT.with(|s| s.get()) +} - // Save off the old stack limits - let old_limit = get_stack_limit(); +fn set_stack_limit(l: Option) { + STACK_LIMIT.with(|s| s.set(l)) +} - // Prepare stack limits for the stack switch - set_stack_limit(new_limit); +cfg_if! { + if #[cfg(not(windows))] { + extern { + fn __stacker_switch_stacks(dataptr: *mut u8, + fnptr: *const u8, + new_stack: usize); + fn getpagesize() -> libc::c_int; + } - // Make sure the stack is 16-byte aligned which should be enough for all - // platforms right now. Allocations on 64-bit are already 16-byte aligned - // and our switching routine doesn't push any other data, but the routine on - // 32-bit pushes an argument so we need a bit of an offset to get it 16-byte - // aligned when the call is made. - let offset = if cfg!(target_pointer_width = "32") { - 12 - } else { - 0 - }; - __stacker_switch_stacks(stack.as_mut_ptr() as usize + stack_size - offset, - doit as usize as *const _, - &mut f as *mut &mut FnMut() as *mut u8); - - // Once we've returned reset bothe stack limits and then return value same - // value the closure returned. - set_stack_limit(old_limit); - - unsafe extern fn doit(f: &mut &mut FnMut()) { - f(); + struct StackSwitch { + map: *mut libc::c_void, + stack_size: usize, + old_stack_limit: Option, + } + + impl Drop for StackSwitch { + fn drop(&mut self) { + unsafe { + libc::munmap(self.map, self.stack_size); + } + set_stack_limit(self.old_stack_limit); + } + } + + fn _grow(stack_size: usize, mut f: &mut FnMut()) { + let page_size = unsafe { getpagesize() } as usize; + + // Round the stack size up to a multiple of page_size + let rem = stack_size % page_size; + let stack_size = if rem == 0 { + stack_size + } else { + stack_size.checked_add(page_size - rem) + .expect("stack size calculation overflowed") + }; + + // We need at least 2 page + let stack_size = std::cmp::max(stack_size, page_size); + + // Add a guard page + let stack_size = stack_size.checked_add(page_size) + .expect("stack size calculation overflowed"); + + // Allocate some new stack for ourselves + let map = unsafe { + libc::mmap(std::ptr::null_mut(), + stack_size, + libc::PROT_NONE, + libc::MAP_PRIVATE | + libc::MAP_ANON, + 0, + 0) + }; + if map == -1isize as _ { + panic!("unable to allocate stack") + } + let _switch = StackSwitch { + map, + stack_size, + old_stack_limit: get_stack_limit(), + }; + let result = unsafe { + libc::mprotect((map as usize + page_size) as *mut libc::c_void, + stack_size - page_size, + libc::PROT_READ | libc::PROT_WRITE) + }; + if result == -1 { + panic!("unable to set stack permissions") + } + let stack_low = map as usize; + + // Prepare stack limits for the stack switch + set_stack_limit(Some(stack_low)); + + // Make sure the stack is 16-byte aligned which should be enough for all + // platforms right now. Allocations on 64-bit are already 16-byte aligned + // and our switching routine doesn't push any other data, but the routine on + // 32-bit pushes an argument so we need a bit of an offset to get it 16-byte + // aligned when the call is made. + let offset = if cfg!(target_pointer_width = "32") { + 12 + } else { + 0 + }; + + extern fn doit(f: &mut &mut FnMut()) { + f(); + } + + unsafe { + __stacker_switch_stacks(&mut f as *mut &mut FnMut() as *mut u8, + doit as usize as *const _, + stack_low + stack_size - offset); + } + + // Dropping `switch` frees the memory mapping and restores the old stack limit + } } } cfg_if! { if #[cfg(windows)] { - // See this for where all this logic is coming from. - // - // https://github.com/adobe/webkit/blob/0441266/Source/WTF/wtf - // /StackBounds.cpp - unsafe fn guess_os_stack_limit() -> usize { - #[cfg(target_pointer_width = "32")] - extern { - #[link_name = "__stacker_get_tib_32"] - fn get_tib_address() -> *const usize; - } - #[cfg(target_pointer_width = "64")] - extern "system" { - #[cfg_attr(target_env = "msvc", link_name = "NtCurrentTeb")] - #[cfg_attr(target_env = "gnu", link_name = "__stacker_get_tib_64")] - fn get_tib_address() -> *const usize; + use std::ptr; + use std::io; + + use winapi::shared::basetsd::*; + use winapi::shared::minwindef::{LPVOID, BOOL}; + use winapi::shared::ntdef::*; + use winapi::um::fibersapi::*; + use winapi::um::memoryapi::*; + use winapi::um::processthreadsapi::*; + use winapi::um::winbase::*; + + extern { + fn __stacker_get_current_fiber() -> PVOID; + } + + struct FiberInfo<'a> { + callback: &'a mut FnMut(), + result: Option>, + parent_fiber: LPVOID, + } + + unsafe extern "system" fn fiber_proc(info: LPVOID) { + let info = &mut *(info as *mut FiberInfo); + + // Remember the old stack limit + let old_stack_limit = get_stack_limit(); + // Update the limit to that of the fiber stack + set_stack_limit(guess_os_stack_limit()); + + info.result = Some(std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + (info.callback)(); + }))); + + // Restore the stack limit of the previous fiber + set_stack_limit(old_stack_limit); + + SwitchToFiber(info.parent_fiber); + return; + } + + fn _grow(stack_size: usize, callback: &mut FnMut()) { + unsafe { + // Fibers (or stackful coroutines) is the only way to create new stacks on the + // same thread on Windows. So in order to extend the stack we create fiber + // and switch to it so we can use it's stack. After running + // `callback` we switch back to the current stack and destroy + // the fiber and its associated stack. + + let was_fiber = IsThreadAFiber() == TRUE as BOOL; + + let mut info = FiberInfo { + callback, + result: None, + + // We need a handle to the current stack / fiber so we can switch back to it + parent_fiber: { + // Is the current thread already a fiber? This is the case when we already + // used a fiber to extend the stack + if was_fiber { + // Get a handle to the current fiber. We need to use C for this + // as GetCurrentFiber is an header only function. + __stacker_get_current_fiber() + } else { + // Convert the current thread to a fiber, so we are able to switch back + // to the current stack. Threads coverted to fibers still act like + // regular threads, but they have associated fiber data. We later + // convert it back to a regular thread and free the fiber data. + ConvertThreadToFiber(ptr::null_mut()) + } + }, + }; + if info.parent_fiber.is_null() { + // We don't have a handle to the fiber, so we can't switch back + panic!("unable to convert thread to fiber: {}", io::Error::last_os_error()); + } + + let fiber = CreateFiber( + stack_size as SIZE_T, + Some(fiber_proc), + &mut info as *mut FiberInfo as *mut _, + ); + if fiber.is_null() { + panic!("unable to allocate fiber: {}", io::Error::last_os_error()); + } + + // Switch to the fiber we created. This changes stacks and starts executing + // fiber_proc on it. fiber_proc will run `callback` and then switch back + SwitchToFiber(fiber); + + // We are back on the old stack and now we have destroy the fiber and its stack + DeleteFiber(fiber); + + // If we started out on a non-fiber thread, we converted that thread to a fiber. + // Here we convert back. + if !was_fiber { + if ConvertFiberToThread() == 0 { + panic!("unable to convert back to thread: {}", io::Error::last_os_error()); + } + } + + if let Err(payload) = info.result.unwrap() { + std::panic::resume_unwind(payload); + } } - // https://en.wikipedia.org/wiki/Win32_Thread_Information_Block for - // the struct layout of the 32-bit TIB. It looks like the struct - // layout of the 64-bit TIB is also the same for getting the stack - // limit: http://doxygen.reactos.org/d3/db0/structNT__TIB64.html - *get_tib_address().offset(2) + } + + #[inline(always)] + fn get_thread_stack_guarantee() -> usize { + let min_guarantee = if cfg!(target_pointer_width = "32") { + 0x1000 + } else { + 0x2000 + }; + let mut stack_guarantee = 0; + unsafe { + // Read the current thread stack guarantee + // This is the stack reserved for stack overflow + // exception handling. + // This doesn't return the true value so we need + // some further logic to calculate the real stack + // guarantee. This logic is what is used on x86-32 and + // x86-64 Windows 10. Other versions and platforms may differ + SetThreadStackGuarantee(&mut stack_guarantee) + }; + std::cmp::max(stack_guarantee, min_guarantee) as usize + 0x1000 + } + + #[inline(always)] + unsafe fn guess_os_stack_limit() -> Option { + let mut mi = std::mem::zeroed(); + // Query the allocation which contains our stack pointer in order + // to discover the size of the stack + VirtualQuery( + __stacker_stack_pointer() as *const _, + &mut mi, + std::mem::size_of_val(&mi) as SIZE_T, + ); + Some(mi.AllocationBase as usize + get_thread_stack_guarantee() + 0x1000) } } else if #[cfg(target_os = "linux")] { use std::mem; - unsafe fn guess_os_stack_limit() -> usize { + unsafe fn guess_os_stack_limit() -> Option { let mut attr: libc::pthread_attr_t = mem::zeroed(); assert_eq!(libc::pthread_attr_init(&mut attr), 0); assert_eq!(libc::pthread_getattr_np(libc::pthread_self(), @@ -166,18 +352,19 @@ cfg_if! { assert_eq!(libc::pthread_attr_getstack(&attr, &mut stackaddr, &mut stacksize), 0); assert_eq!(libc::pthread_attr_destroy(&mut attr), 0); - stackaddr as usize + Some(stackaddr as usize) } } else if #[cfg(target_os = "macos")] { - use libc::{c_void, pthread_t, size_t}; - - unsafe fn guess_os_stack_limit() -> usize { - libc::pthread_get_stackaddr_np(libc::pthread_self()) as usize - - libc::pthread_get_stacksize_np(libc::pthread_self()) as usize + unsafe fn guess_os_stack_limit() -> Option { + Some(libc::pthread_get_stackaddr_np(libc::pthread_self()) as usize - + libc::pthread_get_stacksize_np(libc::pthread_self()) as usize) } } else { - unsafe fn guess_os_stack_limit() -> usize { - panic!("cannot guess the stack limit on this platform"); + // fallback for other platforms is to always increase the stack if we're on + // the root stack. After we increased the stack once, we know the new stack + // size and don't need this pessimization anymore + unsafe fn guess_os_stack_limit() -> Option { + None } } } diff --git a/tests/simple.rs b/tests/simple.rs new file mode 100644 index 0000000..4e4c46f --- /dev/null +++ b/tests/simple.rs @@ -0,0 +1,24 @@ +extern crate stacker; + +const RED_ZONE: usize = 100*1024; // 100k +const STACK_PER_RECURSION: usize = 1 * 1024 * 1024; // 1MB + +pub fn ensure_sufficient_stack R + std::panic::UnwindSafe>( + f: F +) -> R { + stacker::maybe_grow(RED_ZONE, STACK_PER_RECURSION, f) +} + +#[inline(never)] +fn recurse(n: usize) { + let x = [42u8; 50000]; + if n != 0 { + ensure_sufficient_stack(|| recurse(n - 1)); + } + drop(x); +} + +#[test] +fn foo() { + recurse(10000); +} \ No newline at end of file diff --git a/tests/smoke.rs b/tests/smoke.rs index ea42149..41f41c7 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -3,19 +3,18 @@ extern crate stacker; use std::sync::mpsc; use std::thread; -extern { - fn __stacker_black_box(t: *const u8); -} +#[inline(never)] +fn __stacker_black_box(_: *const u8) {} #[test] fn deep() { fn foo(n: usize, s: &mut [u8]) { - unsafe { __stacker_black_box(s.as_ptr()); } + __stacker_black_box(s.as_ptr()); if n > 0 { stacker::maybe_grow(64 * 1024, 1024 * 1024, || { let mut s = [0u8; 1024]; foo(n - 1, &mut s); - unsafe { __stacker_black_box(s.as_ptr()); } + __stacker_black_box(s.as_ptr()); }) } else { println!("bottom"); @@ -26,15 +25,14 @@ fn deep() { } #[test] -#[ignore] fn panic() { fn foo(n: usize, s: &mut [u8]) { - unsafe { __stacker_black_box(s.as_ptr()); } + __stacker_black_box(s.as_ptr()); if n > 0 { stacker::maybe_grow(64 * 1024, 1024 * 1024, || { let mut s = [0u8; 1024]; foo(n - 1, &mut s); - unsafe { __stacker_black_box(s.as_ptr()); } + __stacker_black_box(s.as_ptr()); }) } else { panic!("bottom"); @@ -49,3 +47,40 @@ fn panic() { assert!(rx.recv().is_err()); } + +fn recursive(n: usize, f: F) -> usize { + if n > 0 { + stacker::grow(64 * 1024, || { + recursive(n - 1, f) + 1 + }) + } else { + f(); + 0 + } +} + +#[test] +fn catch_panic() { + let panic_result = std::panic::catch_unwind(|| { + recursive(100, || panic!()); + }); + assert!(panic_result.is_err()); +} + +#[test] +fn catch_panic_inside() { + stacker::grow(64 * 1024, || { + let panic_result = std::panic::catch_unwind(|| { + recursive(100, || panic!()); + }); + assert!(panic_result.is_err()); + }); +} + +#[test] +fn catch_panic_leaf() { + stacker::grow(64 * 1024, || { + let panic_result = std::panic::catch_unwind(|| panic!()); + assert!(panic_result.is_err()); + }); +}