Skip to content

Commit

Permalink
Rollup merge of #69033 - jonas-schievink:resume-with-context, r=tmandry
Browse files Browse the repository at this point in the history
Use generator resume arguments in the async/await lowering

This removes the TLS requirement from async/await and enables it in `#![no_std]` crates.

Closes #56974

I'm not confident the HIR lowering is completely correct, there seem to be quite a few undocumented invariants in there. The `async-std` and tokio test suites are passing with these changes though.
  • Loading branch information
Centril authored Mar 21, 2020
2 parents 4b91729 + db0126a commit ef7c8a1
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 36 deletions.
78 changes: 78 additions & 0 deletions src/libcore/future/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,84 @@

//! Asynchronous values.
#[cfg(not(bootstrap))]
use crate::{
ops::{Generator, GeneratorState},
pin::Pin,
ptr::NonNull,
task::{Context, Poll},
};

mod future;
#[stable(feature = "futures_api", since = "1.36.0")]
pub use self::future::Future;

/// This type is needed because:
///
/// a) Generators cannot implement `for<'a, 'b> Generator<&'a mut Context<'b>>`, so we need to pass
/// a raw pointer (see https://github.com/rust-lang/rust/issues/68923).
/// b) Raw pointers and `NonNull` aren't `Send` or `Sync`, so that would make every single future
/// non-Send/Sync as well, and we don't want that.
///
/// It also simplifies the HIR lowering of `.await`.
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
#[derive(Debug, Copy, Clone)]
pub struct ResumeTy(NonNull<Context<'static>>);

#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
unsafe impl Send for ResumeTy {}

#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
unsafe impl Sync for ResumeTy {}

/// Wrap a generator in a future.
///
/// This function returns a `GenFuture` underneath, but hides it in `impl Trait` to give
/// better error messages (`impl Future` rather than `GenFuture<[closure.....]>`).
// This is `const` to avoid extra errors after we recover from `const async fn`
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
#[inline]
pub const fn from_generator<T>(gen: T) -> impl Future<Output = T::Return>
where
T: Generator<ResumeTy, Yield = ()>,
{
struct GenFuture<T: Generator<ResumeTy, Yield = ()>>(T);

// We rely on the fact that async/await futures are immovable in order to create
// self-referential borrows in the underlying generator.
impl<T: Generator<ResumeTy, Yield = ()>> !Unpin for GenFuture<T> {}

impl<T: Generator<ResumeTy, Yield = ()>> Future for GenFuture<T> {
type Output = T::Return;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Safety: Safe because we're !Unpin + !Drop, and this is just a field projection.
let gen = unsafe { Pin::map_unchecked_mut(self, |s| &mut s.0) };

// Resume the generator, turning the `&mut Context` into a `NonNull` raw pointer. The
// `.await` lowering will safely cast that back to a `&mut Context`.
match gen.resume(ResumeTy(NonNull::from(cx).cast::<Context<'static>>())) {
GeneratorState::Yielded(()) => Poll::Pending,
GeneratorState::Complete(x) => Poll::Ready(x),
}
}
}

GenFuture(gen)
}

#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[cfg(not(bootstrap))]
#[inline]
pub unsafe fn poll_with_context<F>(f: Pin<&mut F>, mut cx: ResumeTy) -> Poll<F::Output>
where
F: Future,
{
F::poll(f, cx.0.as_mut())
}
101 changes: 79 additions & 22 deletions src/librustc_ast_lowering/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
}

/// Lower an `async` construct to a generator that is then wrapped so it implements `Future`.
///
/// This results in:
///
/// ```text
/// std::future::from_generator(static move? |_task_context| -> <ret_ty> {
/// <body>
/// })
/// ```
pub(super) fn make_async_expr(
&mut self,
capture_clause: CaptureBy,
Expand All @@ -480,17 +489,42 @@ impl<'hir> LoweringContext<'_, 'hir> {
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
let output = match ret_ty {
Some(ty) => FnRetTy::Ty(ty),
None => FnRetTy::Default(span),
Some(ty) => hir::FnRetTy::Return(self.lower_ty(&ty, ImplTraitContext::disallowed())),
None => hir::FnRetTy::DefaultReturn(span),
};
let ast_decl = FnDecl { inputs: vec![], output };
let decl = self.lower_fn_decl(&ast_decl, None, /* impl trait allowed */ false, None);
let body_id = self.lower_fn_body(&ast_decl, |this| {

// Resume argument type. We let the compiler infer this to simplify the lowering. It is
// fully constrained by `future::from_generator`.
let input_ty = hir::Ty { hir_id: self.next_id(), kind: hir::TyKind::Infer, span };

// The closure/generator `FnDecl` takes a single (resume) argument of type `input_ty`.
let decl = self.arena.alloc(hir::FnDecl {
inputs: arena_vec![self; input_ty],
output,
c_variadic: false,
implicit_self: hir::ImplicitSelfKind::None,
});

// Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
let (pat, task_context_hid) = self.pat_ident_binding_mode(
span,
Ident::with_dummy_span(sym::_task_context),
hir::BindingAnnotation::Mutable,
);
let param = hir::Param { attrs: &[], hir_id: self.next_id(), pat, span };
let params = arena_vec![self; param];

let body_id = self.lower_body(move |this| {
this.generator_kind = Some(hir::GeneratorKind::Async(async_gen_kind));
body(this)

let old_ctx = this.task_context;
this.task_context = Some(task_context_hid);
let res = body(this);
this.task_context = old_ctx;
(params, res)
});

// `static || -> <ret_ty> { body }`:
// `static |_task_context| -> <ret_ty> { body }`:
let generator_kind = hir::ExprKind::Closure(
capture_clause,
decl,
Expand Down Expand Up @@ -523,13 +557,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
/// ```rust
/// match <expr> {
/// mut pinned => loop {
/// match ::std::future::poll_with_tls_context(unsafe {
/// <::std::pin::Pin>::new_unchecked(&mut pinned)
/// }) {
/// match unsafe { ::std::future::poll_with_context(
/// <::std::pin::Pin>::new_unchecked(&mut pinned),
/// task_context,
/// ) } {
/// ::std::task::Poll::Ready(result) => break result,
/// ::std::task::Poll::Pending => {}
/// }
/// yield ();
/// task_context = yield ();
/// }
/// }
/// ```
Expand Down Expand Up @@ -561,12 +596,23 @@ impl<'hir> LoweringContext<'_, 'hir> {
let (pinned_pat, pinned_pat_hid) =
self.pat_ident_binding_mode(span, pinned_ident, hir::BindingAnnotation::Mutable);

// ::std::future::poll_with_tls_context(unsafe {
// ::std::pin::Pin::new_unchecked(&mut pinned)
// })`
let task_context_ident = Ident::with_dummy_span(sym::_task_context);

// unsafe {
// ::std::future::poll_with_context(
// ::std::pin::Pin::new_unchecked(&mut pinned),
// task_context,
// )
// }
let poll_expr = {
let pinned = self.expr_ident(span, pinned_ident, pinned_pat_hid);
let ref_mut_pinned = self.expr_mut_addr_of(span, pinned);
let task_context = if let Some(task_context_hid) = self.task_context {
self.expr_ident_mut(span, task_context_ident, task_context_hid)
} else {
// Use of `await` outside of an async context, we cannot use `task_context` here.
self.expr_err(span)
};
let pin_ty_id = self.next_id();
let new_unchecked_expr_kind = self.expr_call_std_assoc_fn(
pin_ty_id,
Expand All @@ -575,14 +621,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
"new_unchecked",
arena_vec![self; ref_mut_pinned],
);
let new_unchecked =
self.arena.alloc(self.expr(span, new_unchecked_expr_kind, ThinVec::new()));
let unsafe_expr = self.expr_unsafe(new_unchecked);
self.expr_call_std_path(
let new_unchecked = self.expr(span, new_unchecked_expr_kind, ThinVec::new());
let call = self.expr_call_std_path(
gen_future_span,
&[sym::future, sym::poll_with_tls_context],
arena_vec![self; unsafe_expr],
)
&[sym::future, sym::poll_with_context],
arena_vec![self; new_unchecked, task_context],
);
self.arena.alloc(self.expr_unsafe(call))
};

// `::std::task::Poll::Ready(result) => break result`
Expand Down Expand Up @@ -622,14 +667,26 @@ impl<'hir> LoweringContext<'_, 'hir> {
self.stmt_expr(span, match_expr)
};

// task_context = yield ();
let yield_stmt = {
let unit = self.expr_unit(span);
let yield_expr = self.expr(
span,
hir::ExprKind::Yield(unit, hir::YieldSource::Await),
ThinVec::new(),
);
self.stmt_expr(span, yield_expr)
let yield_expr = self.arena.alloc(yield_expr);

if let Some(task_context_hid) = self.task_context {
let lhs = self.expr_ident(span, task_context_ident, task_context_hid);
let assign =
self.expr(span, hir::ExprKind::Assign(lhs, yield_expr, span), AttrVec::new());
self.stmt_expr(span, assign)
} else {
// Use of `await` outside of an async context. Return `yield_expr` so that we can
// proceed with type checking.
self.stmt(span, hir::StmtKind::Semi(yield_expr))
}
};

let loop_block = self.block_all(span, arena_vec![self; inner_match_stmt, yield_stmt], None);
Expand Down
4 changes: 2 additions & 2 deletions src/librustc_ast_lowering/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
}

/// Construct `ExprKind::Err` for the given `span`.
fn expr_err(&mut self, span: Span) -> hir::Expr<'hir> {
crate fn expr_err(&mut self, span: Span) -> hir::Expr<'hir> {
self.expr(span, hir::ExprKind::Err, AttrVec::new())
}

Expand Down Expand Up @@ -960,7 +960,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
id
}

fn lower_body(
pub(super) fn lower_body(
&mut self,
f: impl FnOnce(&mut Self) -> (&'hir [hir::Param<'hir>], hir::Expr<'hir>),
) -> hir::BodyId {
Expand Down
5 changes: 5 additions & 0 deletions src/librustc_ast_lowering/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ struct LoweringContext<'a, 'hir: 'a> {

generator_kind: Option<hir::GeneratorKind>,

/// When inside an `async` context, this is the `HirId` of the
/// `task_context` local bound to the resume argument of the generator.
task_context: Option<hir::HirId>,

/// Used to get the current `fn`'s def span to point to when using `await`
/// outside of an `async fn`.
current_item: Option<Span>,
Expand Down Expand Up @@ -294,6 +298,7 @@ pub fn lower_crate<'a, 'hir>(
item_local_id_counters: Default::default(),
node_id_to_hir_id: IndexVec::new(),
generator_kind: None,
task_context: None,
current_item: None,
lifetimes_to_define: Vec::new(),
is_collecting_in_band_lifetimes: false,
Expand Down
7 changes: 5 additions & 2 deletions src/librustc_mir/borrow_check/type_check/input_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,16 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
};

debug!(
"equate_inputs_and_outputs: normalized_input_tys = {:?}, local_decls = {:?}",
normalized_input_tys, body.local_decls
);

// Equate expected input tys with those in the MIR.
for (&normalized_input_ty, argument_index) in normalized_input_tys.iter().zip(0..) {
// In MIR, argument N is stored in local N+1.
let local = Local::new(argument_index + 1);

debug!("equate_inputs_and_outputs: normalized_input_ty = {:?}", normalized_input_ty);

let mir_input_ty = body.local_decls[local].ty;
let mir_input_span = body.local_decls[local].source_info.span;
self.equate_normalized_input_or_output(
Expand Down
3 changes: 2 additions & 1 deletion src/librustc_span/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ symbols! {
plugin_registrar,
plugins,
Poll,
poll_with_tls_context,
poll_with_context,
powerpc_target_feature,
precise_pointer_size_matching,
pref_align_of,
Expand Down Expand Up @@ -720,6 +720,7 @@ symbols! {
target_has_atomic_load_store,
target_thread_local,
task,
_task_context,
tbm_target_feature,
termination_trait,
termination_trait_test,
Expand Down
25 changes: 18 additions & 7 deletions src/libstd/future.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
//! Asynchronous values.
use core::cell::Cell;
use core::marker::Unpin;
use core::ops::{Drop, Generator, GeneratorState};
use core::option::Option;
use core::pin::Pin;
use core::ptr::NonNull;
use core::task::{Context, Poll};
#[cfg(bootstrap)]
use core::{
cell::Cell,
marker::Unpin,
ops::{Drop, Generator, GeneratorState},
pin::Pin,
ptr::NonNull,
task::{Context, Poll},
};

#[doc(inline)]
#[stable(feature = "futures_api", since = "1.36.0")]
Expand All @@ -17,22 +19,26 @@ pub use core::future::*;
/// This function returns a `GenFuture` underneath, but hides it in `impl Trait` to give
/// better error messages (`impl Future` rather than `GenFuture<[closure.....]>`).
// This is `const` to avoid extra errors after we recover from `const async fn`
#[cfg(bootstrap)]
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
pub const fn from_generator<T: Generator<Yield = ()>>(x: T) -> impl Future<Output = T::Return> {
GenFuture(x)
}

/// A wrapper around generators used to implement `Future` for `async`/`await` code.
#[cfg(bootstrap)]
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct GenFuture<T: Generator<Yield = ()>>(T);

// We rely on the fact that async/await futures are immovable in order to create
// self-referential borrows in the underlying generator.
#[cfg(bootstrap)]
impl<T: Generator<Yield = ()>> !Unpin for GenFuture<T> {}

#[cfg(bootstrap)]
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
impl<T: Generator<Yield = ()>> Future for GenFuture<T> {
Expand All @@ -48,12 +54,15 @@ impl<T: Generator<Yield = ()>> Future for GenFuture<T> {
}
}

#[cfg(bootstrap)]
thread_local! {
static TLS_CX: Cell<Option<NonNull<Context<'static>>>> = Cell::new(None);
}

#[cfg(bootstrap)]
struct SetOnDrop(Option<NonNull<Context<'static>>>);

#[cfg(bootstrap)]
impl Drop for SetOnDrop {
fn drop(&mut self) {
TLS_CX.with(|tls_cx| {
Expand All @@ -64,13 +73,15 @@ impl Drop for SetOnDrop {

// Safety: the returned guard must drop before `cx` is dropped and before
// any previous guard is dropped.
#[cfg(bootstrap)]
unsafe fn set_task_context(cx: &mut Context<'_>) -> SetOnDrop {
// transmute the context's lifetime to 'static so we can store it.
let cx = core::mem::transmute::<&mut Context<'_>, &mut Context<'static>>(cx);
let old_cx = TLS_CX.with(|tls_cx| tls_cx.replace(Some(NonNull::from(cx))));
SetOnDrop(old_cx)
}

#[cfg(bootstrap)]
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
/// Polls a future in the current thread-local task waker.
Expand Down
Loading

0 comments on commit ef7c8a1

Please sign in to comment.