Skip to content

Commit

Permalink
[Core][Channel] Refactor Channel
Browse files Browse the repository at this point in the history
ref #228
  • Loading branch information
ddmbr committed Jan 24, 2017
1 parent 62a6947 commit b6f9d84
Show file tree
Hide file tree
Showing 13 changed files with 357 additions and 307 deletions.
4 changes: 2 additions & 2 deletions core/channel/aggregator_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ AggregatorChannel::~AggregatorChannel() {}

void AggregatorChannel::default_setup(std::function<void()> something = nullptr) {
do_something = something;
setup(Context::get_local_tid(), Context::get_global_tid(), Context::get_worker_info(), Context::get_mailbox());
// setup(Context::get_local_tid(), Context::get_global_tid(), Context::get_worker_info(), Context::get_mailbox());
}

// Mark these member function private to avoid being used by users
Expand All @@ -54,7 +54,7 @@ void AggregatorChannel::send(std::vector<BinStream>& bins) {

bool AggregatorChannel::poll() { return this->mailbox_->poll(this->channel_id_, this->progress_); }

BinStream AggregatorChannel::recv() { return this->mailbox_->recv(this->channel_id_, this->progress_); }
BinStream AggregatorChannel::recv_() { return this->mailbox_->recv(this->channel_id_, this->progress_); }

void AggregatorChannel::prepare() {}

Expand Down
2 changes: 1 addition & 1 deletion core/channel/aggregator_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AggregatorChannel : public ChannelBase {
void default_setup(std::function<void()> something);
void send(std::vector<BinStream>& bins);
bool poll();
BinStream recv();
BinStream recv_();

private:
std::function<void()> do_something = nullptr;
Expand Down
28 changes: 3 additions & 25 deletions core/channel/channel_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,14 @@

namespace husky {

thread_local size_t ChannelBase::counter = 0;
thread_local int ChannelBase::max_channel_id_ = 0;

ChannelBase::ChannelBase() : channel_id_(counter), progress_(0) {
counter += 1;
set_as_sync_channel();
}

void ChannelBase::set_local_id(size_t local_id) { local_id_ = local_id; }

void ChannelBase::set_global_id(size_t global_id) { global_id_ = global_id; }

void ChannelBase::set_worker_info(const WorkerInfo& worker_info) { worker_info_.reset(new WorkerInfo(worker_info)); }

void ChannelBase::set_mailbox(LocalMailbox* mailbox) { mailbox_ = mailbox; }

void ChannelBase::set_as_async_channel() { type_ = ChannelType::Async; }

void ChannelBase::set_as_sync_channel() { type_ = ChannelType::Sync; }

void ChannelBase::setup(size_t local_id, size_t global_id, const WorkerInfo& worker_info, LocalMailbox* mailbox) {
set_local_id(local_id);
set_global_id(global_id);
set_worker_info(worker_info);
set_mailbox(mailbox);
customized_setup();
ChannelBase::ChannelBase() : channel_id_(max_channel_id_), progress_(0) {
max_channel_id_ += 1;
}

void ChannelBase::inc_progress() {
progress_ += 1;
flushed_.resize(progress_ + 1, true);
}

} // namespace husky
85 changes: 44 additions & 41 deletions core/channel/channel_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,62 +24,68 @@

namespace husky {

using base::BinStream;

class ChannelBase {
public:
enum class ChannelType { Sync, Async };

virtual ~ChannelBase() = default;

/// Getter
inline static size_t get_num_channel() { return counter; }
// Getters of basic information

inline LocalMailbox* get_mailbox() const { return mailbox_; }
inline size_t get_channel_id() const { return channel_id_; }
inline size_t get_global_id() const { return global_id_; }
inline size_t get_local_id() const { return local_id_; }
inline size_t get_progress() const { return progress_; }
inline ChannelType get_channel_type() const { return type_; }

/// Setter
void set_local_id(size_t local_id);
void set_global_id(size_t global_id);
void set_worker_info(const WorkerInfo& worker_info);
void set_mailbox(LocalMailbox* mailbox);
// Setters of basic information

void set_as_async_channel();
void set_as_sync_channel();
void set_local_id(size_t local_id) { local_id_ = local_id; }
void set_global_id(size_t global_id) { global_id_ = global_id; }
void set_worker_info(const WorkerInfo& worker_info) { worker_info_.reset(new WorkerInfo(worker_info)); }
void set_mailbox(LocalMailbox* mailbox) { mailbox_ = mailbox; }

void setup(size_t local_id, size_t global_id, const WorkerInfo& worker_info, LocalMailbox* mailbox);
// Top-level APIs

/// customized_setup() is used to do customized setup for subclass
virtual void customized_setup() = 0;
virtual void in() {
this->recv();
this->post_recv();
};

/// prepare() needs to be invoked to do some preparation work (if any), such as clearing buffers,
/// before it can take new incoming communication using the in(BinStream&) method.
/// In list_execute (in core/executor.hpp), prepare() is usually used before any in(BinStream&).
virtual void prepare() {}
virtual void out() {
this->pre_send();
this->send();
this->post_send();
};

/// in(BinStream&) defines what the channel should do when receiving a binstream
virtual void in(BinStream& bin) {}
// Second-level APIs

/// out() defines what the channel should do after a list_execute, normally mailbox->send_complete() will be invoked
virtual void out() {}
virtual void recv() {
// A simple default synchronous implementation
if(mailbox_ == nullptr)
throw base::HuskyException("Local mailbox not set, and thus cannot use the recv() method.");

/// is_flushed() checks whether flush() is invoked
/// If yes, then list_execute will invoke prepare() and later in(BinStream& bin) will be invoked
/// If no, list_execute will just omit this channel
inline bool is_flushed() { return flushed_[progress_]; }
while(mailbox_->poll(channel_id_, progress_)) {
base::BinStream bin_stream = mailbox_->recv(channel_id_, progress_);
if(bin_stream_processor_ != nullptr)
bin_stream_processor_(&bin_stream);
}
};

/// Invoked by prepare_messages or ChannelManager after receiving from mailbox
/// reset the flushed_ so that prepare/prepare_messages won't be invoked next time
inline void reset_flushed() { flushed_[progress_] = false; }
virtual void post_recv() {};
virtual void pre_send() {};
virtual void send() {};
virtual void post_send() {};

void inc_progress();
// Third-level APIs (invoked by its upper level)

virtual void send() {};
void set_bin_stream_processor(std::function<void(base::BinStream*)> bin_stream_processor) {
bin_stream_processor_ = bin_stream_processor;
}

virtual void send_complete() {};
std::function<void(base::BinStream*)> get_bin_stream_processor() {
return bin_stream_processor_;
}

void inc_progress();

protected:
ChannelBase();
Expand All @@ -95,15 +101,12 @@ class ChannelBase {
size_t local_id_;
size_t progress_;

ChannelType type_;

std::vector<bool> flushed_{0};

std::unique_ptr<WorkerInfo> worker_info_;
LocalMailbox* mailbox_ = nullptr;
const HashRing* hash_ring_ = nullptr;

static thread_local size_t counter;
std::function<void(base::BinStream*)> bin_stream_processor_ = nullptr;

static thread_local int max_channel_id_;
};

} // namespace husky
17 changes: 6 additions & 11 deletions core/channel/channel_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,27 @@ class ChannelManager {
std::vector<ChannelBase*> selected_channels;
std::vector<std::pair<int, int>> channel_progress_pairs;
for (auto* channel : channels_) {
// Only consider the channels_ which are flushed and do preparation
if (channel->is_flushed()) {
channel->prepare();
selected_channels.push_back(channel);
channel_progress_pairs.push_back({channel->get_channel_id(), channel->get_progress()});
}
selected_channels.push_back(channel);
channel_progress_pairs.push_back({channel->get_channel_id(), channel->get_progress()});
}

// return if no channel is flushed
if (channel_progress_pairs.empty())
return;

// receive from mailbox_ and distbribute
int idx = -1;
std::pair<int, int> pair;
if (channel_progress_pairs.empty()) {
return;
}

while (mailbox_->poll(channel_progress_pairs, &idx)) {
ASSERT_MSG(idx != -1, "ChannelManager: Mailbox poll error");
auto bin = mailbox_->recv(channel_progress_pairs[idx].first, channel_progress_pairs[idx].second);
selected_channels[idx]->in(bin);
selected_channels[idx]->get_bin_stream_processor()(&bin);
}

// reset the flushed_ buffer
for (auto* ch : selected_channels) {
ch->reset_flushed();
ch->post_recv();
}
}

Expand Down
50 changes: 41 additions & 9 deletions core/channel/channel_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <string>

#include "core/channel/channel_store_base.hpp"
#include "core/sync_shuffle_combiner.hpp"
#include "core/context.hpp"
#include "core/objlist.hpp"

Expand All @@ -32,17 +33,46 @@ class ChannelStore : public ChannelStoreBase {
static PushChannel<MsgT, DstObjT>& create_push_channel(ChannelSource& src_list, ObjList<DstObjT>& dst_list,
const std::string& name = "") {
auto& ch = ChannelStoreBase::create_push_channel<MsgT>(src_list, dst_list, name);
setup(ch);
return ch;
}

// Create PushCombinedChannel
template <typename MsgT, typename CombineT, typename DstObjT>
static PushCombinedChannel<MsgT, DstObjT, CombineT>& create_push_combined_channel(ChannelSource& src_list,
ObjList<DstObjT>& dst_list,
const std::string& name = "") {
auto& ch = ChannelStoreBase::create_push_combined_channel<MsgT, CombineT>(src_list, dst_list, name);
setup(ch);
static auto* create_push_combined_channel(ObjList<DstObjT>* dst_list, const std::string& name = "") {
auto* ch = ChannelStoreBase::create_push_combined_channel<MsgT, CombineT>(*dst_list);
common_setup(ch);
ch->set_obj_list(dst_list);
ch->set_combiner(new SyncShuffleCombiner<MsgT, typename DstObjT::KeyT, CombineT>(Context::get_zmq_context()));
ch->set_bin_stream_processor([=](base::BinStream* bin_stream){
auto* recv_buffer = ch->get_recv_buffer();
auto* recv_flags = ch->get_recv_flags();

while (bin_stream->size() != 0) {
typename DstObjT::KeyT key;
*bin_stream >> key;
MsgT msg;
*bin_stream >> msg;

DstObjT* recver_obj = dst_list->find(key);
int idx;
if (recver_obj == nullptr) {
DstObjT obj(key); // Construct obj using key only
idx = dst_list->add_object(std::move(obj));
} else {
idx = dst_list->index_of(recver_obj);
}
if (idx >= ch->get_recv_buffer()->size()) {
recv_buffer->resize(idx + 1);
recv_flags->resize(idx + 1);
}
if ((*recv_flags)[idx] == true) {
CombineT::combine((*recv_buffer)[idx], msg);
} else {
(*recv_buffer)[idx] = std::move(msg);
(*recv_flags)[idx] = true;
}
}
});
return ch;
}

Expand Down Expand Up @@ -82,9 +112,11 @@ class ChannelStore : public ChannelStoreBase {
return ch;
}

static void setup(ChannelBase& ch) {
ch.setup(Context::get_local_tid(), Context::get_global_tid(), Context::get_worker_info(),
Context::get_mailbox());
static void common_setup(ChannelBase* ch) {
ch->set_mailbox(Context::get_mailbox(Context::get_local_tid()));
ch->set_local_id(Context::get_local_tid());
ch->set_global_id(Context::get_global_tid());
ch->set_worker_info(Context::get_worker_info());
}
};

Expand Down
25 changes: 18 additions & 7 deletions core/channel/channel_store_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "core/channel/migrate_channel.hpp"
#include "core/channel/push_channel.hpp"
#include "core/channel/push_combined_channel.hpp"
#include "core/sync_shuffle_combiner.hpp"

namespace husky {

Expand Down Expand Up @@ -53,23 +54,33 @@ class ChannelStoreBase {

// Create PushCombinedChannel
template <typename MsgT, typename CombineT, typename DstObjT>
static PushCombinedChannel<MsgT, DstObjT, CombineT>& create_push_combined_channel(ChannelSource& src_list,
ObjList<DstObjT>& dst_list,
const std::string& name = "") {
static auto* create_push_combined_channel(ObjList<DstObjT>& dst_list,
const std::string& name = "") {
std::string channel_name = name.empty() ? channel_name_prefix + std::to_string(default_channel_id++) : name;
if(channel_map.find(name) != channel_map.end())
throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists");
auto* push_combined_channel = new PushCombinedChannel<MsgT, DstObjT, CombineT>(&src_list, &dst_list);
auto* push_combined_channel = new PushCombinedChannel<MsgT, DstObjT, CombineT>();
channel_map.insert({channel_name, push_combined_channel});
return *push_combined_channel;
return push_combined_channel;
}

// Create PushCombinedChannel
template <typename MsgT, typename CombineT, typename DstObjT>
static auto* create_push_combined_channel(const std::string& name = "") {
std::string channel_name = name.empty() ? channel_name_prefix + std::to_string(default_channel_id++) : name;
if(channel_map.find(name) != channel_map.end())
throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists");
auto* push_combined_channel = new PushCombinedChannel<MsgT, DstObjT, CombineT>();
channel_map.insert({channel_name, push_combined_channel});
return push_combined_channel;
}

template <typename MsgT, typename CombineT, typename DstObjT>
static PushCombinedChannel<MsgT, DstObjT, CombineT>& get_push_combined_channel(const std::string& name = "") {
static auto& get_push_combined_channel(const std::string& name = "") {
if(channel_map.find(name) == channel_map.end())
throw base::HuskyException("ChannelStoreBase::get_channel: Channel name doesn't exist");
auto* channel = channel_map[name];
return *dynamic_cast<PushCombinedChannel<MsgT, DstObjT, CombineT>*>(channel);
return *static_cast<PushCombinedChannel<MsgT, DstObjT, CombineT>*>(channel);
}

// Create MigrateChannel
Expand Down
Loading

0 comments on commit b6f9d84

Please sign in to comment.