Skip to content

Commit

Permalink
Merge branch 'JonasAlaif-checks-at-compile'
Browse files Browse the repository at this point in the history
Merge #2.
  • Loading branch information
taylordotfish committed Apr 27, 2023
2 parents 0ba42e4 + afa1229 commit b2a7511
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 103 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ categories = ["data-structures", "no-std"]

[features]
fallback = []

[dev-dependencies]
compiletest_rs = "0.10"
20 changes: 12 additions & 8 deletions src/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,25 @@
* limitations under the License.
*/

use super::mask;
use core::cmp::Ordering;
use core::hash::{Hash, Hasher};
use core::ptr::NonNull;

pub struct PtrImpl<T, const BITS: usize> {
pub struct PtrImpl<T, const BITS: u32> {
ptr: NonNull<T>,
tag: usize,
}

impl<T, const BITS: usize> PtrImpl<T, BITS> {
impl<T, const BITS: u32> PtrImpl<T, BITS> {
pub fn new(ptr: NonNull<T>, tag: usize) -> Self {
// Even though these checks are not necessary here, we want to ensure
// that if the fallback compiles then so would the default.
let _ = Self::T_ALIGNED_PO2;
let _ = Self::T_SIZE_GE_ALIGNMENT;
let _ = Self::ENOUGH_ALIGNMENT_BITS;
Self {
ptr,
tag: tag & mask(BITS),
tag: tag & Self::MASK,
}
}

Expand All @@ -39,7 +43,7 @@ impl<T, const BITS: usize> PtrImpl<T, BITS> {
}
}

impl<T, const BITS: usize> Clone for PtrImpl<T, BITS> {
impl<T, const BITS: u32> Clone for PtrImpl<T, BITS> {
fn clone(&self) -> Self {
Self {
ptr: self.ptr,
Expand All @@ -48,19 +52,19 @@ impl<T, const BITS: usize> Clone for PtrImpl<T, BITS> {
}
}

impl<T, const BITS: usize> PartialEq for PtrImpl<T, BITS> {
impl<T, const BITS: u32> PartialEq for PtrImpl<T, BITS> {
fn eq(&self, other: &Self) -> bool {
(self.ptr, self.tag) == (other.ptr, other.tag)
}
}

impl<T, const BITS: usize> Ord for PtrImpl<T, BITS> {
impl<T, const BITS: u32> Ord for PtrImpl<T, BITS> {
fn cmp(&self, other: &Self) -> Ordering {
(self.ptr, self.tag).cmp(&(other.ptr, other.tag))
}
}

impl<T, const BITS: usize> Hash for PtrImpl<T, BITS> {
impl<T, const BITS: u32> Hash for PtrImpl<T, BITS> {
fn hash<H: Hasher>(&self, state: &mut H) {
(self.ptr, self.tag).hash(state);
}
Expand Down
115 changes: 62 additions & 53 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* limitations under the License.
*/

#![no_std]
#![cfg_attr(not(test), no_std)]
#![cfg_attr(has_unsafe_op_in_unsafe_fn, deny(unsafe_op_in_unsafe_fn))]
#![warn(clippy::pedantic)]
#![allow(clippy::default_trait_access)]
Expand Down Expand Up @@ -114,24 +114,7 @@ use core::hash::{Hash, Hasher};
use core::mem;
use core::ptr::NonNull;

/// Calculates 2 to the power of `bits`, and panics if the result wouldn't fit
/// in a `usize`. This is the alignment required to store `bits` tag bits.
fn alignment(bits: usize) -> usize {
use core::convert::TryFrom;
u32::try_from(bits)
.ok()
.and_then(|n| 1_usize.checked_shl(n))
.expect("`bits` is too large")
}

/// Returns the bitmask that should be applied to the tag to ensure that it is
/// smaller than the alignment ([`alignment(bits)`](alignment)). Since the
/// alignment is always a power of 2, this function simply subtracts 1 from
/// the alignment.
fn mask(bits: usize) -> usize {
alignment(bits) - 1
}

#[cfg(not(feature = "fallback"))]
mod messages;
#[cfg_attr(feature = "fallback", path = "fallback.rs")]
mod ptr;
Expand All @@ -140,11 +123,47 @@ use ptr::PtrImpl;
#[cfg(test)]
mod tests;

impl<T, const BITS: usize> Copy for PtrImpl<T, BITS> {}
impl<T, const BITS: u32> PtrImpl<T, BITS> {
/// Compile-time check of our assumption about the alignment of `T`. This
/// should always succeed.
const T_ALIGNED_PO2: () = assert!(
mem::align_of::<T>().is_power_of_two(),
"unexpected alignment of `T`"
);
/// Compile-time check of our assumption about the size vs alignment of
/// `T`. This should always succeed.
const T_SIZE_GE_ALIGNMENT: () = assert!(
mem::size_of::<T>() == 0
|| mem::size_of::<T>() >= mem::align_of::<T>(),
"unexpected `size_of` vs `align_of` for `T`"
);
/// Compile-time check that the requested `BITS` is small enough.
const ENOUGH_ALIGNMENT_BITS: () = assert!(
mem::align_of::<T>().trailing_zeros() >= BITS,
"alignment of `T` must be greater or equal to `BITS`"
);

impl<T, const BITS: usize> Eq for PtrImpl<T, BITS> {}
/// Calculates 2 to the power of `BITS`, and panics if the result wouldn't
/// fit in a `usize`. This is the alignment required to store `BITS`
/// tag bits.
const ALIGNMENT: usize = if let Some(align) = 1_usize.checked_shl(BITS) {
align
} else {
panic!("2 to the power of `BITS` does not fit in a `usize`")
};

impl<T, const BITS: usize> PartialOrd for PtrImpl<T, BITS> {
/// The bitmask that should be applied to the tag to ensure that it
/// is smaller than [`Self::ALIGNMENT`].
/// Since the alignment is always a power of 2, this simply
/// subtracts 1 from the alignment.
const MASK: usize = Self::ALIGNMENT - 1;
}

impl<T, const BITS: u32> Copy for PtrImpl<T, BITS> {}

impl<T, const BITS: u32> Eq for PtrImpl<T, BITS> {}

impl<T, const BITS: u32> PartialOrd for PtrImpl<T, BITS> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
Expand All @@ -163,20 +182,18 @@ impl<T, const BITS: usize> PartialOrd for PtrImpl<T, BITS> {
/// `BITS` specifies how many bits are used for the tag. The alignment of `T`
/// must be large enough to store this many bits; see [`Self::new`].
#[repr(transparent)]
pub struct TaggedPtr<T, const BITS: usize>(PtrImpl<T, BITS>);
pub struct TaggedPtr<T, const BITS: u32>(PtrImpl<T, BITS>);

impl<T, const BITS: usize> TaggedPtr<T, BITS> {
impl<T, const BITS: u32> TaggedPtr<T, BITS> {
/// Creates a new tagged pointer. Only the lower `BITS` bits of `tag` are
/// stored.
/// stored. A check is performed at compile time that the alignment of `T`
/// is not less than 2<sup>`BITS`</sup> (`1 << BITS`). This ensures that
/// all properly aligned pointers to `T` will be aligned enough to store
/// the specified number of bits of the tag.
///
/// # Panics
///
/// This function panics if the alignment of `T` is less than
/// 2<sup>`BITS`</sup> (`1 << BITS`). This ensures that all properly
/// aligned pointers to `T` will be aligned enough to store the specified
/// number of bits of the tag.
///
/// `ptr` should be “dereferencable” in the sense defined by
/// `ptr` should be “dereferenceable” in the sense defined by
/// [`core::ptr`](core::ptr#safety).[^1] If it is not, this function or
/// methods of [`TaggedPtr`] may panic.
///
Expand All @@ -186,16 +203,8 @@ impl<T, const BITS: usize> TaggedPtr<T, BITS> {
/// 2<sup>`BITS`</sup>, this function may panic.
///
/// [^1]: It is permissible for only the first 2<sup>`BITS`</sup> bytes of
/// `ptr` to be dereferencable.
/// `ptr` to be dereferenceable.
pub fn new(ptr: NonNull<T>, tag: usize) -> Self {
// This should always be true.
assert!(mem::align_of::<T>().is_power_of_two());
assert!(
mem::align_of::<T>().trailing_zeros() as usize >= BITS,
"alignment of `T` must be at least 2 to the power of `BITS`",
);
// This should always be true.
assert!(mem::size_of::<T>().max(1) >= mem::align_of::<T>());
Self(PtrImpl::new(ptr, tag))
}

Expand All @@ -206,7 +215,7 @@ impl<T, const BITS: usize> TaggedPtr<T, BITS> {
/// # Panics
///
/// If the pointer provided to [`Self::new`] wasn't
/// [“dereferencable”](core::ptr#safety), this method may panic.
/// [“dereferenceable”](core::ptr#safety), this method may panic.
pub fn get(self) -> (NonNull<T>, usize) {
self.0.get()
}
Expand All @@ -217,7 +226,7 @@ impl<T, const BITS: usize> TaggedPtr<T, BITS> {
/// # Panics
///
/// If the pointer provided to [`Self::new`] wasn't
/// [“dereferencable”](core::ptr#safety), this method may panic.
/// [“dereferenceable”](core::ptr#safety), this method may panic.
pub fn ptr(self) -> NonNull<T> {
self.get().0
}
Expand All @@ -229,7 +238,7 @@ impl<T, const BITS: usize> TaggedPtr<T, BITS> {
/// ```
/// # use {core::ptr::NonNull, tagged_pointer::TaggedPtr};
/// # trait Ext<T> { fn set_ptr(&mut self, ptr: NonNull<T>); }
/// # impl<T, const BITS: usize> Ext<T> for TaggedPtr<T, BITS> {
/// # impl<T, const BITS: u32> Ext<T> for TaggedPtr<T, BITS> {
/// # fn set_ptr(&mut self, ptr: NonNull<T>) {
/// *self = Self::new(ptr, self.tag());
/// # }}
Expand All @@ -248,7 +257,7 @@ impl<T, const BITS: usize> TaggedPtr<T, BITS> {
/// # Panics
///
/// If the pointer provided to [`Self::new`] wasn't
/// [“dereferencable”](core::ptr#safety), this method may panic.
/// [“dereferenceable”](core::ptr#safety), this method may panic.
pub fn tag(self) -> usize {
self.get().1
}
Expand All @@ -260,7 +269,7 @@ impl<T, const BITS: usize> TaggedPtr<T, BITS> {
/// ```
/// # use tagged_pointer::TaggedPtr;
/// # trait Ext { fn set_tag(&mut self, tag: usize); }
/// # impl<T, const BITS: usize> Ext for TaggedPtr<T, BITS> {
/// # impl<T, const BITS: u32> Ext for TaggedPtr<T, BITS> {
/// # fn set_tag(&mut self, tag: usize) {
/// *self = Self::new(self.ptr(), tag);
/// # }}
Expand All @@ -274,41 +283,41 @@ impl<T, const BITS: usize> TaggedPtr<T, BITS> {
}
}

impl<T, const BITS: usize> Clone for TaggedPtr<T, BITS> {
impl<T, const BITS: u32> Clone for TaggedPtr<T, BITS> {
fn clone(&self) -> Self {
Self(self.0)
}
}

impl<T, const BITS: usize> Copy for TaggedPtr<T, BITS> {}
impl<T, const BITS: u32> Copy for TaggedPtr<T, BITS> {}

impl<T, const BITS: usize> PartialEq for TaggedPtr<T, BITS> {
impl<T, const BITS: u32> PartialEq for TaggedPtr<T, BITS> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

impl<T, const BITS: usize> Eq for TaggedPtr<T, BITS> {}
impl<T, const BITS: u32> Eq for TaggedPtr<T, BITS> {}

impl<T, const BITS: usize> PartialOrd for TaggedPtr<T, BITS> {
impl<T, const BITS: u32> PartialOrd for TaggedPtr<T, BITS> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl<T, const BITS: usize> Ord for TaggedPtr<T, BITS> {
impl<T, const BITS: u32> Ord for TaggedPtr<T, BITS> {
fn cmp(&self, other: &Self) -> Ordering {
self.0.cmp(&other.0)
}
}

impl<T, const BITS: usize> Hash for TaggedPtr<T, BITS> {
impl<T, const BITS: u32> Hash for TaggedPtr<T, BITS> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}

impl<T, const BITS: usize> fmt::Debug for TaggedPtr<T, BITS> {
impl<T, const BITS: u32> fmt::Debug for TaggedPtr<T, BITS> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (ptr, tag) = self.get();
f.debug_struct("TaggedPtr")
Expand Down
2 changes: 1 addition & 1 deletion src/messages/align-offset-failed
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ reasons (most likely one of the first two):

* The pointer passed to `TaggedPtr::new` wasn't aligned enough.

* The pointer passed to `TaggedPtr::new` wasn't "dereferencable" in
* The pointer passed to `TaggedPtr::new` wasn't "dereferenceable" in
the sense defined by the documentation for `std::ptr`.

* The current implementation of `align_offset` sometimes or always
Expand Down
2 changes: 1 addition & 1 deletion src/messages/wrapped-to-null
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
`ptr` became null after adding `tag`. This shouldn't happen if `ptr` is
"dereferencable" in the sense defined by `std::ptr`.
"dereferenceable" in the sense defined by `std::ptr`.
32 changes: 18 additions & 14 deletions src/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,29 @@
* limitations under the License.
*/

use super::{alignment, mask, messages};
use super::messages;
use core::cmp::Ordering;
use core::hash::{Hash, Hasher};
use core::marker::PhantomData;
use core::ptr::NonNull;

#[repr(transparent)]
pub struct PtrImpl<T, const BITS: usize>(NonNull<u8>, PhantomData<NonNull<T>>);
pub struct PtrImpl<T, const BITS: u32>(NonNull<u8>, PhantomData<NonNull<T>>);

impl<T, const BITS: usize> PtrImpl<T, BITS> {
impl<T, const BITS: u32> PtrImpl<T, BITS> {
pub fn new(ptr: NonNull<T>, tag: usize) -> Self {
// Compile-time checks.
let _ = Self::T_ALIGNED_PO2;
let _ = Self::T_SIZE_GE_ALIGNMENT;
let _ = Self::ENOUGH_ALIGNMENT_BITS;
let ptr = ptr.as_ptr().cast::<u8>();
// Keep only the lower `BITS` bits of the tag.
let tag = tag & mask(BITS);
let offset = ptr.align_offset(alignment(BITS));
let tag = tag & Self::MASK;
let offset = ptr.align_offset(Self::ALIGNMENT);
assert!(offset != usize::MAX, "{}", messages::ALIGN_OFFSET_FAILED);
// Check that none of the bits we're about to use are already set.
// We expect that `offset <= mask(BITS)` but do the `&` just in case.
assert!(offset & mask(BITS) == 0, "`ptr` is not aligned enough");
// We expect that `offset <= Self::MASK` but do the `&` just in case.
assert!(offset & Self::MASK == 0, "`ptr` is not aligned enough");
Self(
NonNull::new(ptr.wrapping_add(tag))
.expect(messages::WRAPPED_TO_NULL),
Expand All @@ -44,10 +48,10 @@ impl<T, const BITS: usize> PtrImpl<T, BITS> {

pub fn get(self) -> (NonNull<T>, usize) {
let ptr = self.0.as_ptr();
let offset = ptr.align_offset(alignment(BITS));
let offset = ptr.align_offset(Self::ALIGNMENT);
assert!(offset != usize::MAX, "{}", messages::ALIGN_OFFSET_FAILED);
// We expect that `offset <= mask(BITS)` but do the `&` just in case.
let tag = (alignment(BITS) - offset) & mask(BITS);
// We expect that `offset <= Self::MASK` but do the `&` just in case.
let tag = (Self::ALIGNMENT - offset) & Self::MASK;
let ptr = ptr.wrapping_sub(tag).cast::<T>();
debug_assert!(!ptr.is_null());
// SAFETY: `self.0` was created by adding `tag` to the `ptr` parameter
Expand All @@ -58,25 +62,25 @@ impl<T, const BITS: usize> PtrImpl<T, BITS> {
}
}

impl<T, const BITS: usize> Clone for PtrImpl<T, BITS> {
impl<T, const BITS: u32> Clone for PtrImpl<T, BITS> {
fn clone(&self) -> Self {
Self(self.0, self.1)
}
}

impl<T, const BITS: usize> PartialEq for PtrImpl<T, BITS> {
impl<T, const BITS: u32> PartialEq for PtrImpl<T, BITS> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

impl<T, const BITS: usize> Ord for PtrImpl<T, BITS> {
impl<T, const BITS: u32> Ord for PtrImpl<T, BITS> {
fn cmp(&self, other: &Self) -> Ordering {
self.0.cmp(&other.0)
}
}

impl<T, const BITS: usize> Hash for PtrImpl<T, BITS> {
impl<T, const BITS: u32> Hash for PtrImpl<T, BITS> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
Expand Down
Loading

0 comments on commit b2a7511

Please sign in to comment.