diff --git a/core/channel/async_migrate_channel.hpp b/core/channel/async_migrate_channel.hpp index 377b8f8..7aedece 100644 --- a/core/channel/async_migrate_channel.hpp +++ b/core/channel/async_migrate_channel.hpp @@ -44,7 +44,7 @@ class AsyncMigrateChannel : public MigrateChannel { void send() override { // No increment progress id here int start = std::rand(); - auto shard_info_iter = ShardInfoIter(*this->destination_); + auto shard_info_iter = ShardInfoIter(*this->destination_, start); for (int i = 0; i < this->migrate_buffer_.size(); ++i) { int dst = (start + i) % this->migrate_buffer_.size(); auto pid_and_sid = shard_info_iter.next(); diff --git a/core/channel/broadcast_channel.hpp b/core/channel/broadcast_channel.hpp index 879a38f..ea32eb5 100644 --- a/core/channel/broadcast_channel.hpp +++ b/core/channel/broadcast_channel.hpp @@ -97,7 +97,7 @@ class BroadcastChannel : public ChannelBase, public Shard { void send() override { this->inc_progress(); int start = std::rand(); - auto shard_info_iter = ShardInfoIter(*this->destination_); + auto shard_info_iter = ShardInfoIter(*this->destination_, start); for (int i = 0; i < broadcast_buffer_.size(); ++i) { int dst = (start + i) % broadcast_buffer_.size(); auto pid_and_sid = shard_info_iter.next(); diff --git a/core/channel/migrate_channel.hpp b/core/channel/migrate_channel.hpp index f3c6d46..3df4505 100644 --- a/core/channel/migrate_channel.hpp +++ b/core/channel/migrate_channel.hpp @@ -53,7 +53,7 @@ class MigrateChannel : public ChannelBase { void send() override { this->inc_progress(); int start = std::rand(); - auto shard_info_iter = ShardInfoIter(*this->source_obj_list_); + auto shard_info_iter = ShardInfoIter(*this->source_obj_list_, start); for (int i = 0; i < migrate_buffer_.size(); ++i) { int dst = (start + i) % migrate_buffer_.size(); auto pid_and_sid = shard_info_iter.next(); diff --git a/core/channel/push_channel.hpp b/core/channel/push_channel.hpp index f7c776e..f5f64bb 100644 --- a/core/channel/push_channel.hpp +++ b/core/channel/push_channel.hpp @@ -34,7 +34,7 @@ class PushChannel : public ChannelBase { void send() override { int start = std::rand(); - auto shard_info_iter = ShardInfoIter(*this->destination_); + auto shard_info_iter = ShardInfoIter(*this->destination_, start); for (int i = 0; i < send_buffer_.size(); ++i) { int dst = (start + i) % send_buffer_.size(); auto pid_and_sid = shard_info_iter.next(); diff --git a/core/channel/push_combined_channel.hpp b/core/channel/push_combined_channel.hpp index a56fe5b..6a88877 100644 --- a/core/channel/push_combined_channel.hpp +++ b/core/channel/push_combined_channel.hpp @@ -48,7 +48,7 @@ class PushCombinedChannel : public ChannelBase { void send() override { int start = std::rand(); - auto shard_info_iter = ShardInfoIter(*this->destination_); + auto shard_info_iter = ShardInfoIter(*this->destination_, start); for (int i = 0; i < bin_stream_buffer_.size(); ++i) { int dst = (start + i) % bin_stream_buffer_.size(); auto pid_and_sid = shard_info_iter.next(); diff --git a/core/shard.cpp b/core/shard.cpp index 6209900..04f7c21 100644 --- a/core/shard.cpp +++ b/core/shard.cpp @@ -31,11 +31,15 @@ void Shard::init(const WorkerInfo& worker_info) { // this is important for back compatibility if (shard_info_.empty()) { // the default configuration of this ShardInfo comes from worker info. + num_shards_ = 0; for (int id : worker_info.get_pids()) { shard_info_.push_back({id, worker_info.get_num_local_workers(id)}); + if (id == self_pid_) { + global_shard_id_ = num_shards_ + local_shard_id_; + } + num_shards_ += worker_info.get_num_local_workers(id); } num_local_shards_ = worker_info.get_num_local_workers(); - num_shards_ = worker_info.get_num_workers(); hash_ring_ = worker_info.get_hash_ring(); } else { num_shards_ = 0; @@ -87,9 +91,19 @@ std::vector Shard::get_pids() { return pids; } -ShardInfoIter::ShardInfoIter(Shard& shard) +ShardInfoIter::ShardInfoIter(Shard& shard, int offset) : shard_(shard), shard_info_iter_(shard.get_shard_info().begin()) { + offset = offset % shard_.get_num_shards(); + for (int i = 0; i < offset;) { + if (shard_info_iter_->second <= offset - i) { + i += shard_info_iter_->second; + ++shard_info_iter_; + } else { + cur_local_shard_id_ += offset - i; + i = offset; + } + } } std::pair ShardInfoIter::next() { @@ -97,6 +111,9 @@ std::pair ShardInfoIter::next() { cur_local_shard_id_ = -1; // assert(shard_info_iter_ == shard_.get_shard_info().end()); ++shard_info_iter_; + if (shard_info_iter_ == shard_.get_shard_info().end()) { + shard_info_iter_ = shard_.get_shard_info().begin(); + } } return {shard_info_iter_->first, cur_local_shard_id_}; } diff --git a/core/shard.hpp b/core/shard.hpp index a269e90..f6aada2 100644 --- a/core/shard.hpp +++ b/core/shard.hpp @@ -81,7 +81,7 @@ class Shard { class ShardInfoIter { public: - ShardInfoIter(Shard& shard); + ShardInfoIter(Shard& shard, int offset = 0); inline int size() { return shard_.get_num_shards(); }