diff --git a/library/std/src/thread/mod.rs b/library/std/src/thread/mod.rs index bcf2ec06022d9..fdf0e9faba48e 100644 --- a/library/std/src/thread/mod.rs +++ b/library/std/src/thread/mod.rs @@ -180,6 +180,12 @@ use crate::time::Duration; #[macro_use] mod local; +#[unstable(feature = "scoped_threads", issue = "93203")] +mod scoped; + +#[unstable(feature = "scoped_threads", issue = "93203")] +pub use scoped::{scope, Scope, ScopedJoinHandle}; + #[stable(feature = "rust1", since = "1.0.0")] pub use self::local::{AccessError, LocalKey}; @@ -446,6 +452,20 @@ impl Builder { F: FnOnce() -> T, F: Send + 'a, T: Send + 'a, + { + Ok(JoinHandle(unsafe { self.spawn_unchecked_(f, None) }?)) + } + + unsafe fn spawn_unchecked_<'a, 'scope, F, T>( + self, + f: F, + scope_data: Option<&'scope scoped::ScopeData>, + ) -> io::Result> + where + F: FnOnce() -> T, + F: Send + 'a, + T: Send + 'a, + 'scope: 'a, { let Builder { name, stack_size } = self; @@ -456,7 +476,8 @@ impl Builder { })); let their_thread = my_thread.clone(); - let my_packet: Arc>>> = Arc::new(UnsafeCell::new(None)); + let my_packet: Arc> = + Arc::new(Packet { scope: scope_data, result: UnsafeCell::new(None) }); let their_packet = my_packet.clone(); let output_capture = crate::io::set_output_capture(None); @@ -480,10 +501,14 @@ impl Builder { // closure (it is an Arc<...>) and `my_packet` will be stored in the // same `JoinInner` as this closure meaning the mutation will be // safe (not modify it and affect a value far away). - unsafe { *their_packet.get() = Some(try_result) }; + unsafe { *their_packet.result.get() = Some(try_result) }; }; - Ok(JoinHandle(JoinInner { + if let Some(scope_data) = scope_data { + scope_data.increment_num_running_threads(); + } + + Ok(JoinInner { // SAFETY: // // `imp::Thread::new` takes a closure with a `'static` lifetime, since it's passed @@ -498,16 +523,16 @@ impl Builder { // exist after the thread has terminated, which is signaled by `Thread::join` // returning. native: unsafe { - Some(imp::Thread::new( + imp::Thread::new( stack_size, mem::transmute::, Box>( Box::new(main), ), - )?) + )? }, thread: my_thread, - packet: Packet(my_packet), - })) + packet: my_packet, + }) } } @@ -1239,34 +1264,48 @@ impl fmt::Debug for Thread { #[stable(feature = "rust1", since = "1.0.0")] pub type Result = crate::result::Result>; -// This packet is used to communicate the return value between the spawned thread -// and the rest of the program. Memory is shared through the `Arc` within and there's -// no need for a mutex here because synchronization happens with `join()` (the -// caller will never read this packet until the thread has exited). +// This packet is used to communicate the return value between the spawned +// thread and the rest of the program. It is shared through an `Arc` and +// there's no need for a mutex here because synchronization happens with `join()` +// (the caller will never read this packet until the thread has exited). // -// This packet itself is then stored into a `JoinInner` which in turns is placed -// in `JoinHandle` and `JoinGuard`. Due to the usage of `UnsafeCell` we need to -// manually worry about impls like Send and Sync. The type `T` should -// already always be Send (otherwise the thread could not have been created) and -// this type is inherently Sync because no methods take &self. Regardless, -// however, we add inheriting impls for Send/Sync to this type to ensure it's -// Send/Sync and that future modifications will still appropriately classify it. -struct Packet(Arc>>>); - -unsafe impl Send for Packet {} -unsafe impl Sync for Packet {} +// An Arc to the packet is stored into a `JoinInner` which in turns is placed +// in `JoinHandle`. +struct Packet<'scope, T> { + scope: Option<&'scope scoped::ScopeData>, + result: UnsafeCell>>, +} + +// Due to the usage of `UnsafeCell` we need to manually implement Sync. +// The type `T` should already always be Send (otherwise the thread could not +// have been created) and the Packet is Sync because all access to the +// `UnsafeCell` synchronized (by the `join()` boundary), and `ScopeData` is Sync. +unsafe impl<'scope, T: Sync> Sync for Packet<'scope, T> {} + +impl<'scope, T> Drop for Packet<'scope, T> { + fn drop(&mut self) { + // Book-keeping so the scope knows when it's done. + if let Some(scope) = self.scope { + // If this packet was for a thread that ran in a scope, the thread + // panicked, and nobody consumed the panic payload, we make sure + // the scope function will panic. + let unhandled_panic = matches!(self.result.get_mut(), Some(Err(_))); + scope.decrement_num_running_threads(unhandled_panic); + } + } +} /// Inner representation for JoinHandle -struct JoinInner { - native: Option, +struct JoinInner<'scope, T> { + native: imp::Thread, thread: Thread, - packet: Packet, + packet: Arc>, } -impl JoinInner { - fn join(&mut self) -> Result { - self.native.take().unwrap().join(); - unsafe { (*self.packet.0.get()).take().unwrap() } +impl<'scope, T> JoinInner<'scope, T> { + fn join(mut self) -> Result { + self.native.join(); + Arc::get_mut(&mut self.packet).unwrap().result.get_mut().take().unwrap() } } @@ -1333,7 +1372,7 @@ impl JoinInner { /// [`thread::Builder::spawn`]: Builder::spawn /// [`thread::spawn`]: spawn #[stable(feature = "rust1", since = "1.0.0")] -pub struct JoinHandle(JoinInner); +pub struct JoinHandle(JoinInner<'static, T>); #[stable(feature = "joinhandle_impl_send_sync", since = "1.29.0")] unsafe impl Send for JoinHandle {} @@ -1397,29 +1436,29 @@ impl JoinHandle { /// join_handle.join().expect("Couldn't join on the associated thread"); /// ``` #[stable(feature = "rust1", since = "1.0.0")] - pub fn join(mut self) -> Result { + pub fn join(self) -> Result { self.0.join() } - /// Checks if the the associated thread is still running its main function. + /// Checks if the associated thread is still running its main function. /// /// This might return `false` for a brief moment after the thread's main /// function has returned, but before the thread itself has stopped running. #[unstable(feature = "thread_is_running", issue = "90470")] pub fn is_running(&self) -> bool { - Arc::strong_count(&self.0.packet.0) > 1 + Arc::strong_count(&self.0.packet) > 1 } } impl AsInner for JoinHandle { fn as_inner(&self) -> &imp::Thread { - self.0.native.as_ref().unwrap() + &self.0.native } } impl IntoInner for JoinHandle { fn into_inner(self) -> imp::Thread { - self.0.native.unwrap() + self.0.native } } diff --git a/library/std/src/thread/scoped.rs b/library/std/src/thread/scoped.rs new file mode 100644 index 0000000000000..9dd7c15fc5922 --- /dev/null +++ b/library/std/src/thread/scoped.rs @@ -0,0 +1,316 @@ +use super::{current, park, Builder, JoinInner, Result, Thread}; +use crate::fmt; +use crate::io; +use crate::marker::PhantomData; +use crate::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; +use crate::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use crate::sync::Arc; + +/// A scope to spawn scoped threads in. +/// +/// See [`scope`] for details. +pub struct Scope<'env> { + data: ScopeData, + /// Invariance over 'env, to make sure 'env cannot shrink, + /// which is necessary for soundness. + /// + /// Without invariance, this would compile fine but be unsound: + /// + /// ```compile_fail + /// #![feature(scoped_threads)] + /// + /// std::thread::scope(|s| { + /// s.spawn(|s| { + /// let a = String::from("abcd"); + /// s.spawn(|_| println!("{:?}", a)); // might run after `a` is dropped + /// }); + /// }); + /// ``` + env: PhantomData<&'env mut &'env ()>, +} + +/// An owned permission to join on a scoped thread (block on its termination). +/// +/// See [`Scope::spawn`] for details. +pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>); + +pub(super) struct ScopeData { + num_running_threads: AtomicUsize, + a_thread_panicked: AtomicBool, + main_thread: Thread, +} + +impl ScopeData { + pub(super) fn increment_num_running_threads(&self) { + // We check for 'overflow' with usize::MAX / 2, to make sure there's no + // chance it overflows to 0, which would result in unsoundness. + if self.num_running_threads.fetch_add(1, Ordering::Relaxed) > usize::MAX / 2 { + // This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles. + self.decrement_num_running_threads(false); + panic!("too many running threads in thread scope"); + } + } + pub(super) fn decrement_num_running_threads(&self, panic: bool) { + if panic { + self.a_thread_panicked.store(true, Ordering::Relaxed); + } + if self.num_running_threads.fetch_sub(1, Ordering::Release) == 1 { + self.main_thread.unpark(); + } + } +} + +/// Create a scope for spawning scoped threads. +/// +/// The function passed to `scope` will be provided a [`Scope`] object, +/// through which scoped threads can be [spawned][`Scope::spawn`]. +/// +/// Unlike non-scoped threads, scoped threads can borrow non-`'static` data, +/// as the scope guarantees all threads will be joined at the end of the scope. +/// +/// All threads spawned within the scope that haven't been manually joined +/// will be automatically joined before this function returns. +/// +/// # Panics +/// +/// If any of the automatically joined threads panicked, this function will panic. +/// +/// If you want to handle panics from spawned threads, +/// [`join`][ScopedJoinHandle::join] them before the end of the scope. +/// +/// # Example +/// +/// ``` +/// #![feature(scoped_threads)] +/// use std::thread; +/// +/// let mut a = vec![1, 2, 3]; +/// let mut x = 0; +/// +/// thread::scope(|s| { +/// s.spawn(|_| { +/// println!("hello from the first scoped thread"); +/// // We can borrow `a` here. +/// dbg!(&a); +/// }); +/// s.spawn(|_| { +/// println!("hello from the second scoped thread"); +/// // We can even mutably borrow `x` here, +/// // because no other threads are using it. +/// x += a[0] + a[2]; +/// }); +/// println!("hello from the main thread"); +/// }); +/// +/// // After the scope, we can modify and access our variables again: +/// a.push(4); +/// assert_eq!(x, a.len()); +/// ``` +#[track_caller] +pub fn scope<'env, F, T>(f: F) -> T +where + F: FnOnce(&Scope<'env>) -> T, +{ + let scope = Scope { + data: ScopeData { + num_running_threads: AtomicUsize::new(0), + main_thread: current(), + a_thread_panicked: AtomicBool::new(false), + }, + env: PhantomData, + }; + + // Run `f`, but catch panics so we can make sure to wait for all the threads to join. + let result = catch_unwind(AssertUnwindSafe(|| f(&scope))); + + // Wait until all the threads are finished. + while scope.data.num_running_threads.load(Ordering::Acquire) != 0 { + park(); + } + + // Throw any panic from `f`, or the return value of `f` if no thread panicked. + match result { + Err(e) => resume_unwind(e), + Ok(_) if scope.data.a_thread_panicked.load(Ordering::Relaxed) => { + panic!("a scoped thread panicked") + } + Ok(result) => result, + } +} + +impl<'env> Scope<'env> { + /// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it. + /// + /// Unlike non-scoped threads, threads spawned with this function may + /// borrow non-`'static` data from the outside the scope. See [`scope`] for + /// details. + /// + /// The join handle provides a [`join`] method that can be used to join the spawned + /// thread. If the spawned thread panics, [`join`] will return an [`Err`] containing + /// the panic payload. + /// + /// If the join handle is dropped, the spawned thread will implicitly joined at the + /// end of the scope. In that case, if the spawned thread panics, [`scope`] will + /// panic after all threads are joined. + /// + /// This call will create a thread using default parameters of [`Builder`]. + /// If you want to specify the stack size or the name of the thread, use + /// [`Builder::spawn_scoped`] instead. + /// + /// # Panics + /// + /// Panics if the OS fails to create a thread; use [`Builder::spawn_scoped`] + /// to recover from such errors. + /// + /// [`join`]: ScopedJoinHandle::join + pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T> + where + F: FnOnce(&Scope<'env>) -> T + Send + 'env, + T: Send + 'env, + { + Builder::new().spawn_scoped(self, f).expect("failed to spawn thread") + } +} + +impl Builder { + /// Spawns a new scoped thread using the settings set through this `Builder`. + /// + /// Unlike [`Scope::spawn`], this method yields an [`io::Result`] to + /// capture any failure to create the thread at the OS level. + /// + /// [`io::Result`]: crate::io::Result + /// + /// # Panics + /// + /// Panics if a thread name was set and it contained null bytes. + /// + /// # Example + /// + /// ``` + /// #![feature(scoped_threads)] + /// use std::thread; + /// + /// let mut a = vec![1, 2, 3]; + /// let mut x = 0; + /// + /// thread::scope(|s| { + /// thread::Builder::new() + /// .name("first".to_string()) + /// .spawn_scoped(s, |_| + /// { + /// println!("hello from the {:?} scoped thread", thread::current().name()); + /// // We can borrow `a` here. + /// dbg!(&a); + /// }) + /// .unwrap(); + /// thread::Builder::new() + /// .name("second".to_string()) + /// .spawn_scoped(s, |_| + /// { + /// println!("hello from the {:?} scoped thread", thread::current().name()); + /// // We can even mutably borrow `x` here, + /// // because no other threads are using it. + /// x += a[0] + a[2]; + /// }) + /// .unwrap(); + /// println!("hello from the main thread"); + /// }); + /// + /// // After the scope, we can modify and access our variables again: + /// a.push(4); + /// assert_eq!(x, a.len()); + /// ``` + pub fn spawn_scoped<'scope, 'env, F, T>( + self, + scope: &'scope Scope<'env>, + f: F, + ) -> io::Result> + where + F: FnOnce(&Scope<'env>) -> T + Send + 'env, + T: Send + 'env, + { + Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(|| f(scope), Some(&scope.data)) }?)) + } +} + +impl<'scope, T> ScopedJoinHandle<'scope, T> { + /// Extracts a handle to the underlying thread. + /// + /// # Examples + /// + /// ``` + /// #![feature(scoped_threads)] + /// #![feature(thread_is_running)] + /// + /// use std::thread; + /// + /// thread::scope(|s| { + /// let t = s.spawn(|_| { + /// println!("hello"); + /// }); + /// println!("thread id: {:?}", t.thread().id()); + /// }); + /// ``` + #[must_use] + pub fn thread(&self) -> &Thread { + &self.0.thread + } + + /// Waits for the associated thread to finish. + /// + /// This function will return immediately if the associated thread has already finished. + /// + /// In terms of [atomic memory orderings], the completion of the associated + /// thread synchronizes with this function returning. + /// In other words, all operations performed by that thread + /// [happen before](https://doc.rust-lang.org/nomicon/atomics.html#data-accesses) + /// all operations that happen after `join` returns. + /// + /// If the associated thread panics, [`Err`] is returned with the panic payload. + /// + /// [atomic memory orderings]: crate::sync::atomic + /// + /// # Examples + /// + /// ``` + /// #![feature(scoped_threads)] + /// #![feature(thread_is_running)] + /// + /// use std::thread; + /// + /// thread::scope(|s| { + /// let t = s.spawn(|_| { + /// panic!("oh no"); + /// }); + /// assert!(t.join().is_err()); + /// }); + /// ``` + pub fn join(self) -> Result { + self.0.join() + } + + /// Checks if the associated thread is still running its main function. + /// + /// This might return `false` for a brief moment after the thread's main + /// function has returned, but before the thread itself has stopped running. + #[unstable(feature = "thread_is_running", issue = "90470")] + pub fn is_running(&self) -> bool { + Arc::strong_count(&self.0.packet) > 1 + } +} + +impl<'env> fmt::Debug for Scope<'env> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Scope") + .field("num_running_threads", &self.data.num_running_threads.load(Ordering::Relaxed)) + .field("a_thread_panicked", &self.data.a_thread_panicked.load(Ordering::Relaxed)) + .field("main_thread", &self.data.main_thread) + .finish_non_exhaustive() + } +} + +impl<'scope, T> fmt::Debug for ScopedJoinHandle<'scope, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ScopedJoinHandle").finish_non_exhaustive() + } +}