Skip to content

Commit

Permalink
Normalize opaques eagerly
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Feb 8, 2024
1 parent 1280928 commit c353df0
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 134 deletions.
14 changes: 3 additions & 11 deletions compiler/rustc_hir_typeck/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,7 @@ pub(super) fn check_fn<'a, 'tcx>(
let tcx = fcx.tcx;
let hir = tcx.hir();

let declared_ret_ty = fn_sig.output();

let ret_ty =
fcx.register_infer_ok_obligations(fcx.infcx.replace_opaque_types_with_inference_vars(
declared_ret_ty,
fn_def_id,
decl.output.span(),
fcx.param_env,
));
let ret_ty = fn_sig.output();

fcx.coroutine_types = coroutine_types;
fcx.ret_coercion = Some(RefCell::new(CoerceMany::new(ret_ty)));
Expand Down Expand Up @@ -109,7 +101,7 @@ pub(super) fn check_fn<'a, 'tcx>(
hir::FnRetTy::DefaultReturn(_) => body.value.span,
hir::FnRetTy::Return(ty) => ty.span,
};
fcx.require_type_is_sized(declared_ret_ty, return_or_body_span, traits::SizedReturnType);
fcx.require_type_is_sized(ret_ty, return_or_body_span, traits::SizedReturnType);
fcx.is_whole_body.set(true);
fcx.check_return_expr(body.value, false);

Expand All @@ -120,7 +112,7 @@ pub(super) fn check_fn<'a, 'tcx>(
let coercion = fcx.ret_coercion.take().unwrap().into_inner();
let mut actual_return_ty = coercion.complete(fcx);
debug!("actual_return_ty = {:?}", actual_return_ty);
if let ty::Dynamic(..) = declared_ret_ty.kind() {
if let ty::Dynamic(..) = ret_ty.kind() {
// We have special-cased the case where the function is declared
// `-> dyn Foo` and we don't actually relate it to the
// `fcx.ret_coercion`, so just substitute a type variable.
Expand Down
11 changes: 0 additions & 11 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -879,17 +879,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

let output_ty = self.normalize(closure_span, output_ty);

// async fn that have opaque types in their return type need to redo the conversion to inference variables
// as they fetch the still opaque version from the signature.
let InferOk { value: output_ty, obligations } = self
.replace_opaque_types_with_inference_vars(
output_ty,
body_def_id,
closure_span,
self.param_env,
);
self.register_predicates(obligations);

Some(output_ty)
}

Expand Down
57 changes: 1 addition & 56 deletions compiler/rustc_infer/src/infer/opaque_types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use super::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use super::{DefineOpaqueTypes, InferResult};
use crate::errors::OpaqueHiddenTypeDiag;
use crate::infer::{InferCtxt, InferOk};
Expand Down Expand Up @@ -36,60 +35,6 @@ pub struct OpaqueTypeDecl<'tcx> {
}

impl<'tcx> InferCtxt<'tcx> {
/// This is a backwards compatibility hack to prevent breaking changes from
/// lazy TAIT around RPIT handling.
pub fn replace_opaque_types_with_inference_vars<T: TypeFoldable<TyCtxt<'tcx>>>(
&self,
value: T,
body_id: LocalDefId,
span: Span,
param_env: ty::ParamEnv<'tcx>,
) -> InferOk<'tcx, T> {
// We handle opaque types differently in the new solver.
if self.next_trait_solver() {
return InferOk { value, obligations: vec![] };
}

if !value.has_opaque_types() {
return InferOk { value, obligations: vec![] };
}

let mut obligations = vec![];
let replace_opaque_type = |def_id: DefId| {
def_id.as_local().is_some_and(|def_id| self.opaque_type_origin(def_id).is_some())
};
let value = value.fold_with(&mut BottomUpFolder {
tcx: self.tcx,
lt_op: |lt| lt,
ct_op: |ct| ct,
ty_op: |ty| match *ty.kind() {
ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. })
if replace_opaque_type(def_id) && !ty.has_escaping_bound_vars() =>
{
let def_span = self.tcx.def_span(def_id);
let span = if span.contains(def_span) { def_span } else { span };
let code = traits::ObligationCauseCode::OpaqueReturnType(None);
let cause = ObligationCause::new(span, body_id, code);
// FIXME(compiler-errors): We probably should add a new TypeVariableOriginKind
// for opaque types, and then use that kind to fix the spans for type errors
// that we see later on.
let ty_var = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::OpaqueTypeInference(def_id),
span,
});
obligations.extend(
self.handle_opaque_type(ty, ty_var, true, &cause, param_env)
.unwrap()
.obligations,
);
ty_var
}
_ => ty,
},
});
InferOk { value, obligations }
}

pub fn handle_opaque_type(
&self,
a: Ty<'tcx>,
Expand Down Expand Up @@ -515,7 +460,7 @@ impl UseKind {

impl<'tcx> InferCtxt<'tcx> {
#[instrument(skip(self), level = "debug")]
fn register_hidden_type(
pub fn register_hidden_type(
&self,
opaque_type_key: OpaqueTypeKey<'tcx>,
cause: ObligationCause<'tcx>,
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1004,14 +1004,14 @@ pub enum CodegenObligationError {
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, HashStable, TypeFoldable, TypeVisitable)]
pub enum DefiningAnchor {
/// Define opaques which are in-scope of the `LocalDefId`. Also, eagerly
/// replace opaque types in `replace_opaque_types_with_inference_vars`.
/// replace opaque types in normalization.
Bind(LocalDefId),
/// In contexts where we don't currently know what opaques are allowed to be
/// defined, such as (old solver) canonical queries, we will simply allow
/// opaques to be defined, but "bubble" them up in the canonical response or
/// otherwise treat them to be handled later.
///
/// We do not eagerly replace opaque types in `replace_opaque_types_with_inference_vars`,
/// We do not eagerly replace opaque types in normalization,
/// which may affect what predicates pass and fail in the old trait solver.
Bubble,
/// Do not allow any opaques to be defined. This is used to catch type mismatch
Expand Down
5 changes: 2 additions & 3 deletions compiler/rustc_trait_selection/src/solve/normalize.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::traits::error_reporting::TypeErrCtxtExt;
use crate::traits::query::evaluate_obligation::InferCtxtExt;
use crate::traits::{needs_normalization, BoundVarReplacer, PlaceholderReplacer};
use crate::traits::{BoundVarReplacer, PlaceholderReplacer};
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_infer::infer::at::At;
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
Expand Down Expand Up @@ -205,10 +205,9 @@ impl<'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for NormalizationFolder<'_, 'tcx> {
}

fn try_fold_const(&mut self, ct: ty::Const<'tcx>) -> Result<ty::Const<'tcx>, Self::Error> {
let reveal = self.at.param_env.reveal();
let infcx = self.at.infcx;
debug_assert_eq!(ct, infcx.shallow_resolve(ct));
if !needs_normalization(&ct, reveal) {
if !ct.has_projections() {
return Ok(ct);
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use rustc_span::Span;
use std::fmt::Debug;
use std::ops::ControlFlow;

pub(crate) use self::project::{needs_normalization, BoundVarReplacer, PlaceholderReplacer};
pub(crate) use self::project::{BoundVarReplacer, PlaceholderReplacer};

pub use self::coherence::{add_placeholder_note, orphan_check, overlapping_impls};
pub use self::coherence::{OrphanCheckErr, OverlapResult};
Expand Down
80 changes: 35 additions & 45 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_errors::ErrorGuaranteed;
use rustc_hir::def::DefKind;
use rustc_hir::lang_items::LangItem;
use rustc_hir::OpaqueTyOrigin;
use rustc_infer::infer::at::At;
use rustc_infer::infer::resolve::OpportunisticRegionResolver;
use rustc_infer::infer::DefineOpaqueTypes;
Expand Down Expand Up @@ -318,17 +319,6 @@ fn project_and_unify_type<'cx, 'tcx>(
};
debug!(?normalized, ?obligations, "project_and_unify_type result");
let actual = obligation.predicate.term;
// For an example where this is necessary see tests/ui/impl-trait/nested-return-type2.rs
// This allows users to omit re-mentioning all bounds on an associated type and just use an
// `impl Trait` for the assoc type to add more bounds.
let InferOk { value: actual, obligations: new } =
selcx.infcx.replace_opaque_types_with_inference_vars(
actual,
obligation.cause.body_id,
obligation.cause.span,
obligation.param_env,
);
obligations.extend(new);

// Need to define opaque types to support nested opaque types like `impl Fn() -> impl Trait`
match infcx.at(&obligation.cause, obligation.param_env).eq(
Expand Down Expand Up @@ -409,25 +399,6 @@ where
result
}

pub(crate) fn needs_normalization<'tcx, T: TypeVisitable<TyCtxt<'tcx>>>(
value: &T,
reveal: Reveal,
) -> bool {
match reveal {
Reveal::UserFacing => value.has_type_flags(
ty::TypeFlags::HAS_TY_PROJECTION
| ty::TypeFlags::HAS_TY_INHERENT
| ty::TypeFlags::HAS_CT_PROJECTION,
),
Reveal::All => value.has_type_flags(
ty::TypeFlags::HAS_TY_PROJECTION
| ty::TypeFlags::HAS_TY_INHERENT
| ty::TypeFlags::HAS_TY_OPAQUE
| ty::TypeFlags::HAS_CT_PROJECTION,
),
}
}

struct AssocTypeNormalizer<'a, 'b, 'tcx> {
selcx: &'a mut SelectionContext<'b, 'tcx>,
param_env: ty::ParamEnv<'tcx>,
Expand Down Expand Up @@ -488,11 +459,7 @@ impl<'a, 'b, 'tcx> AssocTypeNormalizer<'a, 'b, 'tcx> {
"Normalizing {value:?} without wrapping in a `Binder`"
);

if !needs_normalization(&value, self.param_env.reveal()) {
value
} else {
value.fold_with(self)
}
if !value.has_projections() { value } else { value.fold_with(self) }
}
}

Expand All @@ -512,7 +479,7 @@ impl<'a, 'b, 'tcx> TypeFolder<TyCtxt<'tcx>> for AssocTypeNormalizer<'a, 'b, 'tcx
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
if !needs_normalization(&ty, self.param_env.reveal()) {
if !ty.has_projections() {
return ty;
}

Expand Down Expand Up @@ -548,7 +515,36 @@ impl<'a, 'b, 'tcx> TypeFolder<TyCtxt<'tcx>> for AssocTypeNormalizer<'a, 'b, 'tcx
ty::Opaque => {
// Only normalize `impl Trait` outside of type inference, usually in codegen.
match self.param_env.reveal() {
Reveal::UserFacing => ty.super_fold_with(self),
Reveal::UserFacing => {
if !data.has_escaping_bound_vars()
&& let Some(def_id) = data.def_id.as_local()
&& let Some(
OpaqueTyOrigin::TyAlias { in_assoc_ty: true }
| OpaqueTyOrigin::AsyncFn(_)
| OpaqueTyOrigin::FnReturn(_),
) = self.selcx.infcx.opaque_type_origin(def_id)
{
let infer = self.selcx.infcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::OpaqueTypeInference(data.def_id),
span: self.cause.span,
});
let InferOk { value: (), obligations } = self
.selcx
.infcx
.register_hidden_type(
ty::OpaqueTypeKey { def_id, args: data.args },
self.cause.clone(),
self.param_env,
infer,
true,
)
.expect("uwu");
self.obligations.extend(obligations);
self.selcx.infcx.resolve_vars_if_possible(infer)
} else {
ty.super_fold_with(self)
}
}

Reveal::All => {
let recursion_limit = self.interner().recursion_limit();
Expand Down Expand Up @@ -752,9 +748,7 @@ impl<'a, 'b, 'tcx> TypeFolder<TyCtxt<'tcx>> for AssocTypeNormalizer<'a, 'b, 'tcx
#[instrument(skip(self), level = "debug")]
fn fold_const(&mut self, constant: ty::Const<'tcx>) -> ty::Const<'tcx> {
let tcx = self.selcx.tcx();
if tcx.features().generic_const_exprs
|| !needs_normalization(&constant, self.param_env.reveal())
{
if tcx.features().generic_const_exprs || !constant.has_projections() {
constant
} else {
let constant = constant.super_fold_with(self);
Expand All @@ -770,11 +764,7 @@ impl<'a, 'b, 'tcx> TypeFolder<TyCtxt<'tcx>> for AssocTypeNormalizer<'a, 'b, 'tcx

#[inline]
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if p.allow_normalization() && needs_normalization(&p, self.param_env.reveal()) {
p.super_fold_with(self)
} else {
p
}
if p.allow_normalization() && p.has_projections() { p.super_fold_with(self) } else { p }
}
}

Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_trait_selection/src/traits/query/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::infer::at::At;
use crate::infer::canonical::OriginalQueryValues;
use crate::infer::{InferCtxt, InferOk};
use crate::traits::error_reporting::TypeErrCtxtExt;
use crate::traits::project::{needs_normalization, BoundVarReplacer, PlaceholderReplacer};
use crate::traits::project::{BoundVarReplacer, PlaceholderReplacer};
use crate::traits::{ObligationCause, PredicateObligation, Reveal};
use rustc_data_structures::sso::SsoHashMap;
use rustc_data_structures::stack::ensure_sufficient_stack;
Expand Down Expand Up @@ -89,7 +89,7 @@ impl<'cx, 'tcx> QueryNormalizeExt<'tcx> for At<'cx, 'tcx> {
}
}

if !needs_normalization(&value, self.param_env.reveal()) {
if !value.has_projections() {
return Ok(Normalized { value, obligations: vec![] });
}

Expand Down Expand Up @@ -198,7 +198,7 @@ impl<'cx, 'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for QueryNormalizer<'cx, 'tcx>

#[instrument(level = "debug", skip(self))]
fn try_fold_ty(&mut self, ty: Ty<'tcx>) -> Result<Ty<'tcx>, Self::Error> {
if !needs_normalization(&ty, self.param_env.reveal()) {
if !ty.has_projections() {
return Ok(ty);
}

Expand Down Expand Up @@ -335,7 +335,7 @@ impl<'cx, 'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for QueryNormalizer<'cx, 'tcx>
&mut self,
constant: ty::Const<'tcx>,
) -> Result<ty::Const<'tcx>, Self::Error> {
if !needs_normalization(&constant, self.param_env.reveal()) {
if !constant.has_projections() {
return Ok(constant);
}

Expand All @@ -354,7 +354,7 @@ impl<'cx, 'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for QueryNormalizer<'cx, 'tcx>
&mut self,
p: ty::Predicate<'tcx>,
) -> Result<ty::Predicate<'tcx>, Self::Error> {
if p.allow_normalization() && needs_normalization(&p, self.param_env.reveal()) {
if p.allow_normalization() && p.has_projections() {
p.try_super_fold_with(self)
} else {
Ok(p)
Expand Down

0 comments on commit c353df0

Please sign in to comment.