Skip to content

Commit

Permalink
Auto merge of #116891 - aliemjay:opaque-region-infer-rework-2, r=<try>
Browse files Browse the repository at this point in the history
rework opaque type region inference

fixes #113971 Pass -> Error

fixes #111906 ICE -> Pass
fixes #110623 ==
fixes #109059 ==

fixes #112841 Pass -> Error

fixes #110726 ICE->Error

fixes #111935 Pass -> Error
fixes #113916 ==

r? `@ghost`
  • Loading branch information
bors committed Oct 23, 2023
2 parents aec4741 + 350b33e commit 66abdf7
Show file tree
Hide file tree
Showing 26 changed files with 713 additions and 123 deletions.
31 changes: 19 additions & 12 deletions compiler/rustc_borrowck/src/region_infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ pub struct RegionInferenceContext<'tcx> {
/// visible from this index.
scc_universes: IndexVec<ConstraintSccIndex, ty::UniverseIndex>,

/// Contains a "representative" from each SCC. This will be the
/// minimal RegionVid belonging to that universe. It is used as a
/// kind of hacky way to manage checking outlives relationships,
/// Contains the "representative" region of each SCC.
/// It is defined as the one with the minimal RegionVid, favoring
/// free regions, then placeholders, then existential regions.
///
/// It is a hacky way to manage checking regions for equality,
/// since we can 'canonicalize' each region to the representative
/// of its SCC and be sure that -- if they have the same repr --
/// they *must* be equal (though not having the same repr does not
Expand Down Expand Up @@ -487,22 +489,27 @@ impl<'tcx> RegionInferenceContext<'tcx> {
scc_universes
}

/// For each SCC, we compute a unique `RegionVid` (in fact, the
/// minimal one that belongs to the SCC). See
/// For each SCC, we compute a unique `RegionVid`. See
/// `scc_representatives` field of `RegionInferenceContext` for
/// more details.
fn compute_scc_representatives(
constraints_scc: &Sccs<RegionVid, ConstraintSccIndex>,
definitions: &IndexSlice<RegionVid, RegionDefinition<'tcx>>,
) -> IndexVec<ConstraintSccIndex, ty::RegionVid> {
let num_sccs = constraints_scc.num_sccs();
let next_region_vid = definitions.next_index();
let mut scc_representatives = IndexVec::from_elem_n(next_region_vid, num_sccs);

for region_vid in definitions.indices() {
let scc = constraints_scc.scc(region_vid);
let prev_min = scc_representatives[scc];
scc_representatives[scc] = region_vid.min(prev_min);
let mut scc_representatives = IndexVec::from_elem_n(ty::RegionVid::MAX, num_sccs);

for (vid, def) in definitions.iter_enumerated() {
use NllRegionVariableOrigin as VarOrigin;
let scc = constraints_scc.scc(vid);
let repr = &mut scc_representatives[scc];
if *repr == ty::RegionVid::MAX {
*repr = vid;
} else if matches!(def.origin, VarOrigin::Placeholder(_))
&& matches!(definitions[*repr].origin, VarOrigin::Existential { .. })
{
*repr = vid;
}
}

scc_representatives
Expand Down
216 changes: 132 additions & 84 deletions compiler/rustc_borrowck/src/region_infer/opaque_types.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use rustc_data_structures::fx::FxIndexMap;
use rustc_errors::ErrorGuaranteed;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::LocalDefId;
use rustc_hir::OpaqueTyOrigin;
use rustc_infer::infer::InferCtxt;
use rustc_infer::infer::TyCtxtInferExt as _;
use rustc_infer::infer::{InferCtxt, NllRegionVariableOrigin};
use rustc_infer::traits::{Obligation, ObligationCause};
use rustc_middle::traits::DefiningAnchor;
use rustc_middle::ty::visit::TypeVisitableExt;
Expand Down Expand Up @@ -66,85 +66,60 @@ impl<'tcx> RegionInferenceContext<'tcx> {
) -> FxIndexMap<LocalDefId, OpaqueHiddenType<'tcx>> {
let mut result: FxIndexMap<LocalDefId, OpaqueHiddenType<'tcx>> = FxIndexMap::default();

let member_constraints: FxIndexMap<_, _> = self
.member_constraints
.all_indices()
.map(|ci| (self.member_constraints[ci].key, ci))
.collect();
debug!(?member_constraints);

for (opaque_type_key, concrete_type) in opaque_ty_decls {
let args = opaque_type_key.args;
debug!(?concrete_type, ?args);
debug!(?opaque_type_key, ?concrete_type);

let mut subst_regions = vec![self.universal_regions.fr_static];
let mut arg_regions: Vec<(ty::RegionVid, ty::Region<'_>)> =
vec![(self.universal_regions.fr_static, infcx.tcx.lifetimes.re_static)];

let to_universal_region = |vid, subst_regions: &mut Vec<_>| {
trace!(?vid);
let scc = self.constraint_sccs.scc(vid);
trace!(?scc);
match self.scc_values.universal_regions_outlived_by(scc).find_map(|lb| {
self.eval_equal(vid, lb).then_some(self.definitions[lb].external_name?)
}) {
Some(region) => {
let vid = self.universal_regions.to_region_vid(region);
subst_regions.push(vid);
region
let opaque_type_key =
opaque_type_key.fold_captured_lifetime_args(infcx.tcx, |region| {
let scc = self.constraint_sccs.scc(self.to_region_vid(region));
let vid = self.scc_representatives[scc];
let named = match self.definitions[vid].origin {
NllRegionVariableOrigin::FreeRegion => self
.universal_regions
.universal_regions()
.filter(|&ur| {
use crate::universal_regions::RegionClassification as Class;
matches!(
self.universal_regions.region_classification(ur),
Some(Class::Global | Class::Local)
)
})
.filter(|&ur| ur != self.universal_regions.fr_fn_body)
.find(|&ur| self.universal_region_relations.equal(vid, ur))
.map(|ur| self.definitions[ur].external_name.unwrap()),
NllRegionVariableOrigin::Placeholder(placeholder) => {
Some(ty::Region::new_placeholder(infcx.tcx, placeholder))
}
NllRegionVariableOrigin::Existential { .. } => None,
}
None => {
subst_regions.push(vid);
.unwrap_or_else(|| {
ty::Region::new_error_with_message(
infcx.tcx,
concrete_type.span,
"opaque type with non-universal region args",
)
}
}
};
});

// Start by inserting universal regions from the member_constraint choice regions.
// This will ensure they get precedence when folding the regions in the concrete type.
if let Some(&ci) = member_constraints.get(&opaque_type_key) {
for &vid in self.member_constraints.choice_regions(ci) {
to_universal_region(vid, &mut subst_regions);
}
}
debug!(?subst_regions);

// Next, insert universal regions from args, so we can translate regions that appear
// in them but are not subject to member constraints, for instance closure args.
let universal_args = infcx.tcx.fold_regions(args, |region, _| {
if let ty::RePlaceholder(..) = region.kind() {
// Higher kinded regions don't need remapping, they don't refer to anything outside of this the args.
return region;
}
arg_regions.push((vid, named));
named
});
debug!(?opaque_type_key, ?arg_regions);

let concrete_type = infcx.tcx.fold_regions(concrete_type, |region, _| {
let vid = self.to_region_vid(region);
to_universal_region(vid, &mut subst_regions)
arg_regions
.iter()
.find(|&&(ur_vid, _)| self.eval_equal(vid, ur_vid))
.map(|&(_, ur_name)| ur_name)
.unwrap_or(infcx.tcx.lifetimes.re_erased)
});
debug!(?universal_args);
debug!(?subst_regions);

// Deduplicate the set of regions while keeping the chosen order.
let subst_regions = subst_regions.into_iter().collect::<FxIndexSet<_>>();
debug!(?subst_regions);

let universal_concrete_type =
infcx.tcx.fold_regions(concrete_type, |region, _| match *region {
ty::ReVar(vid) => subst_regions
.iter()
.find(|ur_vid| self.eval_equal(vid, **ur_vid))
.and_then(|ur_vid| self.definitions[*ur_vid].external_name)
.unwrap_or(infcx.tcx.lifetimes.re_erased),
_ => region,
});
debug!(?universal_concrete_type);
debug!(?concrete_type);

let opaque_type_key =
OpaqueTypeKey { def_id: opaque_type_key.def_id, args: universal_args };
let ty = infcx.infer_opaque_definition_from_instantiation(
opaque_type_key,
universal_concrete_type,
);
let ty =
infcx.infer_opaque_definition_from_instantiation(opaque_type_key, concrete_type);
// Sometimes two opaque types are the same only after we remap the generic parameters
// back to the opaque type definition. E.g. we may have `OpaqueType<X, Y>` mapped to `(X, Y)`
// and `OpaqueType<Y, X>` mapped to `(Y, X)`, and those are the same, but we only know that
Expand Down Expand Up @@ -365,38 +340,33 @@ fn check_opaque_type_well_formed<'tcx>(
}
}

fn check_opaque_type_parameter_valid(
tcx: TyCtxt<'_>,
opaque_type_key: OpaqueTypeKey<'_>,
fn check_opaque_type_parameter_valid<'tcx>(
tcx: TyCtxt<'tcx>,
opaque_type_key: OpaqueTypeKey<'tcx>,
span: Span,
) -> Result<(), ErrorGuaranteed> {
let opaque_ty_hir = tcx.hir().expect_item(opaque_type_key.def_id);
let is_ty_alias = match opaque_ty_hir.expect_opaque_ty().origin {
OpaqueTyOrigin::TyAlias { .. } => true,
OpaqueTyOrigin::AsyncFn(..) | OpaqueTyOrigin::FnReturn(..) => false,
};

let opaque_env = LazyOpaqueTyEnv::new(tcx, opaque_type_key.def_id);
let opaque_generics = tcx.generics_of(opaque_type_key.def_id);
let mut seen_params: FxIndexMap<_, Vec<_>> = FxIndexMap::default();
for (i, arg) in opaque_type_key.args.iter().enumerate() {
for (i, arg) in opaque_type_key.iter_captured_args(tcx) {
if let Err(guar) = arg.error_reported() {
return Err(guar);
}

let arg_is_param = match arg.unpack() {
GenericArgKind::Type(ty) => matches!(ty.kind(), ty::Param(_)),
GenericArgKind::Lifetime(lt) if is_ty_alias => {
GenericArgKind::Lifetime(lt) => {
matches!(*lt, ty::ReEarlyBound(_) | ty::ReFree(_))
|| (lt.is_static() && opaque_env.param_equal_static(i))
}
// FIXME(#113916): we can't currently check for unique lifetime params,
// see that issue for more. We will also have to ignore unused lifetime
// params for RPIT, but that's comparatively trivial ✨
GenericArgKind::Lifetime(_) => continue,
GenericArgKind::Const(ct) => matches!(ct.kind(), ty::ConstKind::Param(_)),
};

if arg_is_param {
seen_params.entry(arg).or_default().push(i);
let seen_where = seen_params.entry(arg).or_default();
if !seen_where.first().is_some_and(|&prev_i| opaque_env.params_equal(i, prev_i)) {
seen_where.push(i);
}
} else {
// Prevent `fn foo() -> Foo<u32>` from being defining.
let opaque_param = opaque_generics.param_at(i, tcx);
Expand Down Expand Up @@ -428,3 +398,81 @@ fn check_opaque_type_parameter_valid(

Ok(())
}

struct LazyOpaqueTyEnv<'tcx> {
tcx: TyCtxt<'tcx>,
def_id: LocalDefId,
canonical_args: std::cell::Cell<Option<ty::GenericArgsRef<'tcx>>>,
}

impl<'tcx> LazyOpaqueTyEnv<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> Self {
Self { tcx, def_id, canonical_args: std::cell::Cell::new(None) }
}

pub fn param_equal_static(&self, param_index: usize) -> bool {
self.get_canonical_args()[param_index].expect_region().is_static()
}

pub fn params_equal(&self, param1: usize, param2: usize) -> bool {
let canonical_args = self.get_canonical_args();
canonical_args[param1] == canonical_args[param2]
}

fn get_canonical_args(&self) -> ty::GenericArgsRef<'tcx> {
use rustc_hir as hir;
use rustc_infer::infer::outlives::env::OutlivesEnvironment;
use rustc_trait_selection::traits::outlives_bounds::InferCtxtExt as _;

if let Some(canonical_args) = self.canonical_args.get() {
return canonical_args;
}

let &Self { tcx, def_id, .. } = self;
let origin = tcx.opaque_type_origin(def_id);
let defining_use_anchor = match origin {
hir::OpaqueTyOrigin::FnReturn(did) | hir::OpaqueTyOrigin::AsyncFn(did) => did,
hir::OpaqueTyOrigin::TyAlias { .. } => tcx.impl_trait_parent(def_id),
};
let param_env = tcx.param_env(defining_use_anchor);

let infcx = tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);

let args = match origin {
hir::OpaqueTyOrigin::FnReturn(parent) | hir::OpaqueTyOrigin::AsyncFn(parent) => {
GenericArgs::identity_for_item(tcx, parent).extend_to(
tcx,
def_id.to_def_id(),
|param, _| {
tcx.map_rpit_lifetime_to_fn_lifetime(param.def_id.expect_local()).into()
},
)
}
hir::OpaqueTyOrigin::TyAlias { .. } => GenericArgs::identity_for_item(tcx, def_id),
};

let wf_tys = ocx.assumed_wf_types(param_env, defining_use_anchor).unwrap_or_else(|_| {
tcx.sess.delay_span_bug(tcx.def_span(def_id), "error getting implied bounds");
Default::default()
});
let implied_bounds = infcx.implied_bounds_tys(param_env, defining_use_anchor, wf_tys);
let outlives_env = OutlivesEnvironment::with_bounds(param_env, implied_bounds);

let mut seen = vec![tcx.lifetimes.re_static];
let canonical_args = tcx.fold_regions(args, |r1, _| {
if let Some(&r2) = seen.iter().find(|&&r2| {
let free_regions = outlives_env.free_region_map();
free_regions.sub_free_regions(tcx, r1, r2)
&& free_regions.sub_free_regions(tcx, r2, r1)
}) {
r2
} else {
seen.push(r1);
r1
}
});
self.canonical_args.set(Some(canonical_args));
canonical_args
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ impl UniversalRegionRelations<'_> {
self.outlives.contains(fr1, fr2)
}

/// Returns `true` if fr1 is known to equal fr2.
///
/// This will only ever be true for universally quantified regions.
pub(crate) fn equal(&self, fr1: RegionVid, fr2: RegionVid) -> bool {
self.outlives.contains(fr1, fr2) && self.outlives.contains(fr2, fr1)
}

/// Returns a vector of free regions `x` such that `fr1: x` is
/// known to hold.
pub(crate) fn regions_outlived_by(&self, fr1: RegionVid) -> Vec<RegionVid> {
Expand Down
16 changes: 16 additions & 0 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,22 @@ pub(crate) fn type_check<'mir, 'tcx>(
hidden_type.ty = Ty::new_error(infcx.tcx, reported);
}

// Convert all regions to nll vars.
let (opaque_type_key, hidden_type) =
infcx.tcx.fold_regions((opaque_type_key, hidden_type), |region, _| {
match region.kind() {
ty::ReVar(_) => region,
ty::RePlaceholder(placeholder) => checker
.borrowck_context
.constraints
.placeholder_region(infcx, placeholder),
_ => ty::Region::new_var(
infcx.tcx,
checker.borrowck_context.universal_regions.to_region_vid(region),
),
}
});

(opaque_type_key, hidden_type)
})
.collect();
Expand Down
32 changes: 32 additions & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,38 @@ pub struct OpaqueTypeKey<'tcx> {
pub args: GenericArgsRef<'tcx>,
}

impl<'tcx> OpaqueTypeKey<'tcx> {
pub fn fold_captured_lifetime_args(
self,
tcx: TyCtxt<'tcx>,
mut f: impl FnMut(Region<'tcx>) -> Region<'tcx>,
) -> Self {
let Self { def_id, args } = self;
let args = std::iter::zip(args, tcx.variances_of(def_id)).map(|(arg, v)| {
match (arg.unpack(), v) {
(ty::GenericArgKind::Lifetime(_), ty::Bivariant) => arg,
(ty::GenericArgKind::Lifetime(lt), _) => f(lt).into(),
_ => arg,
}
});
let args = tcx.mk_args_from_iter(args);
Self { def_id, args }
}

pub fn iter_captured_args(
self,
tcx: TyCtxt<'tcx>,
) -> impl Iterator<Item = (usize, GenericArg<'tcx>)> {
std::iter::zip(self.args, tcx.variances_of(self.def_id)).enumerate().filter_map(
|(i, (arg, v))| match (arg.unpack(), v) {
(_, ty::Invariant) => Some((i, arg)),
(ty::GenericArgKind::Lifetime(_), ty::Bivariant) => None,
_ => bug!("unexpected opaque type arg variance"),
},
)
}
}

#[derive(Copy, Clone, Debug, TypeFoldable, TypeVisitable, HashStable, TyEncodable, TyDecodable)]
pub struct OpaqueHiddenType<'tcx> {
/// The span of this particular definition of the opaque type. So
Expand Down
Loading

0 comments on commit 66abdf7

Please sign in to comment.