Skip to content

Commit

Permalink
Auto merge of #117703 - compiler-errors:recursive-async, r=<try>
Browse files Browse the repository at this point in the history
Support async recursive calls (as long as they have indirection)

TL;DR: This code should work

```
async fn foo() {
  Box::pin(foo()).await;
}
```

r? `@ghost` while I write up a description, etc.
  • Loading branch information
bors committed Nov 8, 2023
2 parents 7cc997d + d0af0e3 commit 4b93a18
Show file tree
Hide file tree
Showing 37 changed files with 281 additions and 175 deletions.
7 changes: 5 additions & 2 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,8 +792,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
// debuggers and debugger extensions expect it to be called `__awaitee`. They use
// this name to identify what is being awaited by a suspended async functions.
let awaitee_ident = Ident::with_dummy_span(sym::__awaitee);
let (awaitee_pat, awaitee_pat_hid) =
self.pat_ident_binding_mode(span, awaitee_ident, hir::BindingAnnotation::MUT);
let (awaitee_pat, awaitee_pat_hid) = self.pat_ident_binding_mode(
gen_future_span,
awaitee_ident,
hir::BindingAnnotation::MUT,
);

let task_context_ident = Ident::with_dummy_span(sym::_task_context);

Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,15 @@ pub enum CoroutineKind {
Coroutine,
}

impl CoroutineKind {
pub fn is_fn_like(self) -> bool {
matches!(
self,
CoroutineKind::Async(CoroutineSource::Fn) | CoroutineKind::Gen(CoroutineSource::Fn)
)
}
}

impl fmt::Display for CoroutineKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expand Down
25 changes: 7 additions & 18 deletions compiler/rustc_hir_analysis/src/check/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,12 @@ fn check_opaque(tcx: TyCtxt<'_>, id: hir::ItemId) {
return;
}

let args = GenericArgs::identity_for_item(tcx, item.owner_id);
let span = tcx.def_span(item.owner_id.def_id);

if tcx.type_of(item.owner_id.def_id).instantiate_identity().references_error() {
return;
}
if check_opaque_for_cycles(tcx, item.owner_id.def_id, args, span, &origin).is_err() {
if check_opaque_for_cycles(tcx, item.owner_id.def_id, span).is_err() {
return;
}

Expand All @@ -233,16 +232,16 @@ fn check_opaque(tcx: TyCtxt<'_>, id: hir::ItemId) {
pub(super) fn check_opaque_for_cycles<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: LocalDefId,
args: GenericArgsRef<'tcx>,
span: Span,
origin: &hir::OpaqueTyOrigin,
) -> Result<(), ErrorGuaranteed> {
let args = GenericArgs::identity_for_item(tcx, def_id);
if tcx.try_expand_impl_trait_type(def_id.to_def_id(), args).is_err() {
let reported = match origin {
hir::OpaqueTyOrigin::AsyncFn(..) => async_opaque_type_cycle_error(tcx, span),
_ => opaque_type_cycle_error(tcx, def_id, span),
};
let reported = opaque_type_cycle_error(tcx, def_id, span);
Err(reported)
} else if let Err(&LayoutError::Cycle(guar)) =
tcx.layout_of(tcx.param_env(def_id).and(Ty::new_opaque(tcx, def_id.to_def_id(), args)))
{
Err(guar)
} else {
Ok(())
}
Expand Down Expand Up @@ -1324,16 +1323,6 @@ pub(super) fn check_mod_item_types(tcx: TyCtxt<'_>, module_def_id: LocalModDefId
}
}

fn async_opaque_type_cycle_error(tcx: TyCtxt<'_>, span: Span) -> ErrorGuaranteed {
struct_span_err!(tcx.sess, span, E0733, "recursion in an `async fn` requires boxing")
.span_label(span, "recursive `async fn`")
.note("a recursive `async fn` must be rewritten to return a boxed `dyn Future`")
.note(
"consider using the `async_recursion` crate: https://crates.io/crates/async_recursion",
)
.emit()
}

/// Emit an error for recursive opaque types.
///
/// If this is a return `impl Trait`, find the item's return expressions and point at them. For
Expand Down
13 changes: 9 additions & 4 deletions compiler/rustc_middle/src/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub trait Key: Sized {
None
}

fn ty_adt_id(&self) -> Option<DefId> {
fn ty_def_id(&self) -> Option<DefId> {
None
}
}
Expand Down Expand Up @@ -406,9 +406,10 @@ impl<'tcx> Key for Ty<'tcx> {
DUMMY_SP
}

fn ty_adt_id(&self) -> Option<DefId> {
match self.kind() {
fn ty_def_id(&self) -> Option<DefId> {
match *self.kind() {
ty::Adt(adt, _) => Some(adt.did()),
ty::Coroutine(def_id, ..) => Some(def_id),
_ => None,
}
}
Expand Down Expand Up @@ -452,6 +453,10 @@ impl<'tcx, T: Key> Key for ty::ParamEnvAnd<'tcx, T> {
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
self.value.default_span(tcx)
}

fn ty_def_id(&self) -> Option<DefId> {
self.value.ty_def_id()
}
}

impl Key for Symbol {
Expand Down Expand Up @@ -550,7 +555,7 @@ impl<'tcx> Key for (ValidityRequirement, ty::ParamEnvAnd<'tcx, Ty<'tcx>>) {
DUMMY_SP
}

fn ty_adt_id(&self) -> Option<DefId> {
fn ty_def_id(&self) -> Option<DefId> {
match self.1.value.kind() {
ty::Adt(adt, _) => Some(adt.did()),
_ => None,
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,8 @@ rustc_queries! {
) -> Result<ty::layout::TyAndLayout<'tcx>, &'tcx ty::layout::LayoutError<'tcx>> {
depth_limit
desc { "computing layout of `{}`", key.value }
// we emit our own error during query cycle handling
cycle_delay_bug
}

/// Compute a `FnAbi` suitable for indirect calls, i.e. to `fn` pointers.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/query/plumbing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub struct DynamicQuery<'tcx, C: QueryCache> {
fn(tcx: TyCtxt<'tcx>, key: &C::Key, index: SerializedDepNodeIndex) -> bool,
pub hash_result: HashResult<C::Value>,
pub value_from_cycle_error:
fn(tcx: TyCtxt<'tcx>, cycle: &[QueryInfo], guar: ErrorGuaranteed) -> C::Value,
fn(tcx: TyCtxt<'tcx>, cycle_error: &CycleError, guar: ErrorGuaranteed) -> C::Value,
pub format_value: fn(&C::Value) -> String,
}

Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ pub enum LayoutError<'tcx> {
SizeOverflow(Ty<'tcx>),
NormalizationFailure(Ty<'tcx>, NormalizationError<'tcx>),
ReferencesError(ErrorGuaranteed),
Cycle,
Cycle(ErrorGuaranteed),
}

impl<'tcx> LayoutError<'tcx> {
Expand All @@ -226,7 +226,7 @@ impl<'tcx> LayoutError<'tcx> {
Unknown(_) => middle_unknown_layout,
SizeOverflow(_) => middle_values_too_big,
NormalizationFailure(_, _) => middle_cannot_be_normalized,
Cycle => middle_cycle,
Cycle(_) => middle_cycle,
ReferencesError(_) => middle_layout_references_error,
}
}
Expand All @@ -240,7 +240,7 @@ impl<'tcx> LayoutError<'tcx> {
NormalizationFailure(ty, e) => {
E::NormalizationFailure { ty, failure_ty: e.get_type_for_failure() }
}
Cycle => E::Cycle,
Cycle(_) => E::Cycle,
ReferencesError(_) => E::ReferencesError,
}
}
Expand All @@ -261,7 +261,7 @@ impl<'tcx> fmt::Display for LayoutError<'tcx> {
t,
e.get_type_for_failure()
),
LayoutError::Cycle => write!(f, "a cycle occurred during layout computation"),
LayoutError::Cycle(_) => write!(f, "a cycle occurred during layout computation"),
LayoutError::ReferencesError(_) => write!(f, "the type has an unknown layout"),
}
}
Expand Down Expand Up @@ -333,7 +333,7 @@ impl<'tcx> SizeSkeleton<'tcx> {
Err(err @ LayoutError::Unknown(_)) => err,
// We can't extract SizeSkeleton info from other layout errors
Err(
e @ LayoutError::Cycle
e @ LayoutError::Cycle(_)
| e @ LayoutError::SizeOverflow(_)
| e @ LayoutError::NormalizationFailure(..)
| e @ LayoutError::ReferencesError(_),
Expand Down
29 changes: 11 additions & 18 deletions compiler/rustc_middle/src/ty/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,9 +747,13 @@ impl<'tcx> TyCtxt<'tcx> {
match def_kind {
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "method",
DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
rustc_hir::CoroutineKind::Async(..) => "async closure",
rustc_hir::CoroutineKind::Coroutine => "coroutine",
rustc_hir::CoroutineKind::Gen(..) => "gen closure",
hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "async fn",
hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "async block",
hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "async closure",
hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "gen fn",
hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "gen block",
hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "gen closure",
hir::CoroutineKind::Coroutine => "coroutine",
},
_ => def_kind.descr(def_id),
}
Expand All @@ -765,9 +769,9 @@ impl<'tcx> TyCtxt<'tcx> {
match def_kind {
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "a",
DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
rustc_hir::CoroutineKind::Async(..) => "an",
rustc_hir::CoroutineKind::Coroutine => "a",
rustc_hir::CoroutineKind::Gen(..) => "a",
hir::CoroutineKind::Async(..) => "an",
hir::CoroutineKind::Coroutine => "a",
hir::CoroutineKind::Gen(..) => "a",
},
_ => def_kind.article(),
}
Expand Down Expand Up @@ -850,18 +854,7 @@ impl<'tcx> OpaqueTypeExpander<'tcx> {
}
let args = args.fold_with(self);
if !self.check_recursion || self.seen_opaque_tys.insert(def_id) {
let expanded_ty = match self.expanded_cache.get(&(def_id, args)) {
Some(expanded_ty) => *expanded_ty,
None => {
for bty in self.tcx.coroutine_hidden_types(def_id) {
let hidden_ty = bty.instantiate(self.tcx, args);
self.fold_ty(hidden_ty);
}
let expanded_ty = Ty::new_coroutine_witness(self.tcx, def_id, args);
self.expanded_cache.insert((def_id, args), expanded_ty);
expanded_ty
}
};
let expanded_ty = Ty::new_coroutine_witness(self.tcx, def_id, args);
if self.check_recursion {
self.seen_opaque_tys.remove(&def_id);
}
Expand Down
Loading

0 comments on commit 4b93a18

Please sign in to comment.