diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs index 15e820f2d1941..1a14cd79fa069 100644 --- a/compiler/rustc_const_eval/src/transform/validate.rs +++ b/compiler/rustc_const_eval/src/transform/validate.rs @@ -89,22 +89,20 @@ pub fn equal_up_to_regions<'tcx>( // Normalize lifetimes away on both sides, then compare. let normalize = |ty: Ty<'tcx>| { - tcx.normalize_erasing_regions( - param_env, - ty.fold_with(&mut BottomUpFolder { - tcx, - // FIXME: We erase all late-bound lifetimes, but this is not fully correct. - // If you have a type like ` fn(&'a u32) as SomeTrait>::Assoc`, - // this is not necessarily equivalent to `::Assoc`, - // since one may have an `impl SomeTrait for fn(&32)` and - // `impl SomeTrait for fn(&'static u32)` at the same time which - // specify distinct values for Assoc. (See also #56105) - lt_op: |_| tcx.lifetimes.re_erased, - // Leave consts and types unchanged. - ct_op: |ct| ct, - ty_op: |ty| ty, - }), - ) + let ty = ty.fold_with(&mut BottomUpFolder { + tcx, + // FIXME: We erase all late-bound lifetimes, but this is not fully correct. + // If you have a type like ` fn(&'a u32) as SomeTrait>::Assoc`, + // this is not necessarily equivalent to `::Assoc`, + // since one may have an `impl SomeTrait for fn(&32)` and + // `impl SomeTrait for fn(&'static u32)` at the same time which + // specify distinct values for Assoc. (See also #56105) + lt_op: |_| tcx.lifetimes.re_erased, + // Leave consts and types unchanged. + ct_op: |ct| ct, + ty_op: |ty| ty, + }); + tcx.try_normalize_erasing_regions(param_env, ty).unwrap_or(ty) }; tcx.infer_ctxt().enter(|infcx| infcx.can_eq(param_env, normalize(src), normalize(dest)).is_ok()) } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index c478c17be76c6..d7d2984018804 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -12,6 +12,7 @@ use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyC use rustc_session::config::OptLevel; use rustc_span::def_id::DefId; use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span}; +use rustc_target::abi::VariantIdx; use rustc_target::spec::abi::Abi; use super::simplify::{remove_dead_blocks, CfgSimplifier}; @@ -414,118 +415,60 @@ impl<'tcx> Inliner<'tcx> { debug!(" final inline threshold = {}", threshold); // FIXME: Give a bonus to functions with only a single caller - let mut first_block = true; - let mut cost = 0; + let diverges = matches!( + callee_body.basic_blocks()[START_BLOCK].terminator().kind, + TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. } + ); + if diverges && !matches!(callee_attrs.inline, InlineAttr::Always) { + return Err("callee diverges unconditionally"); + } + + let mut checker = CostChecker { + tcx: self.tcx, + param_env: self.param_env, + instance: callsite.callee, + callee_body, + cost: 0, + validation: Ok(()), + }; - // Traverse the MIR manually so we can account for the effects of - // inlining on the CFG. + // Traverse the MIR manually so we can account for the effects of inlining on the CFG. let mut work_list = vec![START_BLOCK]; let mut visited = BitSet::new_empty(callee_body.basic_blocks().len()); while let Some(bb) = work_list.pop() { if !visited.insert(bb.index()) { continue; } + let blk = &callee_body.basic_blocks()[bb]; + checker.visit_basic_block_data(bb, blk); - for stmt in &blk.statements { - // Don't count StorageLive/StorageDead in the inlining cost. - match stmt.kind { - StatementKind::StorageLive(_) - | StatementKind::StorageDead(_) - | StatementKind::Deinit(_) - | StatementKind::Nop => {} - _ => cost += INSTR_COST, - } - } let term = blk.terminator(); - let mut is_drop = false; - match term.kind { - TerminatorKind::Drop { ref place, target, unwind } - | TerminatorKind::DropAndReplace { ref place, target, unwind, .. } => { - is_drop = true; - work_list.push(target); - // If the place doesn't actually need dropping, treat it like - // a regular goto. - let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty); - if ty.needs_drop(tcx, self.param_env) { - cost += CALL_PENALTY; - if let Some(unwind) = unwind { - cost += LANDINGPAD_PENALTY; - work_list.push(unwind); - } - } else { - cost += INSTR_COST; - } - } - - TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. } - if first_block => - { - // If the function always diverges, don't inline - // unless the cost is zero - threshold = 0; - } - - TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => { - if let ty::FnDef(def_id, _) = - *callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind() - { - // Don't give intrinsics the extra penalty for calls - if tcx.is_intrinsic(def_id) { - cost += INSTR_COST; - } else { - cost += CALL_PENALTY; - } - } else { - cost += CALL_PENALTY; - } - if cleanup.is_some() { - cost += LANDINGPAD_PENALTY; - } - } - TerminatorKind::Assert { cleanup, .. } => { - cost += CALL_PENALTY; - - if cleanup.is_some() { - cost += LANDINGPAD_PENALTY; - } - } - TerminatorKind::Resume => cost += RESUME_PENALTY, - TerminatorKind::InlineAsm { cleanup, .. } => { - cost += INSTR_COST; + if let TerminatorKind::Drop { ref place, target, unwind } + | TerminatorKind::DropAndReplace { ref place, target, unwind, .. } = term.kind + { + work_list.push(target); - if cleanup.is_some() { - cost += LANDINGPAD_PENALTY; + // If the place doesn't actually need dropping, treat it like a regular goto. + let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty); + if ty.needs_drop(tcx, self.param_env) && let Some(unwind) = unwind { + work_list.push(unwind); } - } - _ => cost += INSTR_COST, - } - - if !is_drop { - for succ in term.successors() { - work_list.push(succ); - } + } else { + work_list.extend(term.successors()) } - - first_block = false; } // Count up the cost of local variables and temps, if we know the size // use that, otherwise we use a moderately-large dummy cost. - - let ptr_size = tcx.data_layout.pointer_size.bytes(); - for v in callee_body.vars_and_temps_iter() { - let ty = callsite.callee.subst_mir(self.tcx, &callee_body.local_decls[v].ty); - // Cost of the var is the size in machine-words, if we know - // it. - if let Some(size) = type_size_of(tcx, self.param_env, ty) { - cost += ((size + ptr_size - 1) / ptr_size) as usize; - } else { - cost += UNKNOWN_SIZE_COST; - } + checker.visit_local_decl(v, &callee_body.local_decls[v]); } + // Abort if type validation found anything fishy. + checker.validation?; + + let cost = checker.cost; if let InlineAttr::Always = callee_attrs.inline { debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost); Ok(()) @@ -799,6 +742,193 @@ fn type_size_of<'tcx>( tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes()) } +/// Verify that the callee body is compatible with the caller. +/// +/// This visitor mostly computes the inlining cost, +/// but also needs to verify that types match because of normalization failure. +struct CostChecker<'b, 'tcx> { + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + cost: usize, + callee_body: &'b Body<'tcx>, + instance: ty::Instance<'tcx>, + validation: Result<(), &'static str>, +} + +impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + // Don't count StorageLive/StorageDead in the inlining cost. + match statement.kind { + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Deinit(_) + | StatementKind::Nop => {} + _ => self.cost += INSTR_COST, + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + let tcx = self.tcx; + match terminator.kind { + TerminatorKind::Drop { ref place, unwind, .. } + | TerminatorKind::DropAndReplace { ref place, unwind, .. } => { + // If the place doesn't actually need dropping, treat it like a regular goto. + let ty = self.instance.subst_mir(tcx, &place.ty(self.callee_body, tcx).ty); + if ty.needs_drop(tcx, self.param_env) { + self.cost += CALL_PENALTY; + if unwind.is_some() { + self.cost += LANDINGPAD_PENALTY; + } + } else { + self.cost += INSTR_COST; + } + } + TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => { + let fn_ty = self.instance.subst_mir(tcx, &f.literal.ty()); + self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) { + // Don't give intrinsics the extra penalty for calls + INSTR_COST + } else { + CALL_PENALTY + }; + if cleanup.is_some() { + self.cost += LANDINGPAD_PENALTY; + } + } + TerminatorKind::Assert { cleanup, .. } => { + self.cost += CALL_PENALTY; + if cleanup.is_some() { + self.cost += LANDINGPAD_PENALTY; + } + } + TerminatorKind::Resume => self.cost += RESUME_PENALTY, + TerminatorKind::InlineAsm { cleanup, .. } => { + self.cost += INSTR_COST; + if cleanup.is_some() { + self.cost += LANDINGPAD_PENALTY; + } + } + _ => self.cost += INSTR_COST, + } + + self.super_terminator(terminator, location); + } + + /// Count up the cost of local variables and temps, if we know the size + /// use that, otherwise we use a moderately-large dummy cost. + fn visit_local_decl(&mut self, local: Local, local_decl: &LocalDecl<'tcx>) { + let tcx = self.tcx; + let ptr_size = tcx.data_layout.pointer_size.bytes(); + + let ty = self.instance.subst_mir(tcx, &local_decl.ty); + // Cost of the var is the size in machine-words, if we know + // it. + if let Some(size) = type_size_of(tcx, self.param_env, ty) { + self.cost += ((size + ptr_size - 1) / ptr_size) as usize; + } else { + self.cost += UNKNOWN_SIZE_COST; + } + + self.super_local_decl(local, local_decl) + } + + /// This method duplicates code from MIR validation in an attempt to detect type mismatches due + /// to normalization failure. + fn visit_projection_elem( + &mut self, + local: Local, + proj_base: &[PlaceElem<'tcx>], + elem: PlaceElem<'tcx>, + context: PlaceContext, + location: Location, + ) { + if let ProjectionElem::Field(f, ty) = elem { + let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) }; + let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx); + let check_equal = |this: &mut Self, f_ty| { + if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) { + trace!(?ty, ?f_ty); + this.validation = Err("failed to normalize projection type"); + return; + } + }; + + let kind = match parent_ty.ty.kind() { + &ty::Opaque(def_id, substs) => { + self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind() + } + kind => kind, + }; + + match kind { + ty::Tuple(fields) => { + let Some(f_ty) = fields.get(f.as_usize()) else { + self.validation = Err("malformed MIR"); + return; + }; + check_equal(self, *f_ty); + } + ty::Adt(adt_def, substs) => { + let var = parent_ty.variant_index.unwrap_or(VariantIdx::from_u32(0)); + let Some(field) = adt_def.variant(var).fields.get(f.as_usize()) else { + self.validation = Err("malformed MIR"); + return; + }; + check_equal(self, field.ty(self.tcx, substs)); + } + ty::Closure(_, substs) => { + let substs = substs.as_closure(); + let Some(f_ty) = substs.upvar_tys().nth(f.as_usize()) else { + self.validation = Err("malformed MIR"); + return; + }; + check_equal(self, f_ty); + } + &ty::Generator(def_id, substs, _) => { + let f_ty = if let Some(var) = parent_ty.variant_index { + let gen_body = if def_id == self.callee_body.source.def_id() { + self.callee_body + } else { + self.tcx.optimized_mir(def_id) + }; + + let Some(layout) = gen_body.generator_layout() else { + self.validation = Err("malformed MIR"); + return; + }; + + let Some(&local) = layout.variant_fields[var].get(f) else { + self.validation = Err("malformed MIR"); + return; + }; + + let Some(&f_ty) = layout.field_tys.get(local) else { + self.validation = Err("malformed MIR"); + return; + }; + + f_ty + } else { + let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else { + self.validation = Err("malformed MIR"); + return; + }; + + f_ty + }; + + check_equal(self, f_ty); + } + _ => self.validation = Err("malformed MIR"), + } + } + + self.super_projection_elem(local, proj_base, elem, context, location); + } +} + /** * Integrator. * diff --git a/src/test/ui/mir/mir-inlining/ice-issue-100550-unnormalized-projection.rs b/src/test/ui/mir/mir-inlining/ice-issue-100550-unnormalized-projection.rs new file mode 100644 index 0000000000000..f67b073548167 --- /dev/null +++ b/src/test/ui/mir/mir-inlining/ice-issue-100550-unnormalized-projection.rs @@ -0,0 +1,30 @@ +// This test verifies that we do not ICE due to MIR inlining in case of normalization failure +// in a projection. +// +// compile-flags: --crate-type lib -C opt-level=3 +// build-pass + +pub trait Trait { + type Associated; +} +impl Trait for T { + type Associated = T; +} + +pub struct Struct(::Associated); + +pub fn foo() -> Struct +where + T: Trait, +{ + bar() +} + +#[inline] +fn bar() -> Struct { + Struct(baz()) +} + +fn baz() -> T { + unimplemented!() +}