From c449867236d7023a46b80e41608b5bde8ece0cb0 Mon Sep 17 00:00:00 2001 From: Jay Huh Date: Fri, 22 Mar 2024 14:51:16 -0700 Subject: [PATCH] MultiCfIterator Impl Follow up (#12465) Summary: As a follow up for https://github.com/facebook/rocksdb/issues/12422 , this PR includes the following two changes. - Removal of `direction_` in the MultiCfIterator - Use of Member Func Template instead of `std::function` Pull Request resolved: https://github.com/facebook/rocksdb/pull/12465 Test Plan: ``` ./multi_cf_iterator_test ``` Reviewed By: pdillinger, ltamasi Differential Revision: D55208448 Pulled By: jaykorean fbshipit-source-id: 8b3167c1d59839d076afc29097b5ad21a453460a --- db/multi_cf_iterator.cc | 55 ++++++++++++++++++++--------------------- db/multi_cf_iterator.h | 48 ++++++++++------------------------- 2 files changed, 40 insertions(+), 63 deletions(-) diff --git a/db/multi_cf_iterator.cc b/db/multi_cf_iterator.cc index 4398edca72c..80e4171d54d 100644 --- a/db/multi_cf_iterator.cc +++ b/db/multi_cf_iterator.cc @@ -9,11 +9,10 @@ namespace ROCKSDB_NAMESPACE { -void MultiCfIterator::SeekCommon( - const std::function& child_seek_func, - Direction direction) { - direction_ = direction; - Reset(); +template +void MultiCfIterator::SeekCommon(BinaryHeap& heap, + ChildSeekFuncType child_seek_func) { + heap.clear(); int i = 0; for (auto& cfh_iter_pair : cfh_iter_pairs_) { auto& cfh = cfh_iter_pair.first; @@ -21,13 +20,7 @@ void MultiCfIterator::SeekCommon( child_seek_func(iter.get()); if (iter->Valid()) { assert(iter->status().ok()); - if (direction_ == kReverse) { - auto& max_heap = std::get(heap_); - max_heap.push(MultiCfIteratorInfo{iter.get(), cfh, i}); - } else { - auto& min_heap = std::get(heap_); - min_heap.push(MultiCfIteratorInfo{iter.get(), cfh, i}); - } + heap.push(MultiCfIteratorInfo{iter.get(), cfh, i}); } else { considerStatus(iter->status()); } @@ -35,9 +28,9 @@ void MultiCfIterator::SeekCommon( } } -template -void MultiCfIterator::AdvanceIterator( - BinaryHeap& heap, const std::function& advance_func) { +template +void MultiCfIterator::AdvanceIterator(BinaryHeap& heap, + AdvanceFuncType advance_func) { // 1. Keep the top iterator (by popping it from the heap) // 2. Make sure all others have iterated past the top iterator key slice // 3. Advance the top iterator, and add it back to the heap if valid @@ -70,33 +63,39 @@ void MultiCfIterator::AdvanceIterator( } void MultiCfIterator::SeekToFirst() { - SeekCommon([](Iterator* iter) { iter->SeekToFirst(); }, kForward); + auto& min_heap = GetHeap([this]() { InitMinHeap(); }); + SeekCommon(min_heap, [](Iterator* iter) { iter->SeekToFirst(); }); } void MultiCfIterator::Seek(const Slice& target) { - SeekCommon([&target](Iterator* iter) { iter->Seek(target); }, kForward); + auto& min_heap = GetHeap([this]() { InitMinHeap(); }); + SeekCommon(min_heap, [&target](Iterator* iter) { iter->Seek(target); }); } void MultiCfIterator::SeekToLast() { - SeekCommon([](Iterator* iter) { iter->SeekToLast(); }, kReverse); + auto& max_heap = GetHeap([this]() { InitMaxHeap(); }); + SeekCommon(max_heap, [](Iterator* iter) { iter->SeekToLast(); }); } void MultiCfIterator::SeekForPrev(const Slice& target) { - SeekCommon([&target](Iterator* iter) { iter->SeekForPrev(target); }, - kReverse); + auto& max_heap = GetHeap([this]() { InitMaxHeap(); }); + SeekCommon(max_heap, + [&target](Iterator* iter) { iter->SeekForPrev(target); }); } void MultiCfIterator::Next() { assert(Valid()); - if (direction_ != kForward) { - SwitchToDirection(kForward); - } - auto& min_heap = std::get(heap_); + auto& min_heap = GetHeap([this]() { + Slice target = key(); + InitMinHeap(); + Seek(target); + }); AdvanceIterator(min_heap, [](Iterator* iter) { iter->Next(); }); } void MultiCfIterator::Prev() { assert(Valid()); - if (direction_ != kReverse) { - SwitchToDirection(kReverse); - } - auto& max_heap = std::get(heap_); + auto& max_heap = GetHeap([this]() { + Slice target = key(); + InitMaxHeap(); + SeekForPrev(target); + }); AdvanceIterator(max_heap, [](Iterator* iter) { iter->Prev(); }); } diff --git a/db/multi_cf_iterator.h b/db/multi_cf_iterator.h index 4269422b3c1..cdd09c16df0 100644 --- a/db/multi_cf_iterator.h +++ b/db/multi_cf_iterator.h @@ -86,13 +86,10 @@ class MultiCfIterator : public Iterator { MultiCfIterHeap heap_; - enum Direction : uint8_t { kForward, kReverse }; - Direction direction_ = kForward; - // TODO: Lower and Upper bounds Iterator* current() const { - if (direction_ == kReverse) { + if (std::holds_alternative(heap_)) { auto& max_heap = std::get(heap_); return max_heap.top().iterator; } @@ -114,7 +111,7 @@ class MultiCfIterator : public Iterator { } bool Valid() const override { - if (direction_ == kReverse) { + if (std::holds_alternative(heap_)) { auto& max_heap = std::get(heap_); return !max_heap.empty() && status_.ok(); } @@ -128,21 +125,13 @@ class MultiCfIterator : public Iterator { status_ = std::move(s); } } - void Reset() { - std::visit(overload{[&](MultiCfMinHeap& min_heap) -> void { - min_heap.clear(); - if (direction_ == kReverse) { - InitMaxHeap(); - } - }, - [&](MultiCfMaxHeap& max_heap) -> void { - max_heap.clear(); - if (direction_ == kForward) { - InitMinHeap(); - } - }}, - heap_); - status_ = Status::OK(); + + template + HeapType& GetHeap(InitFunc initFunc) { + if (!std::holds_alternative(heap_)) { + initFunc(); + } + return std::get(heap_); } void InitMinHeap() { @@ -154,21 +143,10 @@ class MultiCfIterator : public Iterator { MultiCfHeapItemComparator>(comparator_)); } - void SwitchToDirection(Direction new_direction) { - assert(direction_ != new_direction); - Slice target = key(); - if (new_direction == kForward) { - Seek(target); - } else { - SeekForPrev(target); - } - } - - void SeekCommon(const std::function& child_seek_func, - Direction direction); - template - void AdvanceIterator(BinaryHeap& heap, - const std::function& advance_func); + template + void SeekCommon(BinaryHeap& heap, ChildSeekFuncType child_seek_func); + template + void AdvanceIterator(BinaryHeap& heap, AdvanceFuncType advance_func); void SeekToFirst() override; void SeekToLast() override;