Skip to content

Commit

Permalink
Vec::resize for bytes should be a single memset
Browse files Browse the repository at this point in the history
  • Loading branch information
scottmcm committed Jul 4, 2024
1 parent c422581 commit f4bb10b
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 75 deletions.
78 changes: 25 additions & 53 deletions library/alloc/src/vec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2561,7 +2561,7 @@ impl<T: Clone, A: Allocator> Vec<T, A> {
let len = self.len();

if new_len > len {
self.extend_with(new_len - len, value)
self.extend_trusted(core::iter::repeat_n(value, new_len - len));
} else {
self.truncate(new_len);
}
Expand Down Expand Up @@ -2673,38 +2673,6 @@ impl<T, A: Allocator, const N: usize> Vec<[T; N], A> {
}
}

impl<T: Clone, A: Allocator> Vec<T, A> {
#[cfg(not(no_global_oom_handling))]
/// Extend the vector by `n` clones of value.
fn extend_with(&mut self, n: usize, value: T) {
self.reserve(n);

unsafe {
let mut ptr = self.as_mut_ptr().add(self.len());
// Use SetLenOnDrop to work around bug where compiler
// might not realize the store through `ptr` through self.set_len()
// don't alias.
let mut local_len = SetLenOnDrop::new(&mut self.len);

// Write all elements except the last one
for _ in 1..n {
ptr::write(ptr, value.clone());
ptr = ptr.add(1);
// Increment the length in every step in case clone() panics
local_len.increment_len(1);
}

if n > 0 {
// We can write the last element directly without cloning needlessly
ptr::write(ptr, value);
local_len.increment_len(1);
}

// len set by scope guard
}
}
}

impl<T: PartialEq, A: Allocator> Vec<T, A> {
/// Removes consecutive repeated elements in the vector according to the
/// [`PartialEq`] trait implementation.
Expand Down Expand Up @@ -3083,32 +3051,36 @@ impl<T, A: Allocator> Vec<T, A> {
#[cfg(not(no_global_oom_handling))]
fn extend_trusted(&mut self, iterator: impl iter::TrustedLen<Item = T>) {
let (low, high) = iterator.size_hint();
if let Some(additional) = high {
debug_assert_eq!(
low,
additional,
"TrustedLen iterator's size hint is not exact: {:?}",
(low, high)
);
self.reserve(additional);
unsafe {
let ptr = self.as_mut_ptr();
let mut local_len = SetLenOnDrop::new(&mut self.len);
iterator.for_each(move |element| {
ptr::write(ptr.add(local_len.current_len()), element);
// Since the loop executes user code which can panic we have to update
// the length every step to correctly drop what we've written.
// NB can't overflow since we would have had to alloc the address space
local_len.increment_len(1);
});
}
} else {
if high.is_none() {
// Per TrustedLen contract a `None` upper bound means that the iterator length
// truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway.
// Since the other branch already panics eagerly (via `reserve()`) we do the same here.
// This avoids additional codegen for a fallback code path which would eventually
// panic anyway.
panic!("capacity overflow");
};

debug_assert_eq!(
Some(low),
high,
"TrustedLen iterator's size hint is not exact: {:?}",
(low, high)
);
self.reserve(low);

// SAFETY: From TrustedLen we know exactly how many slots we'll need,
// and we just reserved them. Thus we can write each element as we generate
// it into its final location without needing any further safety checks.
unsafe {
let ptr = self.as_mut_ptr();
let mut local_len = SetLenOnDrop::new(&mut self.len);
iterator.for_each(move |element| {
ptr::write(ptr.add(local_len.current_len()), element);
// Since the loop executes user code which can panic we have to update
// the length every step to correctly drop what we've written.
// NB can't overflow since we would have had to alloc the address space
local_len.increment_len_unchecked(1);
});
}
}

Expand Down
7 changes: 5 additions & 2 deletions library/alloc/src/vec/set_len_on_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ impl<'a> SetLenOnDrop<'a> {
SetLenOnDrop { local_len: *len, len }
}

/// # Safety
/// `self.current_len() + increment` must not overflow.
#[inline]
pub(super) fn increment_len(&mut self, increment: usize) {
self.local_len += increment;
pub(super) unsafe fn increment_len_unchecked(&mut self, increment: usize) {
// SAFETY: This is our precondition
self.local_len = unsafe { self.local_len.unchecked_add(increment) };
}

#[inline]
Expand Down
6 changes: 2 additions & 4 deletions library/alloc/src/vec/spec_from_elem.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use core::ptr;

use crate::alloc::Allocator;
use crate::raw_vec::RawVec;

Expand All @@ -13,7 +11,7 @@ pub(super) trait SpecFromElem: Sized {
impl<T: Clone> SpecFromElem for T {
default fn from_elem<A: Allocator>(elem: Self, n: usize, alloc: A) -> Vec<Self, A> {
let mut v = Vec::with_capacity_in(n, alloc);
v.extend_with(n, elem);
v.extend_trusted(core::iter::repeat_n(elem, n));
v
}
}
Expand All @@ -25,7 +23,7 @@ impl<T: Clone + IsZero> SpecFromElem for T {
return Vec { buf: RawVec::with_capacity_zeroed_in(n, alloc), len: n };
}
let mut v = Vec::with_capacity_in(n, alloc);
v.extend_with(n, elem);
v.extend_trusted(core::iter::repeat_n(elem, n));
v
}
}
Expand Down
59 changes: 43 additions & 16 deletions library/core/src/iter/sources/repeat_n.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,52 @@ impl<A> Drop for RepeatN<A> {
}
}

trait SpecRepeatN<A> {
/// Reads an item after `self.count` has been decreased
///
/// # Safety
///
/// Must be called only once after lowering a count.
///
/// Will cause double-frees if used multiple times or without checking
/// that the iterator was originally non-empty beforehand.
unsafe fn spec_read_unchecked(&mut self) -> A;
}

impl<A: Clone> SpecRepeatN<A> for RepeatN<A> {
default unsafe fn spec_read_unchecked(&mut self) -> A {
if self.count == 0 {
// SAFETY: we just lowered the count to zero so it won't be dropped
// later, and thus it's okay to take it here.
unsafe { ManuallyDrop::take(&mut self.element) }
} else {
A::clone(&self.element)
}
}
}

impl<A: Copy> SpecRepeatN<A> for RepeatN<A> {
unsafe fn spec_read_unchecked(&mut self) -> A {
// For `Copy` types, we can always just read the item directly,
// so skip having a branch that would need to be optimized out.
*self.element
}
}

#[unstable(feature = "iter_repeat_n", issue = "104434")]
impl<A: Clone> Iterator for RepeatN<A> {
type Item = A;

#[inline]
fn next(&mut self) -> Option<A> {
if self.count == 0 {
return None;
}

self.count -= 1;
Some(if self.count == 0 {
// SAFETY: the check above ensured that the count used to be non-zero,
// so element hasn't been dropped yet, and we just lowered the count to
// zero so it won't be dropped later, and thus it's okay to take it here.
unsafe { ManuallyDrop::take(&mut self.element) }
// Using checked_sub as a safe way to get unchecked_sub
if let Some(new_count) = self.count.checked_sub(1) {
self.count = new_count;
// SAFETY: Just decreased the count.
unsafe { Some(self.spec_read_unchecked()) }
} else {
A::clone(&self.element)
})
None
}
}

#[inline]
Expand All @@ -143,12 +170,12 @@ impl<A: Clone> Iterator for RepeatN<A> {
self.take_element();
}

if skip > len {
if let Some(new_count) = len.checked_sub(skip) {
self.count = new_count;
Ok(())
} else {
// SAFETY: we just checked that the difference is positive
Err(unsafe { NonZero::new_unchecked(skip - len) })
} else {
self.count = len - skip;
Ok(())
}
}

Expand Down
23 changes: 23 additions & 0 deletions tests/codegen/vec-of-bytes-memset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//@ compile-flags: -O
//@ only-64bit

#![crate_type = "lib"]

// CHECK-LABEL: @resize_bytes_is_one_memset
#[no_mangle]
pub fn resize_bytes_is_one_memset(x: &mut Vec<u8>) {
// CHECK: call void @llvm.memset.p0.i64({{.+}}, i8 123, i64 456789, i1 false)
let new_len = x.len() + 456789;
x.resize(new_len, 123);
}

#[derive(Copy, Clone)]
struct ByteNewtype(i8);

// CHECK-LABEL: @from_elem_is_one_memset
#[no_mangle]
pub fn from_elem_is_one_memset() -> Vec<ByteNewtype> {
// CHECK: %[[P:.+]] = tail call{{.+}}@__rust_alloc(i64 noundef 123456, i64 noundef 1)
// CHECK: call void @llvm.memset.p0.i64({{.+}} %[[P]], i8 42, i64 123456, i1 false)
vec![ByteNewtype(42); 123456]
}

0 comments on commit f4bb10b

Please sign in to comment.