Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unpolished, experimental support for AFIDT (async fn in dyn trait) #133122

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/rustc_feature/src/unstable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ declare_features! (
(unstable, associated_type_defaults, "1.2.0", Some(29661)),
/// Allows `async || body` closures.
(unstable, async_closure, "1.37.0", Some(62290)),
/// Allows async functions to be called from `dyn Trait`.
(incomplete, async_fn_in_dyn_trait, "CURRENT_RUSTC_VERSION", Some(133119)),
/// Allows `#[track_caller]` on async functions.
(unstable, async_fn_track_caller, "1.73.0", Some(110011)),
/// Allows `for await` loops.
Expand Down
37 changes: 20 additions & 17 deletions compiler/rustc_middle/src/ty/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,23 +677,26 @@ impl<'tcx> Instance<'tcx> {
//
// 1) The underlying method expects a caller location parameter
// in the ABI
if resolved.def.requires_caller_location(tcx)
// 2) The caller location parameter comes from having `#[track_caller]`
// on the implementation, and *not* on the trait method.
&& !tcx.should_inherit_track_caller(def)
// If the method implementation comes from the trait definition itself
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
// then we don't need to generate a shim. This check is needed because
// `should_inherit_track_caller` returns `false` if our method
// implementation comes from the trait block, and not an impl block
&& !matches!(
tcx.opt_associated_item(def),
Some(ty::AssocItem {
container: ty::AssocItemContainer::Trait,
..
})
)
{
let needs_track_caller_shim = resolved.def.requires_caller_location(tcx)
// 2) The caller location parameter comes from having `#[track_caller]`
// on the implementation, and *not* on the trait method.
&& !tcx.should_inherit_track_caller(def)
// If the method implementation comes from the trait definition itself
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
// then we don't need to generate a shim. This check is needed because
// `should_inherit_track_caller` returns `false` if our method
// implementation comes from the trait block, and not an impl block
&& !matches!(
tcx.opt_associated_item(def),
Some(ty::AssocItem {
container: ty::AssocItemContainer::Trait,
..
})
);
// We also need to generate a shim if this is an AFIT.
let needs_rpitit_shim =
tcx.return_position_impl_trait_in_trait_shim_data(def).is_some();
if needs_track_caller_shim || needs_rpitit_shim {
if tcx.is_closure_like(def) {
debug!(
" => vtable fn pointer created for closure with #[track_caller]: {:?} for method {:?} {:?}",
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ mod opaque_types;
mod parameterized;
mod predicate;
mod region;
mod return_position_impl_trait_in_trait;
mod rvalue_scopes;
mod structural_impls;
#[allow(hidden_glob_reexports)]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use rustc_hir::def_id::DefId;

use crate::ty::{self, ExistentialPredicateStableCmpExt, TyCtxt};

impl<'tcx> TyCtxt<'tcx> {
/// Given a `def_id` of a trait or impl method, compute whether that method needs to
/// have an RPITIT shim applied to it for it to be object safe. If so, return the
/// `def_id` of the RPITIT, and also the args of trait method that returns the RPITIT.
///
/// NOTE that these args are not, in general, the same as than the RPITIT's args. They
/// are a subset of those args, since they do not include the late-bound lifetimes of
/// the RPITIT. Depending on the context, these will need to be dealt with in different
/// ways -- in codegen, it's okay to fill them with ReErased.
pub fn return_position_impl_trait_in_trait_shim_data(
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
self,
def_id: DefId,
) -> Option<(DefId, ty::EarlyBinder<'tcx, ty::GenericArgsRef<'tcx>>)> {
let assoc_item = self.opt_associated_item(def_id)?;

let (trait_item_def_id, opt_impl_def_id) = match assoc_item.container {
ty::AssocItemContainer::Impl => {
(assoc_item.trait_item_def_id?, Some(self.parent(def_id)))
}
ty::AssocItemContainer::Trait => (def_id, None),
};

let sig = self.fn_sig(trait_item_def_id);

// Check if the trait returns an RPITIT.
let ty::Alias(ty::Projection, ty::AliasTy { def_id, .. }) =
*sig.skip_binder().skip_binder().output().kind()
else {
return None;
};
if !self.is_impl_trait_in_trait(def_id) {
return None;
}

let args = if let Some(impl_def_id) = opt_impl_def_id {
// Rebase the args from the RPITIT onto the impl trait ref, so we can later
// substitute them with the method args of the *impl* method, since that's
// the instance we're building a vtable shim for.
ty::GenericArgs::identity_for_item(self, trait_item_def_id).rebase_onto(
self,
self.parent(trait_item_def_id),
self.impl_trait_ref(impl_def_id)
.expect("expected impl trait ref from parent of impl item")
.instantiate_identity()
.args,
)
} else {
// This is when we have a default trait implementation.
ty::GenericArgs::identity_for_item(self, trait_item_def_id)
};

Some((def_id, ty::EarlyBinder::bind(args)))
}

/// Given a `DefId` of an RPITIT and its args, return the existential predicates
/// that corresponds to the RPITIT's bounds with the self type erased.
pub fn item_bounds_to_existential_predicates(
self,
def_id: DefId,
args: ty::GenericArgsRef<'tcx>,
) -> &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> {
let mut bounds: Vec<_> = self
.item_super_predicates(def_id)
.iter_instantiated(self, args)
.filter_map(|clause| {
clause
.kind()
.map_bound(|clause| match clause {
ty::ClauseKind::Trait(trait_pred) => Some(ty::ExistentialPredicate::Trait(
ty::ExistentialTraitRef::erase_self_ty(self, trait_pred.trait_ref),
)),
ty::ClauseKind::Projection(projection_pred) => {
Some(ty::ExistentialPredicate::Projection(
ty::ExistentialProjection::erase_self_ty(self, projection_pred),
))
}
ty::ClauseKind::TypeOutlives(_) => {
// Type outlives bounds don't really turn into anything,
// since we must use an intersection region for the `dyn*`'s
// region anyways.
None
}
_ => unreachable!("unexpected clause in item bounds: {clause:?}"),
})
.transpose()
})
.collect();
bounds.sort_by(|a, b| a.skip_binder().stable_cmp(self, &b.skip_binder()));
self.mk_poly_existential_predicates(&bounds)
}
}
56 changes: 53 additions & 3 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::query::Providers;
use rustc_middle::ty::adjustment::PointerCoercion;
use rustc_middle::ty::{
self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, GenericArgs, Ty, TyCtxt,
};
Expand Down Expand Up @@ -710,6 +711,13 @@ fn build_call_shim<'tcx>(
};

let def_id = instance.def_id();

let rpitit_shim = if let ty::InstanceKind::ReifyShim(..) = instance {
tcx.return_position_impl_trait_in_trait_shim_data(def_id)
} else {
None
};

let sig = tcx.fn_sig(def_id);
let sig = sig.map_bound(|sig| tcx.instantiate_bound_regions_with_erased(sig));

Expand Down Expand Up @@ -765,9 +773,34 @@ fn build_call_shim<'tcx>(
let mut local_decls = local_decls_for_sig(&sig, span);
let source_info = SourceInfo::outermost(span);

let mut destination = Place::return_place();
if let Some((rpitit_def_id, fn_args)) = rpitit_shim {
let rpitit_args =
fn_args.instantiate_identity().extend_to(tcx, rpitit_def_id, |param, _| {
match param.kind {
ty::GenericParamDefKind::Lifetime => tcx.lifetimes.re_erased.into(),
ty::GenericParamDefKind::Type { .. }
| ty::GenericParamDefKind::Const { .. } => {
unreachable!("rpitit should have no addition ty/ct")
}
}
});
let dyn_star_ty = Ty::new_dynamic(
tcx,
tcx.item_bounds_to_existential_predicates(rpitit_def_id, rpitit_args),
tcx.lifetimes.re_erased,
ty::DynStar,
);
destination = local_decls.push(local_decls[RETURN_PLACE].clone()).into();
local_decls[RETURN_PLACE].ty = dyn_star_ty;
let mut inputs_and_output = sig.inputs_and_output.to_vec();
*inputs_and_output.last_mut().unwrap() = dyn_star_ty;
sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output);
}

let rcvr_place = || {
assert!(rcvr_adjustment.is_some());
Place::from(Local::new(1 + 0))
Place::from(Local::new(1))
};
let mut statements = vec![];

Expand Down Expand Up @@ -854,7 +887,7 @@ fn build_call_shim<'tcx>(
TerminatorKind::Call {
func: callee,
args,
destination: Place::return_place(),
destination,
target: Some(BasicBlock::new(1)),
unwind: if let Some(Adjustment::RefMut) = rcvr_adjustment {
UnwindAction::Cleanup(BasicBlock::new(3))
Expand Down Expand Up @@ -882,7 +915,24 @@ fn build_call_shim<'tcx>(
);
}
// BB #1/#2 - return
block(&mut blocks, vec![], TerminatorKind::Return, false);
// NOTE: If this is an RPITIT in dyn, we also want to coerce
// the return type of the function into a `dyn*`.
let stmts = if rpitit_shim.is_some() {
vec![Statement {
source_info,
kind: StatementKind::Assign(Box::new((
Place::return_place(),
Rvalue::Cast(
CastKind::PointerCoercion(PointerCoercion::DynStar, CoercionSource::Implicit),
Operand::Move(destination),
sig.output(),
),
))),
}]
} else {
vec![]
};
block(&mut blocks, stmts, TerminatorKind::Return, false);
if let Some(Adjustment::RefMut) = rcvr_adjustment {
// BB #3 - drop if closure panics
block(
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_monomorphize/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ fn custom_coerce_unsize_info<'tcx>(
..
})) => Ok(tcx.coerce_unsized_info(impl_def_id)?.custom_kind.unwrap()),
impl_source => {
bug!("invalid `CoerceUnsized` impl_source: {:?}", impl_source);
bug!(
"invalid `CoerceUnsized` from {source_ty} to {target_ty}: impl_source: {:?}",
impl_source
);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ symbols! {
async_drop_slice,
async_drop_surface_drop_in_place,
async_fn,
async_fn_in_dyn_trait,
async_fn_in_trait,
async_fn_kind_helper,
async_fn_kind_upvars,
Expand Down
59 changes: 48 additions & 11 deletions compiler/rustc_trait_selection/src/traits/dyn_compatibility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rustc_abi::BackendRepr;
use rustc_errors::FatalError;
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_middle::bug;
use rustc_middle::query::Providers;
use rustc_middle::ty::{
self, EarlyBinder, ExistentialPredicateStableCmpExt as _, GenericArgs, Ty, TyCtxt,
Expand Down Expand Up @@ -901,23 +902,59 @@ fn contains_illegal_impl_trait_in_trait<'tcx>(
fn_def_id: DefId,
ty: ty::Binder<'tcx, Ty<'tcx>>,
) -> Option<MethodViolationCode> {
// This would be caught below, but rendering the error as a separate
// `async-specific` message is better.
let ty = tcx.liberate_late_bound_regions(fn_def_id, ty);

if tcx.asyncness(fn_def_id).is_async() {
return Some(MethodViolationCode::AsyncFn);
// FIXME(async_fn_in_dyn_trait): Think of a better way to unify these code paths
// to issue an appropriate feature suggestion when users try to use AFIDT.
// Obviously we must only do this once AFIDT is finished enough to actually be usable.
if tcx.features().async_fn_in_dyn_trait() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of rejecting if the feature is not enabled, we could always use code path for when the feature is enabled and just emit a feature gate error. Or are you trying to not impact stable code for now in diagnostics?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could. I'm gonna leave it as a follow-up, since this needs to be reworked once again to support -> impl Future style RPITITs.

let ty::Alias(ty::Projection, proj) = *ty.kind() else {
bug!("expected async fn in trait to return an RPITIT");
};
assert!(tcx.is_impl_trait_in_trait(proj.def_id));

// FIXME(async_fn_in_dyn_trait): We should check that this bound is legal too,
// and stop relying on `async fn` in the definition.
for bound in tcx.item_bounds(proj.def_id).instantiate(tcx, proj.args) {
if let Some(violation) = bound
.visit_with(&mut IllegalRpititVisitor { tcx, allowed: Some(proj) })
.break_value()
{
return Some(violation);
}
}

None
} else {
// Rendering the error as a separate `async-specific` message is better.
Some(MethodViolationCode::AsyncFn)
}
} else {
ty.visit_with(&mut IllegalRpititVisitor { tcx, allowed: None }).break_value()
}
}

struct IllegalRpititVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
allowed: Option<ty::AliasTy<'tcx>>,
}

impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for IllegalRpititVisitor<'tcx> {
type Result = ControlFlow<MethodViolationCode>;

// FIXME(RPITIT): Perhaps we should use a visitor here?
ty.skip_binder().walk().find_map(|arg| {
if let ty::GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Alias(ty::Projection, proj) = ty.kind()
&& tcx.is_impl_trait_in_trait(proj.def_id)
fn visit_ty(&mut self, ty: Ty<'tcx>) -> Self::Result {
if let ty::Alias(ty::Projection, proj) = *ty.kind()
&& Some(proj) != self.allowed
&& self.tcx.is_impl_trait_in_trait(proj.def_id)
{
Some(MethodViolationCode::ReferencesImplTraitInTrait(tcx.def_span(proj.def_id)))
ControlFlow::Break(MethodViolationCode::ReferencesImplTraitInTrait(
self.tcx.def_span(proj.def_id),
))
} else {
None
ty.super_visit_with(self)
}
})
}
}

pub(crate) fn provide(providers: &mut Providers) {
Expand Down
Loading
Loading