Skip to content

Commit

Permalink
Auto merge of rust-lang#127007 - krtab:improv_binary_search, r=<try>
Browse files Browse the repository at this point in the history
Improve slice::binary_search_by

This PR aims to improve the performances of std::slice::binary_search.

**EDIT: The proposed implementation changed so the rest of this comment is outdated. See rust-lang#127007 (comment) for an up to date presentation of the PR.**

It reduces the total instruction count for the `u32` monomorphization, but maybe more remarkably, removes 2 of the 12 instructions of the main loop (on x86).

It changes `test_binary_search_implementation_details()` so may warrant a crater run.

I will document it much more if this is shown to be interesting on benchmarks. Could we start with a timer run first?

**Before the PR**

```asm
        mov     eax, 1
        test    rsi, rsi
        je      .LBB0_1
        mov     rcx, rdx
        mov     rdx, rsi
        mov     ecx, dword ptr [rcx]
        xor     esi, esi
        mov     r8, rdx
.LBB0_3:
        shr     rdx
        add     rdx, rsi
        mov     r9d, dword ptr [rdi + 4*rdx]
        cmp     r9d, ecx
        je      .LBB0_4
        lea     r10, [rdx + 1]
        cmp     r9d, ecx
        cmova   r8, rdx
        cmovb   rsi, r10
        mov     rdx, r8
        sub     rdx, rsi
        ja      .LBB0_3
        mov     rdx, rsi
        ret
.LBB0_1:
        xor     edx, edx
        ret
.LBB0_4:
        xor     eax, eax
        ret
```

**After the PR**

```asm
        mov     ecx, dword ptr [rdx]
        xor     eax, eax
        xor     edx, edx
.LBB1_1:
        cmp     rsi, 1
        jbe     .LBB1_2
        mov     r9, rsi
        shr     r9
        lea     r8, [r9 + rdx]
        sub     rsi, r9
        cmp     dword ptr [rdi + 4*r8], ecx
        cmovb   rdx, r8
        cmova   rsi, r9
        jne     .LBB1_1
        mov     rdx, r8
        ret
.LBB1_2:
        test    rsi, rsi
        je      .LBB1_3
        xor     eax, eax
        cmp     dword ptr [rdi + 4*rdx], ecx
        setne   al
        adc     rdx, 0
        ret
.LBB1_3:
        mov     eax, 1
        ret
```
  • Loading branch information
bors committed Jun 29, 2024
2 parents d38cd22 + 2f5eec9 commit 3dc1b1e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
28 changes: 19 additions & 9 deletions library/core/src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2787,15 +2787,28 @@ impl<T> [T] {
where
F: FnMut(&'a T) -> Ordering,
{
if T::IS_ZST {
let res = if self.len() == 0 {
Err(0)
} else {
match f(&self[0]) {
Less => Err(self.len()),
Equal => Ok(0),
Greater => Err(0),
}
};
return res;
}
// INVARIANTS:
// - 0 <= left <= left + size = right <= self.len()
// - 0 <= left <= right <= self.len()
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = self.len();
let mut right = self.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = left + size / 2;
// This is an okay way to compute the mean because left and right are
// <= isize::MAX so the addition won't overflow
let mid = (left + right) / 2;

// SAFETY: the while condition means `size` is strictly positive, so
// `size/2 < size`. Thus `left + size/2 < left + size`, which
Expand All @@ -2807,19 +2820,16 @@ impl<T> [T] {
// fewer branches and instructions than if/else or matching on
// cmp::Ordering.
// This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx.

left = if cmp == Less { mid + 1 } else { left };
right = if cmp == Greater { mid } else { right };
if cmp == Equal {
// SAFETY: same as the `get_unchecked` above
unsafe { hint::assert_unchecked(mid < self.len()) };
return Ok(mid);
}

size = right - left;
}

// SAFETY: directly true from the overall invariant.
// Note that this is `<=`, unlike the assume in the `Ok` path.
// SAFETY: yolo
unsafe { hint::assert_unchecked(left <= self.len()) };
Err(left)
}
Expand Down
4 changes: 2 additions & 2 deletions library/core/tests/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ fn test_binary_search() {
assert_eq!(b.binary_search(&8), Err(5));

let b = [(); usize::MAX];
assert_eq!(b.binary_search(&()), Ok(usize::MAX / 2));
assert_eq!(b.binary_search(&()), Ok(0));
}

#[test]
fn test_binary_search_by_overflow() {
let b = [(); usize::MAX];
assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(usize::MAX / 2));
assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(0));
assert_eq!(b.binary_search_by(|_| Ordering::Greater), Err(0));
assert_eq!(b.binary_search_by(|_| Ordering::Less), Err(usize::MAX));
}
Expand Down

0 comments on commit 3dc1b1e

Please sign in to comment.