From b6f9d84a69f89b69d7c521ebc50f024996c2cc9b Mon Sep 17 00:00:00 2001 From: ddmbr Date: Tue, 24 Jan 2017 15:10:44 +0800 Subject: [PATCH] [Core][Channel] Refactor Channel ref #228 --- core/channel/aggregator_channel.cpp | 4 +- core/channel/aggregator_channel.hpp | 2 +- core/channel/channel_base.cpp | 28 +--- core/channel/channel_base.hpp | 85 ++++++----- core/channel/channel_manager.hpp | 17 +-- core/channel/channel_store.hpp | 50 ++++-- core/channel/channel_store_base.hpp | 25 ++- core/channel/push_combined_channel.hpp | 201 ++++++++----------------- core/executor.hpp | 109 ++++++-------- core/shuffle_combiner_base.hpp | 24 +++ core/sync_shuffle_combiner.hpp | 98 ++++++++++++ examples/pi.cpp | 13 +- lib/aggregator_factory.cpp | 8 +- 13 files changed, 357 insertions(+), 307 deletions(-) create mode 100644 core/shuffle_combiner_base.hpp create mode 100644 core/sync_shuffle_combiner.hpp diff --git a/core/channel/aggregator_channel.cpp b/core/channel/aggregator_channel.cpp index ba49207..c314013 100644 --- a/core/channel/aggregator_channel.cpp +++ b/core/channel/aggregator_channel.cpp @@ -31,7 +31,7 @@ AggregatorChannel::~AggregatorChannel() {} void AggregatorChannel::default_setup(std::function something = nullptr) { do_something = something; - setup(Context::get_local_tid(), Context::get_global_tid(), Context::get_worker_info(), Context::get_mailbox()); + // setup(Context::get_local_tid(), Context::get_global_tid(), Context::get_worker_info(), Context::get_mailbox()); } // Mark these member function private to avoid being used by users @@ -54,7 +54,7 @@ void AggregatorChannel::send(std::vector& bins) { bool AggregatorChannel::poll() { return this->mailbox_->poll(this->channel_id_, this->progress_); } -BinStream AggregatorChannel::recv() { return this->mailbox_->recv(this->channel_id_, this->progress_); } +BinStream AggregatorChannel::recv_() { return this->mailbox_->recv(this->channel_id_, this->progress_); } void AggregatorChannel::prepare() {} diff --git a/core/channel/aggregator_channel.hpp b/core/channel/aggregator_channel.hpp index 8c3c4d9..61e3c76 100644 --- a/core/channel/aggregator_channel.hpp +++ b/core/channel/aggregator_channel.hpp @@ -37,7 +37,7 @@ class AggregatorChannel : public ChannelBase { void default_setup(std::function something); void send(std::vector& bins); bool poll(); - BinStream recv(); + BinStream recv_(); private: std::function do_something = nullptr; diff --git a/core/channel/channel_base.cpp b/core/channel/channel_base.cpp index 474ff1c..033e3fe 100644 --- a/core/channel/channel_base.cpp +++ b/core/channel/channel_base.cpp @@ -20,36 +20,14 @@ namespace husky { -thread_local size_t ChannelBase::counter = 0; +thread_local int ChannelBase::max_channel_id_ = 0; -ChannelBase::ChannelBase() : channel_id_(counter), progress_(0) { - counter += 1; - set_as_sync_channel(); -} - -void ChannelBase::set_local_id(size_t local_id) { local_id_ = local_id; } - -void ChannelBase::set_global_id(size_t global_id) { global_id_ = global_id; } - -void ChannelBase::set_worker_info(const WorkerInfo& worker_info) { worker_info_.reset(new WorkerInfo(worker_info)); } - -void ChannelBase::set_mailbox(LocalMailbox* mailbox) { mailbox_ = mailbox; } - -void ChannelBase::set_as_async_channel() { type_ = ChannelType::Async; } - -void ChannelBase::set_as_sync_channel() { type_ = ChannelType::Sync; } - -void ChannelBase::setup(size_t local_id, size_t global_id, const WorkerInfo& worker_info, LocalMailbox* mailbox) { - set_local_id(local_id); - set_global_id(global_id); - set_worker_info(worker_info); - set_mailbox(mailbox); - customized_setup(); +ChannelBase::ChannelBase() : channel_id_(max_channel_id_), progress_(0) { + max_channel_id_ += 1; } void ChannelBase::inc_progress() { progress_ += 1; - flushed_.resize(progress_ + 1, true); } } // namespace husky diff --git a/core/channel/channel_base.hpp b/core/channel/channel_base.hpp index 4db4c88..10f50d4 100644 --- a/core/channel/channel_base.hpp +++ b/core/channel/channel_base.hpp @@ -24,62 +24,68 @@ namespace husky { -using base::BinStream; - class ChannelBase { public: - enum class ChannelType { Sync, Async }; - virtual ~ChannelBase() = default; - /// Getter - inline static size_t get_num_channel() { return counter; } + // Getters of basic information + inline LocalMailbox* get_mailbox() const { return mailbox_; } inline size_t get_channel_id() const { return channel_id_; } inline size_t get_global_id() const { return global_id_; } inline size_t get_local_id() const { return local_id_; } inline size_t get_progress() const { return progress_; } - inline ChannelType get_channel_type() const { return type_; } - /// Setter - void set_local_id(size_t local_id); - void set_global_id(size_t global_id); - void set_worker_info(const WorkerInfo& worker_info); - void set_mailbox(LocalMailbox* mailbox); + // Setters of basic information - void set_as_async_channel(); - void set_as_sync_channel(); + void set_local_id(size_t local_id) { local_id_ = local_id; } + void set_global_id(size_t global_id) { global_id_ = global_id; } + void set_worker_info(const WorkerInfo& worker_info) { worker_info_.reset(new WorkerInfo(worker_info)); } + void set_mailbox(LocalMailbox* mailbox) { mailbox_ = mailbox; } - void setup(size_t local_id, size_t global_id, const WorkerInfo& worker_info, LocalMailbox* mailbox); + // Top-level APIs - /// customized_setup() is used to do customized setup for subclass - virtual void customized_setup() = 0; + virtual void in() { + this->recv(); + this->post_recv(); + }; - /// prepare() needs to be invoked to do some preparation work (if any), such as clearing buffers, - /// before it can take new incoming communication using the in(BinStream&) method. - /// In list_execute (in core/executor.hpp), prepare() is usually used before any in(BinStream&). - virtual void prepare() {} + virtual void out() { + this->pre_send(); + this->send(); + this->post_send(); + }; - /// in(BinStream&) defines what the channel should do when receiving a binstream - virtual void in(BinStream& bin) {} + // Second-level APIs - /// out() defines what the channel should do after a list_execute, normally mailbox->send_complete() will be invoked - virtual void out() {} + virtual void recv() { + // A simple default synchronous implementation + if(mailbox_ == nullptr) + throw base::HuskyException("Local mailbox not set, and thus cannot use the recv() method."); - /// is_flushed() checks whether flush() is invoked - /// If yes, then list_execute will invoke prepare() and later in(BinStream& bin) will be invoked - /// If no, list_execute will just omit this channel - inline bool is_flushed() { return flushed_[progress_]; } + while(mailbox_->poll(channel_id_, progress_)) { + base::BinStream bin_stream = mailbox_->recv(channel_id_, progress_); + if(bin_stream_processor_ != nullptr) + bin_stream_processor_(&bin_stream); + } + }; - /// Invoked by prepare_messages or ChannelManager after receiving from mailbox - /// reset the flushed_ so that prepare/prepare_messages won't be invoked next time - inline void reset_flushed() { flushed_[progress_] = false; } + virtual void post_recv() {}; + virtual void pre_send() {}; + virtual void send() {}; + virtual void post_send() {}; - void inc_progress(); + // Third-level APIs (invoked by its upper level) - virtual void send() {}; + void set_bin_stream_processor(std::function bin_stream_processor) { + bin_stream_processor_ = bin_stream_processor; + } - virtual void send_complete() {}; + std::function get_bin_stream_processor() { + return bin_stream_processor_; + } + + void inc_progress(); protected: ChannelBase(); @@ -95,15 +101,12 @@ class ChannelBase { size_t local_id_; size_t progress_; - ChannelType type_; - - std::vector flushed_{0}; - std::unique_ptr worker_info_; LocalMailbox* mailbox_ = nullptr; - const HashRing* hash_ring_ = nullptr; - static thread_local size_t counter; + std::function bin_stream_processor_ = nullptr; + + static thread_local int max_channel_id_; }; } // namespace husky diff --git a/core/channel/channel_manager.hpp b/core/channel/channel_manager.hpp index b235410..9fe19da 100644 --- a/core/channel/channel_manager.hpp +++ b/core/channel/channel_manager.hpp @@ -40,13 +40,10 @@ class ChannelManager { std::vector selected_channels; std::vector> channel_progress_pairs; for (auto* channel : channels_) { - // Only consider the channels_ which are flushed and do preparation - if (channel->is_flushed()) { - channel->prepare(); - selected_channels.push_back(channel); - channel_progress_pairs.push_back({channel->get_channel_id(), channel->get_progress()}); - } + selected_channels.push_back(channel); + channel_progress_pairs.push_back({channel->get_channel_id(), channel->get_progress()}); } + // return if no channel is flushed if (channel_progress_pairs.empty()) return; @@ -54,18 +51,16 @@ class ChannelManager { // receive from mailbox_ and distbribute int idx = -1; std::pair pair; - if (channel_progress_pairs.empty()) { - return; - } + while (mailbox_->poll(channel_progress_pairs, &idx)) { ASSERT_MSG(idx != -1, "ChannelManager: Mailbox poll error"); auto bin = mailbox_->recv(channel_progress_pairs[idx].first, channel_progress_pairs[idx].second); - selected_channels[idx]->in(bin); + selected_channels[idx]->get_bin_stream_processor()(&bin); } // reset the flushed_ buffer for (auto* ch : selected_channels) { - ch->reset_flushed(); + ch->post_recv(); } } diff --git a/core/channel/channel_store.hpp b/core/channel/channel_store.hpp index 8fd2db4..ac46a98 100644 --- a/core/channel/channel_store.hpp +++ b/core/channel/channel_store.hpp @@ -17,6 +17,7 @@ #include #include "core/channel/channel_store_base.hpp" +#include "core/sync_shuffle_combiner.hpp" #include "core/context.hpp" #include "core/objlist.hpp" @@ -32,17 +33,46 @@ class ChannelStore : public ChannelStoreBase { static PushChannel& create_push_channel(ChannelSource& src_list, ObjList& dst_list, const std::string& name = "") { auto& ch = ChannelStoreBase::create_push_channel(src_list, dst_list, name); - setup(ch); return ch; } // Create PushCombinedChannel template - static PushCombinedChannel& create_push_combined_channel(ChannelSource& src_list, - ObjList& dst_list, - const std::string& name = "") { - auto& ch = ChannelStoreBase::create_push_combined_channel(src_list, dst_list, name); - setup(ch); + static auto* create_push_combined_channel(ObjList* dst_list, const std::string& name = "") { + auto* ch = ChannelStoreBase::create_push_combined_channel(*dst_list); + common_setup(ch); + ch->set_obj_list(dst_list); + ch->set_combiner(new SyncShuffleCombiner(Context::get_zmq_context())); + ch->set_bin_stream_processor([=](base::BinStream* bin_stream){ + auto* recv_buffer = ch->get_recv_buffer(); + auto* recv_flags = ch->get_recv_flags(); + + while (bin_stream->size() != 0) { + typename DstObjT::KeyT key; + *bin_stream >> key; + MsgT msg; + *bin_stream >> msg; + + DstObjT* recver_obj = dst_list->find(key); + int idx; + if (recver_obj == nullptr) { + DstObjT obj(key); // Construct obj using key only + idx = dst_list->add_object(std::move(obj)); + } else { + idx = dst_list->index_of(recver_obj); + } + if (idx >= ch->get_recv_buffer()->size()) { + recv_buffer->resize(idx + 1); + recv_flags->resize(idx + 1); + } + if ((*recv_flags)[idx] == true) { + CombineT::combine((*recv_buffer)[idx], msg); + } else { + (*recv_buffer)[idx] = std::move(msg); + (*recv_flags)[idx] = true; + } + } + }); return ch; } @@ -82,9 +112,11 @@ class ChannelStore : public ChannelStoreBase { return ch; } - static void setup(ChannelBase& ch) { - ch.setup(Context::get_local_tid(), Context::get_global_tid(), Context::get_worker_info(), - Context::get_mailbox()); + static void common_setup(ChannelBase* ch) { + ch->set_mailbox(Context::get_mailbox(Context::get_local_tid())); + ch->set_local_id(Context::get_local_tid()); + ch->set_global_id(Context::get_global_tid()); + ch->set_worker_info(Context::get_worker_info()); } }; diff --git a/core/channel/channel_store_base.hpp b/core/channel/channel_store_base.hpp index 6b79cc1..ae81af3 100644 --- a/core/channel/channel_store_base.hpp +++ b/core/channel/channel_store_base.hpp @@ -24,6 +24,7 @@ #include "core/channel/migrate_channel.hpp" #include "core/channel/push_channel.hpp" #include "core/channel/push_combined_channel.hpp" +#include "core/sync_shuffle_combiner.hpp" namespace husky { @@ -53,23 +54,33 @@ class ChannelStoreBase { // Create PushCombinedChannel template - static PushCombinedChannel& create_push_combined_channel(ChannelSource& src_list, - ObjList& dst_list, - const std::string& name = "") { + static auto* create_push_combined_channel(ObjList& dst_list, + const std::string& name = "") { std::string channel_name = name.empty() ? channel_name_prefix + std::to_string(default_channel_id++) : name; if(channel_map.find(name) != channel_map.end()) throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists"); - auto* push_combined_channel = new PushCombinedChannel(&src_list, &dst_list); + auto* push_combined_channel = new PushCombinedChannel(); channel_map.insert({channel_name, push_combined_channel}); - return *push_combined_channel; + return push_combined_channel; + } + + // Create PushCombinedChannel + template + static auto* create_push_combined_channel(const std::string& name = "") { + std::string channel_name = name.empty() ? channel_name_prefix + std::to_string(default_channel_id++) : name; + if(channel_map.find(name) != channel_map.end()) + throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists"); + auto* push_combined_channel = new PushCombinedChannel(); + channel_map.insert({channel_name, push_combined_channel}); + return push_combined_channel; } template - static PushCombinedChannel& get_push_combined_channel(const std::string& name = "") { + static auto& get_push_combined_channel(const std::string& name = "") { if(channel_map.find(name) == channel_map.end()) throw base::HuskyException("ChannelStoreBase::get_channel: Channel name doesn't exist"); auto* channel = channel_map[name]; - return *dynamic_cast*>(channel); + return *static_cast*>(channel); } // Create MigrateChannel diff --git a/core/channel/push_combined_channel.hpp b/core/channel/push_combined_channel.hpp index f586a06..96c55dc 100644 --- a/core/channel/push_combined_channel.hpp +++ b/core/channel/push_combined_channel.hpp @@ -27,27 +27,18 @@ #include "core/hash_ring.hpp" #include "core/mailbox.hpp" #include "core/objlist.hpp" -#include "core/shuffle_combiner_store.hpp" +#include "core/shuffle_combiner_base.hpp" #include "core/worker_info.hpp" +#include "core/zmq_helpers.hpp" namespace husky { -using base::BinStream; - template -class PushCombinedChannel : public Source2ObjListChannel { +class PushCombinedChannel : public ChannelBase { public: - PushCombinedChannel(ChannelSource* src, ObjList* dst) : Source2ObjListChannel(src, dst) { - this->src_ptr_->register_outchannel(this->channel_id_, this); - this->dst_ptr_->register_inchannel(this->channel_id_, this); - } - - ~PushCombinedChannel() override { - ShuffleCombinerStore::remove_shuffle_combiner(this->channel_id_); + PushCombinedChannel() = default; - this->src_ptr_->deregister_outchannel(this->channel_id_); - this->dst_ptr_->deregister_inchannel(this->channel_id_); - } + // Are the following necessary? PushCombinedChannel(const PushCombinedChannel&) = delete; PushCombinedChannel& operator=(const PushCombinedChannel&) = delete; @@ -55,52 +46,14 @@ class PushCombinedChannel : public Source2ObjListChannel { PushCombinedChannel(PushCombinedChannel&&) = default; PushCombinedChannel& operator=(PushCombinedChannel&&) = default; - void customized_setup() override { - // Initialize send_buffer_ - // use get_largest_tid() instead of get_num_workers() - // sine we may only use a subset of worker - send_buffer_.resize(this->worker_info_->get_largest_tid() + 1); - // Create shuffle_combiner_ - // TODO(yuzhen): Only support sortcombine, hashcombine can be added using enableif - shuffle_combiner_ = ShuffleCombinerStore::create_shuffle_combiner( - this->channel_id_, this->local_id_, this->worker_info_->get_num_local_workers(), - this->worker_info_->get_largest_tid() + 1); - } + // The following are virtual methods - void push(const MsgT& msg, const typename DstObjT::KeyT& key) { - // shuffle_combiner_.init(); // Already move init() to create_shuffle_combiner_() - int dst_worker_id = this->worker_info_->get_hash_ring().hash_lookup(key); - auto& buffer = (*shuffle_combiner_)[this->local_id_].storage(dst_worker_id); - back_combine(buffer, key, msg); - } - - const MsgT& get(const DstObjT& obj) { - auto idx = this->dst_ptr_->index_of(&obj); - if (idx >= recv_buffer_.size()) { // resize recv_buffer_ and recv_flag_ if it is not large enough - recv_buffer_.resize(this->dst_ptr_->get_size()); - recv_flag_.resize(this->dst_ptr_->get_size()); - } - if (recv_flag_[idx] == false) { - recv_buffer_[idx] = MsgT(); - } - return recv_buffer_[idx]; + void pre_send() override { + // shuffle and combine + this->shuffle_combiner_impl_->shuffle(); + this->shuffle_combiner_impl_->combine(&send_buffer_); } - bool has_msgs(const DstObjT& obj) { - auto idx = this->dst_ptr_->index_of(&obj); - if (idx >= recv_buffer_.size()) { // resize recv_buffer_ and recv_flag_ if it is not large enough - recv_buffer_.resize(this->dst_ptr_->get_size()); - recv_flag_.resize(this->dst_ptr_->get_size()); - } - return recv_flag_[idx]; - } - - void prepare() override { clear_recv_buffer_(); } - - void in(BinStream& bin) override { process_bin(bin); } - - void out() override { flush(); } - void send() { int start = this->global_id_; for (int i = 0; i < send_buffer_.size(); ++i) { @@ -112,107 +65,79 @@ class PushCombinedChannel : public Source2ObjListChannel { } } - void send_complete() { + void post_send() { this->inc_progress(); this->mailbox_->send_complete(this->channel_id_, this->progress_, this->worker_info_->get_local_tids(), this->worker_info_->get_pids()); + clear_recv_buffer_(); } - /// This method is only useful without list_execute - void flush() { - shuffle_combine(); - send(); - send_complete(); + // The following are specific to this channel type + + void set_combiner(ShuffleCombinerBase* combiner_base) { + shuffle_combiner_impl_.reset(combiner_base); + shuffle_combiner_impl_->set_local_id(local_id_); + shuffle_combiner_impl_->set_channel_id(channel_id_); + shuffle_combiner_impl_->set_worker_info(*(worker_info_.get())); // FIXME(fan) unsafe + + // FIXME(fan) This should not be here + if(send_buffer_.size() != worker_info_->get_largest_tid()+1) + send_buffer_.resize(worker_info_->get_largest_tid()+1); } - /// This method is only useful without list_execute - void prepare_messages() { - if (!this->is_flushed()) - return; - clear_recv_buffer_(); - while (this->mailbox_->poll(this->channel_id_, this->progress_)) { - auto bin_push = this->mailbox_->recv(this->channel_id_, this->progress_); - process_bin(bin_push); + inline void push(const MsgT& msg, const typename DstObjT::KeyT& key) { + int dst_worker_id = this->worker_info_->get_hash_ring().hash_lookup(key); + this->shuffle_combiner_impl_->push(msg, key, dst_worker_id); + } + + inline const MsgT& get(const DstObjT& obj) { + if(this->obj_list_ptr_ == nullptr) { + throw base::HuskyException("Object list not set and thus cannot get message by \ + providing an object. Please use `set_obj_list` first."); } - this->reset_flushed(); + auto idx = obj - &this->obj_list_ptr_->get_data()[0]; // FIXME(fan): unsafe + if(not has_msgs(idx)) return MsgT(); + return get(idx); } - ShuffleCombiner>& get_shuffle_combiner(int tid) { - return (*shuffle_combiner_)[tid]; - } + inline const MsgT& get(int idx) { + return recv_buffer_[idx]; + } - std::vector& get_send_buffer() { - return send_buffer_; + inline bool has_msgs(const DstObjT& obj) { + if(this->obj_start_ptr_ == nullptr) { + throw base::HuskyException("Object list not set and thus cannot get message by \ + providing an object. Please use `set_obj_list` first."); + } + auto idx = obj - &this->obj_list_ptr_->get_data()[0]; // FIXME(fan): unsafe + return has_msgs(idx); } - protected: - void clear_recv_buffer_() { std::fill(recv_flag_.begin(), recv_flag_.end(), false); } + inline bool has_msgs(int idx) { + if (idx >= recv_buffer_.size()) return false; + return recv_flag_[idx]; + } - void process_bin(BinStream& bin_push) { - while (bin_push.size() != 0) { - typename DstObjT::KeyT key; - bin_push >> key; - MsgT msg; - bin_push >> msg; - - DstObjT* recver_obj = this->dst_ptr_->find(key); - size_t idx; - if (recver_obj == nullptr) { - DstObjT obj(key); // Construct obj using key only - idx = this->dst_ptr_->add_object(std::move(obj)); - } else { - idx = this->dst_ptr_->index_of(recver_obj); - } - if (idx >= recv_buffer_.size()) { - recv_buffer_.resize(idx + 1); - recv_flag_.resize(idx + 1); - } - if (recv_flag_[idx] == true) { - CombineT::combine(recv_buffer_[idx], msg); - } else { - recv_buffer_[idx] = std::move(msg); - recv_flag_[idx] = true; - } - } + void set_obj_list(ObjList* obj_list_ptr) { + obj_list_ptr_ = obj_list_ptr; } - void shuffle_combine() { - // step 1: shuffle combine - auto& self_shuffle_combiner = (*shuffle_combiner_)[this->local_id_]; - self_shuffle_combiner.send_shuffler_buffer(); - for (int iter = 0; iter < this->worker_info_->get_num_local_workers() - 1; iter++) { - int next_worker = self_shuffle_combiner.access_next(); - auto& peer_shuffle_combiner = (*shuffle_combiner_)[next_worker]; - for (int i = this->local_id_; i < this->worker_info_->get_largest_tid() + 1; - i += this->worker_info_->get_num_local_workers()) { - // combining the i-th buffer - auto& self_buffer = self_shuffle_combiner.storage(i); - auto& peer_buffer = peer_shuffle_combiner.storage(i); - self_buffer.insert(self_buffer.end(), peer_buffer.begin(), peer_buffer.end()); - peer_buffer.clear(); - } - } - for (int i = this->local_id_; i < this->worker_info_->get_largest_tid() + 1; - i += this->worker_info_->get_num_local_workers()) { - auto& self_buffer = self_shuffle_combiner.storage(i); - combine_single(self_buffer); - } - // step 2: serialize combine buffer - for (int i = this->local_id_; i < this->worker_info_->get_largest_tid() + 1; - i += this->worker_info_->get_num_local_workers()) { - auto& combine_buffer = self_shuffle_combiner.storage(i); - for (int k = 0; k < combine_buffer.size(); k++) { - send_buffer_[i] << combine_buffer[k].first; - send_buffer_[i] << combine_buffer[k].second; - } - combine_buffer.clear(); - } + std::vector* get_recv_buffer() { + return &recv_buffer_; } - std::vector>>* shuffle_combiner_; - std::vector send_buffer_; + std::vector* get_recv_flags() { + return &recv_flag_; + } + + protected: + void clear_recv_buffer_() { std::fill(recv_flag_.begin(), recv_flag_.end(), false); } + + std::vector send_buffer_; std::vector recv_buffer_; std::vector recv_flag_; + ObjList* obj_list_ptr_; + std::unique_ptr> shuffle_combiner_impl_; }; } // namespace husky diff --git a/core/executor.hpp b/core/executor.hpp index ea73bf5..53ff89d 100644 --- a/core/executor.hpp +++ b/core/executor.hpp @@ -105,69 +105,52 @@ void globalize(ObjList& obj_list) { /// Channel must be AsyncPushChannel or AsyncMigrateChannel /// Only one channel is allowed so far /// TODO(Wei): Multiple channels should be allowed to bind to one object list -template -void list_execute_async(ObjList& obj_list, ExecT execute, int async_time, double timeout = 0.0) { - std::vector channels = obj_list.get_inchannels(); - if (channels.size() != 1) - throw base::HuskyException("list_execute_async currently only supports exactly one channel."); - ChannelBase* channel = channels[0]; - if (channel->get_channel_type() != ChannelBase::ChannelType::Async) - throw base::HuskyException("list_execute_async currently only supports one asynchronous channel."); - - auto start = std::chrono::steady_clock::now(); - auto duration = std::chrono::seconds(async_time); - auto* mailbox = channel->get_mailbox(); - while (std::chrono::steady_clock::now() - start < duration) { - // 1. receive messages if any - channel->prepare(); - if (timeout == 0.0) { - while (mailbox->poll_non_block(channel->get_channel_id(), channel->get_progress())) { - auto bin = mailbox->recv(channel->get_channel_id(), channel->get_progress()); - channel->in(bin); - } - } else { - while (mailbox->poll_with_timeout(channel->get_channel_id(), channel->get_progress(), timeout)) { - auto bin = mailbox->recv(channel->get_channel_id(), channel->get_progress()); - channel->in(bin); - } - } - - // 2. iterate over the list - for (size_t i = 0; i < obj_list.get_vector_size(); ++i) { - if (obj_list.get_del(i)) - continue; - execute(obj_list.get(i)); - } - - // 3. flush - channel->out(); - } - mailbox->send_complete(channel->get_channel_id(), channel->get_progress(), - Context::get_worker_info().get_local_tids(), Context::get_worker_info().get_pids()); - channel->prepare(); - while (mailbox->poll(channel->get_channel_id(), channel->get_progress())) { - auto bin = mailbox->recv(channel->get_channel_id(), channel->get_progress()); - channel->in(bin); - } - channel->inc_progress(); -} - -template -void list_execute(ObjList& obj_list, ExecT execute) { - // TODO(all): the order of invoking prefuncs may matter. - // e.g. MigrateChannel should be invoked before PushChannel - ChannelManager in_manager(obj_list.get_inchannels()); - in_manager.poll_and_distribute(); - - for (size_t i = 0; i < obj_list.get_vector_size(); ++i) { - if (obj_list.get_del(i)) - continue; - execute(obj_list.get(i)); - } - - ChannelManager out_manager(obj_list.get_outchannels()); - out_manager.flush(); -} +// template +// void list_execute_async(ObjList& obj_list, ExecT execute, int async_time, double timeout = 0.0) { +// std::vector channels = obj_list.get_inchannels(); +// if (channels.size() != 1) +// throw base::HuskyException("list_execute_async currently only supports exactly one channel."); +// ChannelBase* channel = channels[0]; +// if (channel->get_channel_type() != ChannelBase::ChannelType::Async) +// throw base::HuskyException("list_execute_async currently only supports one asynchronous channel."); +// +// auto start = std::chrono::steady_clock::now(); +// auto duration = std::chrono::seconds(async_time); +// auto* mailbox = channel->get_mailbox(); +// while (std::chrono::steady_clock::now() - start < duration) { +// // 1. receive messages if any +// channel->prepare(); +// if (timeout == 0.0) { +// while (mailbox->poll_non_block(channel->get_channel_id(), channel->get_progress())) { +// auto bin = mailbox->recv(channel->get_channel_id(), channel->get_progress()); +// channel->in(bin); +// } +// } else { +// while (mailbox->poll_with_timeout(channel->get_channel_id(), channel->get_progress(), timeout)) { +// auto bin = mailbox->recv(channel->get_channel_id(), channel->get_progress()); +// channel->in(bin); +// } +// } +// +// // 2. iterate over the list +// for (size_t i = 0; i < obj_list.get_vector_size(); ++i) { +// if (obj_list.get_del(i)) +// continue; +// execute(obj_list.get(i)); +// } +// +// // 3. flush +// channel->out(); +// } +// mailbox->send_complete(channel->get_channel_id(), channel->get_progress(), +// Context::get_worker_info().get_local_tids(), Context::get_worker_info().get_pids()); +// channel->prepare(); +// while (mailbox->poll(channel->get_channel_id(), channel->get_progress())) { +// auto bin = mailbox->recv(channel->get_channel_id(), channel->get_progress()); +// channel->in(bin); +// } +// channel->inc_progress(); +// } template void list_execute(ObjList& obj_list, const std::vector& in_channel, diff --git a/core/shuffle_combiner_base.hpp b/core/shuffle_combiner_base.hpp new file mode 100644 index 0000000..d4baa9e --- /dev/null +++ b/core/shuffle_combiner_base.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "base/serialization.hpp" +#include "core/worker_info.hpp" + +namespace husky { + +template +class ShuffleCombinerBase { + public: + virtual void set_local_id(int local_id) { local_id_ = local_id; } + virtual void set_channel_id(int channel_id) { channel_id_ = channel_id; } + virtual void set_worker_info(const WorkerInfo& worker_info) { worker_info_ = worker_info; } + + virtual void push(const MsgT& msg, const KeyT& key, int dst_worker_id) = 0; + virtual void shuffle() = 0; + virtual void combine(std::vector* send_buffers) = 0; + protected: + int local_id_; + int channel_id_; + WorkerInfo worker_info_; +}; + +} // namespace husky diff --git a/core/sync_shuffle_combiner.hpp b/core/sync_shuffle_combiner.hpp new file mode 100644 index 0000000..466a87a --- /dev/null +++ b/core/sync_shuffle_combiner.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include "zmq.hpp" + +#include "core/shuffle_combiner_base.hpp" +#include "core/worker_info.hpp" + +namespace husky { + +template +class SyncShuffleCombinerSharedStore { + public: + typedef std::vector>> BufferT; + + static SyncShuffleCombinerSharedStore* get_instance() { + static SyncShuffleCombinerSharedStore shared_store; + return &shared_store; + } + + BufferT* get(int tid) { + std::lock_guard lock(mutex_); + return &shared_buffers[tid]; + } + protected: + SyncShuffleCombinerSharedStore() = default; + std::mutex mutex_; + std::unordered_map shared_buffers; +}; + +template +class SyncShuffleCombiner : public ShuffleCombinerBase { + public: + typedef std::vector>> BufferT; + + SyncShuffleCombiner(zmq::context_t* zmq_context) : zmq_context_(zmq_context) {} + + void set_worker_info(const WorkerInfo& worker_info) override { + worker_info_ = worker_info; + send_buffers_.resize(worker_info_.get_num_workers()); + } + + void push(const MsgT& msg, const KeyT& key, int dst_worker_id) override { + send_buffers_[dst_worker_id].push_back({key, msg}); + } + + void shuffle() override { + // shuffle the buffers + // 1. declare that my buffer is ready + if(not is_shuffle_itc_ready_) + init_shuffle_itc(); + SyncShuffleCombinerSharedStore::get_instance()->get(local_id_)->clear(); + SyncShuffleCombinerSharedStore::get_instance()->get(local_id_)->swap(send_buffers_); + send_buffers_.resize(worker_info_.get_num_workers()); + zmq_send_int32(pub_sock_.get(), local_id_); + + // 2. stream in others' buffers + for(int i=0; i::get_instance()->get(tid); + for(int j=0; j* bin_stream_buffers) override { + for(int i : worker_info_.get_global_tids()) { + if(send_buffers_[i].size() != 0) { + combine_single(send_buffers_[i]); + for(auto& pair : send_buffers_[i]) + (*bin_stream_buffers)[i] << pair; + } + } + } + + protected: + void init_shuffle_itc() { + pub_sock_.reset(new zmq::socket_t(*zmq_context_, ZMQ_PUB)); + sub_sock_.reset(new zmq::socket_t(*zmq_context_, ZMQ_SUB)); + pub_sock_->bind("inproc://sync-shuffle-combine-"+std::to_string(channel_id_)+"-"+std::to_string(local_id_)); + for(int tid : worker_info_.get_local_tids()) { + sub_sock_->connect("inproc://sync-shuffle-combine-"+std::to_string(channel_id_)+"-"+std::to_string(tid)); + sub_sock_->setsockopt(ZMQ_SUBSCRIBE, "", 0); + } + } + + bool is_shuffle_itc_ready_ = false; + zmq::context_t* zmq_context_ = nullptr; + std::unique_ptr pub_sock_ = nullptr; + std::unique_ptr sub_sock_ = nullptr; + WorkerInfo worker_info_; + int local_id_; + int channel_id_; + BufferT send_buffers_; +}; + +} // namespace husky diff --git a/examples/pi.cpp b/examples/pi.cpp index 26f441d..f4f3e73 100644 --- a/examples/pi.cpp +++ b/examples/pi.cpp @@ -45,14 +45,15 @@ void pi() { // Aggregate statistics to object 0 auto& pi_list = husky::ObjListStore::create_objlist(); - auto& ch = husky::ChannelStore::create_push_combined_channel>(pi_list, pi_list); - ch.push(cnt, 0); - ch.flush(); - list_execute(pi_list, [&](PIObject& obj) { - int sum = ch.get(obj); + auto* ch = husky::ChannelStore::create_push_combined_channel>(&pi_list); + ch->push(cnt, 0); + ch->out(); + ch->in(); + if(ch->has_msgs(0)) { + int sum = ch->get(0); int total_pts = num_pts_per_thread * husky::Context::get_num_workers(); husky::LOG_I << (4.0 * sum / total_pts); - }); + } } int main(int argc, char** argv) { diff --git a/lib/aggregator_factory.cpp b/lib/aggregator_factory.cpp index fdd5cbf..100a4c3 100644 --- a/lib/aggregator_factory.cpp +++ b/lib/aggregator_factory.cpp @@ -112,10 +112,10 @@ void AggregatorFactory::send(AggregatorChannel& channel, std::vector& void AggregatorFactory::on_recv(AggregatorChannel& channel, const std::function& recv) { while (channel.poll()) { - BinStream bin = channel.recv(); - if (bin.size() != 0) { - recv(bin); - } + // BinStream bin = channel.recv(); + // if (bin.size() != 0) { + // recv(bin); + // } } }