Skip to content

Commit

Permalink
Add Chain to allow iterator range concatenation.
Browse files Browse the repository at this point in the history
  • Loading branch information
ben-e-whitney committed Jun 28, 2022
1 parent 6f43e8f commit 15d5710
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 0 deletions.
87 changes: 87 additions & 0 deletions include/utilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iterator>
#include <memory>
#include <utility>
#include <vector>

namespace mgard {

Expand Down Expand Up @@ -545,6 +546,92 @@ class Bits::iterator {
unsigned char offset;
};

//! Concatenated iterator ranges.
//!
//! Approximate Python's `itertools.chain` generator.
template <typename It> class Chain {
public:
//! Constructor.
//!
//!\param segments Beginnings and lengths of iterator ranges.
Chain(const std::vector<std::pair<It, std::size_t>> &segments);

// Forward declaration.
class iterator;

//! Return an iterator to the beginning of the enumeration.
iterator begin() const;

//! Return an iterator to the end of the enumeration.
iterator end() const;

//! Beginnings and lengths of iterator ranges.
std::vector<std::pair<It, std::size_t>> segments;
};

//! Equality comparison.
template <typename It> bool operator==(const Chain<It> &a, const Chain<It> &b);

//! Inequality comparison.
template <typename It> bool operator!=(const Chain<It> &a, const Chain<It> &b);

//! Iterator over concatenated iterator ranges.
template <typename It> class Chain<It>::iterator {
public:
//! Category of the iterator.
using iterator_category = std::forward_iterator_tag;
//! Type iterated over.
using value_type = typename std::iterator_traits<It>::value_type;
//! Type for distance between iterators.
using difference_type = typename std::iterator_traits<It>::difference_type;
//! Pointer to `value_type`.
using pointer = typename std::iterator_traits<It>::pointer;
//! Type returned by the dereference operator.
using reference = typename std::iterator_traits<It>::reference;

//! Constructor.
//!
//!\param iterable Associated chain.
//!\param q Iterator to current segment.
iterator(
const Chain &iterable,
const typename std::vector<std::pair<It, std::size_t>>::const_iterator q);

//! Equality comparison.
bool operator==(const iterator &other) const;

//! Inequality comparison.
bool operator!=(const iterator &other) const;

//! Preincrement.
iterator &operator++();

//! Postincrement.
iterator operator++(int);

//! Dereference.
reference operator*() const;

private:
//! Associated bit range.
const Chain &iterable;

//! Iterator to current segment.
typename std::vector<std::pair<It, std::size_t>>::const_iterator q;

//! Position in the current segment.
It p;

//! Distance from the beginning of the current segment.
std::size_t i;

//! Length of the current segment.
std::size_t n;

//! Zero `i`; populate `p` and `n` from `q` if not at end.
void conditionally_start_segment();
};

} // namespace mgard

#include "utilities.tpp"
Expand Down
76 changes: 76 additions & 0 deletions include/utilities.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,80 @@ template <typename T>
MemoryBuffer<T>::MemoryBuffer(const std::size_t size)
: MemoryBuffer(new T[size], size) {}

template <typename It>
Chain<It>::Chain(const std::vector<std::pair<It, std::size_t>> &segments)
: segments(segments) {}

template <typename It> bool operator==(const Chain<It> &a, const Chain<It> &b) {
return a.segments == b.segments;
}

template <typename It> bool operator!=(const Chain<It> &a, const Chain<It> &b) {
return !operator==(a, b);
}

template <typename It> typename Chain<It>::iterator Chain<It>::begin() const {
return {*this, segments.begin()};
}

template <typename It> typename Chain<It>::iterator Chain<It>::end() const {
return {*this, segments.end()};
}

template <typename It>
Chain<It>::iterator::iterator(
const Chain<It> &iterable,
const typename std::vector<std::pair<It, std::size_t>>::const_iterator q)
: iterable(iterable), q(q) {
conditionally_start_segment();
}

template <typename It> void Chain<It>::iterator::conditionally_start_segment() {
i = 0;
if (q != iterable.segments.end()) {
const std::pair<It, std::size_t> pair = *q;
p = pair.first;
n = pair.second;
if (not n) {
++q;
conditionally_start_segment();
}
}
}

template <typename It>
bool Chain<It>::iterator::
operator==(const typename Chain<It>::iterator &other) const {
return i == other.i and q == other.q and iterable == other.iterable;
}

template <typename It>
bool Chain<It>::iterator::
operator!=(const typename Chain<It>::iterator &other) const {
return !operator==(other);
}

template <typename It>
typename Chain<It>::iterator &Chain<It>::iterator::operator++() {
++p;
++i;
if (i == n) {
++q;
conditionally_start_segment();
}
return *this;
}

template <typename It>
typename Chain<It>::iterator Chain<It>::iterator::operator++(int) {
const iterator tmp = *this;
operator++();
return tmp;
}

template <typename It>
typename Chain<It>::iterator::reference Chain<It>::iterator::operator*() const {
return *p;
}

} // namespace mgard
55 changes: 55 additions & 0 deletions tests/src/test_utilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,58 @@ TEST_CASE("Bits iteration", "[utilities]") {
}
}
}

TEST_CASE("Chain iteration", "[utilities]") {
SECTION("reading") {
const std::size_t N = 5;
std::array<std::vector<unsigned char>, N> in;
in.at(0) = {0};
in.at(1) = {1, 2, 3};
in.at(2) = {};
in.at(3) = {4, 5, 6};
in.at(4) = {7, 8, 9, 10};
using It = std::vector<unsigned char>::const_iterator;
std::vector<std::pair<It, std::size_t>> segments;
for (const std::vector<unsigned char> &in_ : in) {
segments.push_back({in_.begin(), in_.size()});
}
unsigned char expected = 0;
TrialTracker tracker;
for (const unsigned char read : mgard::Chain(segments)) {
tracker += read == expected++;
}
REQUIRE(tracker);
REQUIRE(expected == 11);
}

SECTION("writing") {
const std::size_t N = 4;
std::array<std::vector<unsigned short int>, N> out;
const std::array<std::size_t, N> ns{3, 5, 0, 1};
using It = std::vector<unsigned short int>::iterator;
std::vector<std::pair<It, std::size_t>> segments;
segments.reserve(N);
for (std::size_t i = 0; i < N; ++i) {
std::vector<unsigned short int> &out_ = out.at(i);
const std::size_t n = ns.at(i);
out_.resize(n);
segments.push_back({out_.begin(), n});
}

unsigned short int a = 1;
unsigned short int b = 1;
for (unsigned short int &c : mgard::Chain(segments)) {
c = a;
const unsigned short int tmp = a + b;
a = b;
b = tmp;
}

std::array<std::vector<unsigned short int>, N> expected;
expected.at(0) = {1, 1, 2};
expected.at(1) = {3, 5, 8, 13, 21};
expected.at(2) = {};
expected.at(3) = {34};
REQUIRE(out == expected);
}
}

0 comments on commit 15d5710

Please sign in to comment.