Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
Clean up code and add comments.
Use InlineConstant to wrap range patterns.
  • Loading branch information
matthewjasper committed Oct 13, 2023
1 parent 98b4c1e commit c28d195
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 62 deletions.
15 changes: 13 additions & 2 deletions compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,19 @@ pub enum PatKind<'tcx> {
value: mir::Const<'tcx>,
},

/// Inline constant found while lowering a pattern.
InlineConstant {
value: mir::UnevaluatedConst<'tcx>,
/// [LocalDefId] of the constant, we need this so that we have a
/// reference that can be used by unsafety checking to visit nested
/// unevaluated constants.
def: LocalDefId,
/// If the inline constant is used in a range pattern, this subpattern
/// represents the range (if both ends are inline constants, there will
/// be multiple InlineConstant wrappers).
///
/// Otherwise, the actual pattern that the constant lowered to. As with
/// other constants, inline constants are matched structurally where
/// possible.
subpattern: Box<Pat<'tcx>>,
},

Expand Down Expand Up @@ -910,7 +921,7 @@ impl<'tcx> fmt::Display for Pat<'tcx> {
write!(f, "{subpattern}")
}
PatKind::Constant { value } => write!(f, "{value}"),
PatKind::InlineConstant { value: _, ref subpattern } => {
PatKind::InlineConstant { def: _, ref subpattern } => {
write!(f, "{} (from inline const)", subpattern)
}
PatKind::Range(box PatRange { lo, hi, end }) => {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/thir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ pub fn walk_pat<'a, 'tcx: 'a, V: Visitor<'a, 'tcx>>(visitor: &mut V, pat: &Pat<'
}
}
Constant { value: _ } => {}
InlineConstant { value: _, subpattern } => visitor.visit_pat(subpattern),
InlineConstant { def: _, subpattern } => visitor.visit_pat(subpattern),
Range(_) => {}
Slice { prefix, slice, suffix } | Array { prefix, slice, suffix } => {
for subpattern in prefix.iter() {
Expand Down
18 changes: 4 additions & 14 deletions compiler/rustc_mir_build/src/build/matches/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
Err(match_pair)
}

PatKind::InlineConstant { subpattern: ref pattern, value: _ } => {
PatKind::InlineConstant { subpattern: ref pattern, def: _ } => {
candidate.match_pairs.push(MatchPair::new(match_pair.place, pattern, self));

Ok(())
Expand Down Expand Up @@ -236,20 +236,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// pattern/_match.rs for another pertinent example of this pattern).
//
// Also, for performance, it's important to only do the second
// `try_eval_scalar_int` if necessary.
let lo = lo
.try_eval_scalar_int(self.tcx, self.param_env)
.unwrap()
.to_bits(sz)
.unwrap()
^ bias;
// `try_to_bits` if necessary.
let lo = lo.try_to_bits(sz).unwrap() ^ bias;
if lo <= min {
let hi = hi
.try_eval_scalar_int(self.tcx, self.param_env)
.unwrap()
.to_bits(sz)
.unwrap()
^ bias;
let hi = hi.try_to_bits(sz).unwrap() ^ bias;
if hi > max || hi == max && end == RangeEnd::Included {
// Irrefutable pattern match.
return Ok(());
Expand Down
1 change: 0 additions & 1 deletion compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ fn mir_build<'tcx>(tcx: TyCtxt<'tcx>, def: LocalDefId) -> Body<'tcx> {
thir::BodyTy::Const(ty) => construct_const(tcx, def, thir, expr, ty),
};

tcx.ensure().check_match(def);
// this must run before MIR dump, because
// "not all control paths return a value" is reported here.
//
Expand Down
27 changes: 7 additions & 20 deletions compiler/rustc_mir_build/src/check_unsafety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::errors::*;
use rustc_middle::thir::visit::{self, Visitor};

use rustc_hir as hir;
use rustc_middle::mir::{BorrowKind, Const};
use rustc_middle::mir::BorrowKind;
use rustc_middle::thir::*;
use rustc_middle::ty::print::with_no_trimmed_paths;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
Expand Down Expand Up @@ -124,7 +124,8 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
/// Handle closures/generators/inline-consts, which is unsafecked with their parent body.
fn visit_inner_body(&mut self, def: LocalDefId) {
if let Ok((inner_thir, expr)) = self.tcx.thir_body(def) {
let _ = self.tcx.ensure_with_value().mir_built(def);
// Runs all other queries that depend on THIR.
self.tcx.ensure_with_value().mir_built(def);
let inner_thir = &inner_thir.steal();
let hir_context = self.tcx.hir().local_def_id_to_hir_id(def);
let mut inner_visitor = UnsafetyVisitor { thir: inner_thir, hir_context, ..*self };
Expand Down Expand Up @@ -278,23 +279,8 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
visit::walk_pat(self, pat);
self.inside_adt = old_inside_adt;
}
PatKind::Range(range) => {
if let Const::Unevaluated(c, _) = range.lo {
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
let def_id = c.def.expect_local();
self.visit_inner_body(def_id);
}
}
if let Const::Unevaluated(c, _) = range.hi {
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
let def_id = c.def.expect_local();
self.visit_inner_body(def_id);
}
}
}
PatKind::InlineConstant { value, .. } => {
let def_id = value.def.expect_local();
self.visit_inner_body(def_id);
PatKind::InlineConstant { def, .. } => {
self.visit_inner_body(*def);
}
_ => {
visit::walk_pat(self, pat);
Expand Down Expand Up @@ -804,7 +790,8 @@ pub fn thir_check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
}

let Ok((thir, expr)) = tcx.thir_body(def) else { return };
let _ = tcx.ensure_with_value().mir_built(def);
// Runs all other queries that depend on THIR.
tcx.ensure_with_value().mir_built(def);
let thir = &thir.steal();
// If `thir` is empty, a type error occurred, skip this body.
if thir.exprs.is_empty() {
Expand Down
47 changes: 27 additions & 20 deletions compiler/rustc_mir_build/src/thir/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use rustc_middle::ty::{
self, AdtDef, CanonicalUserTypeAnnotation, GenericArg, GenericArgsRef, Region, Ty, TyCtxt,
TypeVisitableExt, UserType,
};
use rustc_span::def_id::LocalDefId;
use rustc_span::{ErrorGuaranteed, Span, Symbol};
use rustc_target::abi::{FieldIdx, Integer};

Expand Down Expand Up @@ -88,19 +89,21 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
fn lower_pattern_range_endpoint(
&mut self,
expr: Option<&'tcx hir::Expr<'tcx>>,
) -> Result<(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>), ErrorGuaranteed> {
) -> Result<
(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>, Option<LocalDefId>),
ErrorGuaranteed,
> {
match expr {
None => Ok((None, None)),
None => Ok((None, None, None)),
Some(expr) => {
let (kind, ascr) = match self.lower_lit(expr) {
PatKind::InlineConstant { subpattern, value } => (
PatKind::Constant { value: Const::Unevaluated(value, subpattern.ty) },
None,
),
let (kind, ascr, inline_const) = match self.lower_lit(expr) {
PatKind::InlineConstant { subpattern, def } => {
(subpattern.kind, None, Some(def))
}
PatKind::AscribeUserType { ascription, subpattern: box Pat { kind, .. } } => {
(kind, Some(ascription))
(kind, Some(ascription), None)
}
kind => (kind, None),
kind => (kind, None, None),
};
let value = if let PatKind::Constant { value } = kind {
value
Expand All @@ -110,7 +113,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
);
return Err(self.tcx.sess.delay_span_bug(expr.span, msg));
};
Ok((Some(value), ascr))
Ok((Some(value), ascr, inline_const))
}
}
}
Expand Down Expand Up @@ -181,8 +184,8 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
return Err(self.tcx.sess.delay_span_bug(span, msg));
}

let (lo, lo_ascr) = self.lower_pattern_range_endpoint(lo_expr)?;
let (hi, hi_ascr) = self.lower_pattern_range_endpoint(hi_expr)?;
let (lo, lo_ascr, lo_inline) = self.lower_pattern_range_endpoint(lo_expr)?;
let (hi, hi_ascr, hi_inline) = self.lower_pattern_range_endpoint(hi_expr)?;

let lo = lo.unwrap_or_else(|| {
// Unwrap is ok because the type is known to be numeric.
Expand Down Expand Up @@ -241,6 +244,12 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
};
}
}
for inline_const in [lo_inline, hi_inline] {
if let Some(def) = inline_const {
kind =
PatKind::InlineConstant { def, subpattern: Box::new(Pat { span, ty, kind }) };
}
}
Ok(kind)
}

Expand Down Expand Up @@ -606,11 +615,9 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
// const eval path below.
// FIXME: investigate the performance impact of removing this.
let lit_input = match expr.kind {
hir::ExprKind::Lit(ref lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }),
hir::ExprKind::Unary(hir::UnOp::Neg, ref expr) => match expr.kind {
hir::ExprKind::Lit(ref lit) => {
Some(LitToConstInput { lit: &lit.node, ty, neg: true })
}
hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }),
hir::ExprKind::Unary(hir::UnOp::Neg, expr) => match expr.kind {
hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: true }),
_ => None,
},
_ => None,
Expand Down Expand Up @@ -646,7 +653,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
span,
None,
);
PatKind::InlineConstant { subpattern, value: uneval }
PatKind::InlineConstant { subpattern, def: def_id }
} else {
// If that fails, convert it to an opaque constant pattern.
match tcx.const_eval_resolve(self.param_env, uneval, Some(span)) {
Expand Down Expand Up @@ -828,8 +835,8 @@ impl<'tcx> PatternFoldable<'tcx> for PatKind<'tcx> {
PatKind::Deref { subpattern: subpattern.fold_with(folder) }
}
PatKind::Constant { value } => PatKind::Constant { value },
PatKind::InlineConstant { value, subpattern: ref pattern } => {
PatKind::InlineConstant { value, subpattern: pattern.fold_with(folder) }
PatKind::InlineConstant { def, subpattern: ref pattern } => {
PatKind::InlineConstant { def, subpattern: pattern.fold_with(folder) }
}
PatKind::Range(ref range) => PatKind::Range(range.clone()),
PatKind::Slice { ref prefix, ref slice, ref suffix } => PatKind::Slice {
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_mir_build/src/thir/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
}
PatKind::Deref { subpattern } => {
print_indented!(self, "Deref { ", depth_lvl + 1);
print_indented!(self, "subpattern: ", depth_lvl + 2);
print_indented!(self, "subpattern:", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
Expand All @@ -701,10 +701,10 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
PatKind::InlineConstant { value, subpattern } => {
PatKind::InlineConstant { def, subpattern } => {
print_indented!(self, "InlineConstant {", depth_lvl + 1);
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
print_indented!(self, "subpattern: ", depth_lvl + 2);
print_indented!(self, format!("def: {:?}", def), depth_lvl + 2);
print_indented!(self, "subpattern:", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
Expand Down

0 comments on commit c28d195

Please sign in to comment.