From 5ca8a735ca36219abbf601624606c41148b95210 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 13 Mar 2017 22:27:12 -0700 Subject: [PATCH] std: Don't cache stdio handles on Windows This alters the stdio code on Windows to always call `GetStdHandle` whenever the stdio read/write functions are called as this allows us to track changes to the value over time (such as if a process calls `SetStdHandle` while it's running). Closes #40490 --- src/libstd/sys/windows/process.rs | 9 +- src/libstd/sys/windows/stdio.rs | 92 +++++++++------------ src/test/run-pass-fulldeps/switch-stdout.rs | 64 ++++++++++++++ 3 files changed, 110 insertions(+), 55 deletions(-) create mode 100644 src/test/run-pass-fulldeps/switch-stdout.rs diff --git a/src/libstd/sys/windows/process.rs b/src/libstd/sys/windows/process.rs index 1afb3728c9d72..dfbc1b581ee55 100644 --- a/src/libstd/sys/windows/process.rs +++ b/src/libstd/sys/windows/process.rs @@ -257,8 +257,13 @@ impl Stdio { // INVALID_HANDLE_VALUE. Stdio::Inherit => { match stdio::get(stdio_id) { - Ok(io) => io.handle().duplicate(0, true, - c::DUPLICATE_SAME_ACCESS), + Ok(io) => { + let io = Handle::new(io.handle()); + let ret = io.duplicate(0, true, + c::DUPLICATE_SAME_ACCESS); + io.into_raw(); + return ret + } Err(..) => Ok(Handle::new(c::INVALID_HANDLE_VALUE)), } } diff --git a/src/libstd/sys/windows/stdio.rs b/src/libstd/sys/windows/stdio.rs index b1a57c349fbb9..d72e4b4438b7b 100644 --- a/src/libstd/sys/windows/stdio.rs +++ b/src/libstd/sys/windows/stdio.rs @@ -22,42 +22,43 @@ use sys::cvt; use sys::handle::Handle; use sys_common::io::read_to_end_uninitialized; -pub struct NoClose(Option); - pub enum Output { - Console(NoClose), - Pipe(NoClose), + Console(c::HANDLE), + Pipe(c::HANDLE), } pub struct Stdin { - handle: Output, utf8: Mutex>>, } -pub struct Stdout(Output); -pub struct Stderr(Output); +pub struct Stdout; +pub struct Stderr; pub fn get(handle: c::DWORD) -> io::Result { let handle = unsafe { c::GetStdHandle(handle) }; if handle == c::INVALID_HANDLE_VALUE { Err(io::Error::last_os_error()) } else if handle.is_null() { - Err(io::Error::new(io::ErrorKind::Other, - "no stdio handle available for this process")) + Err(io::Error::from_raw_os_error(c::ERROR_INVALID_HANDLE as i32)) } else { - let ret = NoClose::new(handle); let mut out = 0; match unsafe { c::GetConsoleMode(handle, &mut out) } { - 0 => Ok(Output::Pipe(ret)), - _ => Ok(Output::Console(ret)), + 0 => Ok(Output::Pipe(handle)), + _ => Ok(Output::Console(handle)), } } } -fn write(out: &Output, data: &[u8]) -> io::Result { - let handle = match *out { - Output::Console(ref c) => c.get().raw(), - Output::Pipe(ref p) => return p.get().write(data), +fn write(handle: c::DWORD, data: &[u8]) -> io::Result { + let handle = match try!(get(handle)) { + Output::Console(c) => c, + Output::Pipe(p) => { + let handle = Handle::new(p); + let ret = handle.write(data); + handle.into_raw(); + return ret + } }; + // As with stdin on windows, stdout often can't handle writes of large // sizes. For an example, see #14940. For this reason, don't try to // write the entire output buffer on windows. @@ -93,18 +94,20 @@ fn write(out: &Output, data: &[u8]) -> io::Result { impl Stdin { pub fn new() -> io::Result { - get(c::STD_INPUT_HANDLE).map(|handle| { - Stdin { - handle: handle, - utf8: Mutex::new(Cursor::new(Vec::new())), - } + Ok(Stdin { + utf8: Mutex::new(Cursor::new(Vec::new())), }) } pub fn read(&self, buf: &mut [u8]) -> io::Result { - let handle = match self.handle { - Output::Console(ref c) => c.get().raw(), - Output::Pipe(ref p) => return p.get().read(buf), + let handle = match try!(get(c::STD_INPUT_HANDLE)) { + Output::Console(c) => c, + Output::Pipe(p) => { + let handle = Handle::new(p); + let ret = handle.read(buf); + handle.into_raw(); + return ret + } }; let mut utf8 = self.utf8.lock().unwrap(); // Read more if the buffer is empty @@ -125,11 +128,9 @@ impl Stdin { Ok(utf8) => utf8.into_bytes(), Err(..) => return Err(invalid_encoding()), }; - if let Output::Console(_) = self.handle { - if let Some(&last_byte) = data.last() { - if last_byte == CTRL_Z { - data.pop(); - } + if let Some(&last_byte) = data.last() { + if last_byte == CTRL_Z { + data.pop(); } } *utf8 = Cursor::new(data); @@ -158,11 +159,11 @@ impl<'a> Read for &'a Stdin { impl Stdout { pub fn new() -> io::Result { - get(c::STD_OUTPUT_HANDLE).map(Stdout) + Ok(Stdout) } pub fn write(&self, data: &[u8]) -> io::Result { - write(&self.0, data) + write(c::STD_OUTPUT_HANDLE, data) } pub fn flush(&self) -> io::Result<()> { @@ -172,11 +173,11 @@ impl Stdout { impl Stderr { pub fn new() -> io::Result { - get(c::STD_ERROR_HANDLE).map(Stderr) + Ok(Stderr) } pub fn write(&self, data: &[u8]) -> io::Result { - write(&self.0, data) + write(c::STD_ERROR_HANDLE, data) } pub fn flush(&self) -> io::Result<()> { @@ -197,27 +198,12 @@ impl io::Write for Stderr { } } -impl NoClose { - fn new(handle: c::HANDLE) -> NoClose { - NoClose(Some(Handle::new(handle))) - } - - fn get(&self) -> &Handle { self.0.as_ref().unwrap() } -} - -impl Drop for NoClose { - fn drop(&mut self) { - self.0.take().unwrap().into_raw(); - } -} - impl Output { - pub fn handle(&self) -> &Handle { - let nc = match *self { - Output::Console(ref c) => c, - Output::Pipe(ref c) => c, - }; - nc.0.as_ref().unwrap() + pub fn handle(&self) -> c::HANDLE { + match *self { + Output::Console(c) => c, + Output::Pipe(c) => c, + } } } diff --git a/src/test/run-pass-fulldeps/switch-stdout.rs b/src/test/run-pass-fulldeps/switch-stdout.rs new file mode 100644 index 0000000000000..4542e27545a4c --- /dev/null +++ b/src/test/run-pass-fulldeps/switch-stdout.rs @@ -0,0 +1,64 @@ +// Copyright 2012-2014 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![feature(rustc_private)] + +extern crate rustc_back; + +use std::fs::File; +use std::io::{Read, Write}; + +use rustc_back::tempdir::TempDir; + +#[cfg(unix)] +fn switch_stdout_to(file: File) { + use std::os::unix::prelude::*; + + extern { + fn dup2(old: i32, new: i32) -> i32; + } + + unsafe { + assert_eq!(dup2(file.as_raw_fd(), 1), 1); + } +} + +#[cfg(windows)] +fn switch_stdout_to(file: File) { + use std::os::windows::prelude::*; + + extern "system" { + fn SetStdHandle(nStdHandle: u32, handle: *mut u8) -> i32; + } + + const STD_OUTPUT_HANDLE: u32 = (-11i32) as u32; + + unsafe { + let rc = SetStdHandle(STD_OUTPUT_HANDLE, + file.into_raw_handle() as *mut _); + assert!(rc != 0); + } +} + +fn main() { + let td = TempDir::new("foo").unwrap(); + let path = td.path().join("bar"); + let f = File::create(&path).unwrap(); + + println!("foo"); + std::io::stdout().flush().unwrap(); + switch_stdout_to(f); + println!("bar"); + std::io::stdout().flush().unwrap(); + + let mut contents = String::new(); + File::open(&path).unwrap().read_to_string(&mut contents).unwrap(); + assert_eq!(contents, "bar\n"); +}