Skip to content

Commit

Permalink
[WIP][Core][Channel] Refactor Channel (#2)
Browse files Browse the repository at this point in the history
push_channel and some relevant things.
  • Loading branch information
MiaoLoud authored and ddmbr committed Jan 26, 2017
1 parent 50099e0 commit 0b24cf2
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 193 deletions.
26 changes: 12 additions & 14 deletions core/channel/channel_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,50 +40,48 @@ class ChannelBase {

void set_local_id(size_t local_id) { local_id_ = local_id; }
void set_global_id(size_t global_id) { global_id_ = global_id; }
void set_worker_info(const WorkerInfo& worker_info) { worker_info_.reset(new WorkerInfo(worker_info)); }
virtual void set_worker_info(const WorkerInfo& worker_info) { worker_info_.reset(new WorkerInfo(worker_info)); }
void set_mailbox(LocalMailbox* mailbox) { mailbox_ = mailbox; }

// Top-level APIs

virtual void in() {
this->recv();
this->post_recv();
};
}

virtual void out() {
this->pre_send();
this->send();
this->post_send();
};
}

// Second-level APIs

virtual void recv() {
// A simple default synchronous implementation
if(mailbox_ == nullptr)
if (mailbox_ == nullptr)
throw base::HuskyException("Local mailbox not set, and thus cannot use the recv() method.");

while(mailbox_->poll(channel_id_, progress_)) {
while (mailbox_->poll(channel_id_, progress_)) {
base::BinStream bin_stream = mailbox_->recv(channel_id_, progress_);
if(bin_stream_processor_ != nullptr)
if (bin_stream_processor_ != nullptr)
bin_stream_processor_(&bin_stream);
}
};
}

virtual void post_recv() {};
virtual void pre_send() {};
virtual void send() {};
virtual void post_send() {};
virtual void post_recv(){}
virtual void pre_send(){}
virtual void send(){}
virtual void post_send(){}

// Third-level APIs (invoked by its upper level)

void set_bin_stream_processor(std::function<void(base::BinStream*)> bin_stream_processor) {
bin_stream_processor_ = bin_stream_processor;
}

std::function<void(base::BinStream*)> get_bin_stream_processor() {
return bin_stream_processor_;
}
std::function<void(base::BinStream*)> get_bin_stream_processor() { return bin_stream_processor_; }

void inc_progress();

Expand Down
44 changes: 38 additions & 6 deletions core/channel/channel_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#include <string>

#include "core/channel/channel_store_base.hpp"
#include "core/sync_shuffle_combiner.hpp"
#include "core/context.hpp"
#include "core/objlist.hpp"
#include "core/sync_shuffle_combiner.hpp"

namespace husky {

Expand All @@ -30,9 +30,38 @@ class ChannelStore : public ChannelStoreBase {
public:
// Create PushChannel
template <typename MsgT, typename DstObjT>
static PushChannel<MsgT, DstObjT>& create_push_channel(ChannelSource& src_list, ObjList<DstObjT>& dst_list,
const std::string& name = "") {
auto& ch = ChannelStoreBase::create_push_channel<MsgT>(src_list, dst_list, name);
static auto* create_push_channel(ObjList<DstObjT>* dst_list, const std::string& name = "") {
auto* ch = ChannelStoreBase::create_push_channel<MsgT>(*dst_list);
common_setup(ch);
ch->set_base_obj_addr_getter([=](){
// TODO(fan) should do &dst_list->get_data.get(0) in debug mode
return &dst_list->get_data[0];
});
ch->set_bin_stream_processor([=](base::BinStream* bin_stream) {
auto* recv_buffer = ch->get_recv_buffer();

while (bin_stream->size() != 0) {
typename DstObjT::KeyT key;
*bin_stream >> key;

MsgT msg;
*bin_stream >> msg;

DstObjT* recver_obj = dst_list->find(key);
int idx;
if (recver_obj == nullptr) {
DstObjT obj(key); // Construct obj using key only
idx = dst_list->add_object(std::move(obj));
} else {
idx = dst_list->index_of(recver_obj);
}
if (idx >= ch->get_recv_buffer()->size()) {
recv_buffer->resize(idx + 1);
}

(*recv_buffer)[idx].push_back(std::move(msg));
}
});
return ch;
}

Expand All @@ -41,9 +70,12 @@ class ChannelStore : public ChannelStoreBase {
static auto* create_push_combined_channel(ObjList<DstObjT>* dst_list, const std::string& name = "") {
auto* ch = ChannelStoreBase::create_push_combined_channel<MsgT, CombineT>(*dst_list);
common_setup(ch);
ch->set_obj_list(dst_list);
ch->set_base_obj_addr_getter([=](){
// TODO(fan) should do &dst_list->get_data.get(0) in debug mode
return &dst_list->get_data[0];
});
ch->set_combiner(new SyncShuffleCombiner<MsgT, typename DstObjT::KeyT, CombineT>(Context::get_zmq_context()));
ch->set_bin_stream_processor([=](base::BinStream* bin_stream){
ch->set_bin_stream_processor([=](base::BinStream* bin_stream) {
auto* recv_buffer = ch->get_recv_buffer();
auto* recv_flags = ch->get_recv_flags();

Expand Down
45 changes: 27 additions & 18 deletions core/channel/channel_store_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,40 @@ class ChannelStoreBase {
public:
// Create PushChannel
template <typename MsgT, typename DstObjT>
static PushChannel<MsgT, DstObjT>& create_push_channel(ChannelSource& src_list, ObjList<DstObjT>& dst_list,
const std::string& name = "") {
static auto* create_push_channel(ObjList<DstObjT>& dst_list, const std::string& name = "") {
std::string channel_name = name.empty() ? channel_name_prefix + std::to_string(default_channel_id++) : name;
if(channel_map.find(name) != channel_map.end())
throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists");
auto* push_channel = new PushChannel<MsgT, DstObjT>(&src_list, &dst_list);
if (channel_map.find(name) != channel_map.end())
throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists");
auto* push_channel = new PushChannel<MsgT, DstObjT>();
channel_map.insert({channel_name, push_channel});
return *push_channel;
return push_channel;
}

// Create PushChannel
template <typename MsgT, typename DstObjT>
static auto* create_push_combined_channel(const std::string& name = "") {
std::string channel_name = name.empty() ? channel_name_prefix + std::to_string(default_channel_id++) : name;
if (channel_map.find(name) != channel_map.end())
throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists");
auto* push_channel = new PushChannel<MsgT, DstObjT>();
channel_map.insert({channel_name, push_channel});
return push_channel;
}

template <typename MsgT, typename DstObjT>
static PushChannel<MsgT, DstObjT>& get_push_channel(const std::string& name = "") {
if(channel_map.find(name) == channel_map.end())
throw base::HuskyException("ChannelStoreBase::get_channel: Channel name doesn't exist");
static auto& get_push_channel(const std::string& name = "") {
if (channel_map.find(name) == channel_map.end())
throw base::HuskyException("ChannelStoreBase::get_channel: Channel name doesn't exist");
auto* channel = channel_map[name];
return *dynamic_cast<PushChannel<MsgT, DstObjT>*>(channel);
return *static_cast<PushChannel<MsgT, DstObjT>*>(channel);
}

// Create PushCombinedChannel
template <typename MsgT, typename CombineT, typename DstObjT>
static auto* create_push_combined_channel(ObjList<DstObjT>& dst_list,
const std::string& name = "") {
static auto* create_push_combined_channel(ObjList<DstObjT>& dst_list, const std::string& name = "") {
std::string channel_name = name.empty() ? channel_name_prefix + std::to_string(default_channel_id++) : name;
if(channel_map.find(name) != channel_map.end())
throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists");
if (channel_map.find(name) != channel_map.end())
throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists");
auto* push_combined_channel = new PushCombinedChannel<MsgT, DstObjT, CombineT>();
channel_map.insert({channel_name, push_combined_channel});
return push_combined_channel;
Expand All @@ -68,17 +77,17 @@ class ChannelStoreBase {
template <typename MsgT, typename CombineT, typename DstObjT>
static auto* create_push_combined_channel(const std::string& name = "") {
std::string channel_name = name.empty() ? channel_name_prefix + std::to_string(default_channel_id++) : name;
if(channel_map.find(name) != channel_map.end())
throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists");
if (channel_map.find(name) != channel_map.end())
throw base::HuskyException("ChannelStoreBase::create_channel: Channel name already exists");
auto* push_combined_channel = new PushCombinedChannel<MsgT, DstObjT, CombineT>();
channel_map.insert({channel_name, push_combined_channel});
return push_combined_channel;
}

template <typename MsgT, typename CombineT, typename DstObjT>
static auto& get_push_combined_channel(const std::string& name = "") {
if(channel_map.find(name) == channel_map.end())
throw base::HuskyException("ChannelStoreBase::get_channel: Channel name doesn't exist");
if (channel_map.find(name) == channel_map.end())
throw base::HuskyException("ChannelStoreBase::get_channel: Channel name doesn't exist");
auto* channel = channel_map[name];
return *static_cast<PushCombinedChannel<MsgT, DstObjT, CombineT>*>(channel);
}
Expand Down
145 changes: 47 additions & 98 deletions core/channel/push_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,66 +21,18 @@
#include "core/channel/channel_impl.hpp"
#include "core/hash_ring.hpp"
#include "core/mailbox.hpp"
#include "core/objlist.hpp"
#include "core/worker_info.hpp"

namespace husky {

using base::BinStream;

template <typename MsgT, typename DstObjT>
class PushChannel : public Source2ObjListChannel<DstObjT> {
class PushChannel : public ChannelBase {
public:
PushChannel(ChannelSource* src, ObjList<DstObjT>* dst) : Source2ObjListChannel<DstObjT>(src, dst) {
this->src_ptr_->register_outchannel(this->channel_id_, this);
this->dst_ptr_->register_inchannel(this->channel_id_, this);

recv_comm_handler_ = [&](const MsgT& msg, DstObjT* recver_obj) {
size_t idx = this->dst_ptr_->index_of(recver_obj);
if (idx >= recv_buffer_.size())
recv_buffer_.resize(idx + 1);
recv_buffer_[idx].push_back(std::move(msg));
};
}

~PushChannel() override {
if (this->src_ptr_ != nullptr)
this->src_ptr_->deregister_outchannel(this->channel_id_);
if (this->dst_ptr_ != nullptr)
this->dst_ptr_->deregister_inchannel(this->channel_id_);
}
PushChannel(const PushChannel&) = delete;
PushChannel& operator=(const PushChannel&) = delete;

PushChannel(PushChannel&&) = default;
PushChannel& operator=(PushChannel&&) = default;

void customized_setup() override {
// use get_largest_tid() instead of get_num_workers()
// sine we may only use a subset of worker
send_buffer_.resize(this->worker_info_->get_largest_tid() + 1);
}

void push(const MsgT& msg, const typename DstObjT::KeyT& key) {
int dst_worker_id = this->worker_info_->get_hash_ring().hash_lookup(key);
send_buffer_[dst_worker_id] << key << msg;
}

const std::vector<MsgT>& get(const DstObjT& obj) {
auto idx = this->dst_ptr_->index_of(&obj);
if (idx >= recv_buffer_.size()) { // resize recv_buffer_ if it is not large enough
recv_buffer_.resize(this->dst_ptr_->get_size());
}
return recv_buffer_[idx];
}

void prepare() override { clear_recv_buffer_(); }
PushChannel() = default;

void in(BinStream& bin) override { process_bin(bin); }
// The following are virtual methods

void out() override { flush(); }

void send() {
void send() override {
int start = this->global_id_;
for (int i = 0; i < send_buffer_.size(); ++i) {
int dst = (start + i) % send_buffer_.size();
Expand All @@ -91,67 +43,64 @@ class PushChannel : public Source2ObjListChannel<DstObjT> {
}
}

void send_complete() {
void post_send() override {
this->inc_progress();
this->mailbox_->send_complete(this->channel_id_, this->progress_, this->worker_info_->get_local_tids(),
this->worker_info_->get_pids());
}

/// This method is only useful without list_execute
void flush() {
send();
send_complete();
void set_worker_info(const WorkerInfo& worker_info) override {
worker_info_.reset(new WorkerInfo(worker_info));
if (send_buffer_.size() != worker_info_->get_largest_tid() + 1)
send_buffer_.resize(worker_info_->get_largest_tid() + 1);
}

/// This method is only useful without list_execute
void prepare_messages() {
if (!this->is_flushed())
return;
clear_recv_buffer_();
while (this->mailbox_->poll(this->channel_id_, this->progress_)) {
auto bin_push = this->mailbox_->recv(this->channel_id_, this->progress_);
process_bin(bin_push);
}
this->reset_flushed();
}
// The following are specific to this channel type

/// \brief Set a customized handler to handle incoming communication
///
/// The handler takes the message and its destinating object as its two arguments.
/// Then the handler decides the operation to apply on this object.
///
/// @param comm_handler A handler that contains the operation to be applied on
/// the obejct, using the received message.
void set_recv_comm_handler(std::function<void(const MsgT&, DstObjT*)> recv_comm_handler) {
recv_comm_handler_ = recv_comm_handler;
inline void push(const MsgT& msg, const typename DstObjT::KeyT& key) {
int dst_worker_id = this->worker_info_->get_hash_ring().hash_lookup(key);
send_buffer_[dst_worker_id] << key << msg;
}

protected:
void clear_recv_buffer_() {
// TODO(yuzhen): What types of clear do we need?
for (auto& vec : recv_buffer_)
vec.clear();
inline const std::vector<MsgT>& get(const DstObjT& obj) {
if (this->base_obj_addr_getter_ == nullptr) {
throw base::HuskyException(
"Object Address Getter not set and thus cannot get message by providing an object. "
"Please use `set_base_obj_addr_getter` first.");
}
auto idx = &obj - this->base_obj_addr_getter_();
if (idx >= recv_buffer_.size()) { // resize recv_buffer_ if it is not large enough
recv_buffer_.resize(this->obj_list_ptr_->get_size());
}
return recv_buffer_[idx];
}
void process_bin(BinStream& bin_push) {
while (bin_push.size() != 0) {
typename DstObjT::KeyT key;
bin_push >> key;
MsgT msg;
bin_push >> msg;

DstObjT* recver_obj = this->dst_ptr_->find(key);
if (recver_obj == nullptr) {
DstObjT obj(key); // Construct obj using key only
size_t idx = this->dst_ptr_->add_object(std::move(obj));
recver_obj = &(this->dst_ptr_->get(idx));
}
recv_comm_handler_(msg, recver_obj);

inline const std::vector<MsgT>& get(int idx) { return recv_buffer_[idx]; }

inline bool has_msgs(const DstObjT& obj) {
if (this->base_obj_addr_getter_ == nullptr) {
throw base::HuskyException(
"Object Address Getter not set and thus cannot get message by providing an object. "
"Please use `set_base_obj_addr_getter` first.");
}
auto idx = &obj - this->base_obj_addr_getter_();
return has_msgs(idx);
}

inline bool has_msgs(int idx) {
if (idx >= recv_buffer_.size())
return false;
return recv_buffer_[idx].size() != 0;
}

std::function<void(const MsgT&, DstObjT*)> recv_comm_handler_;
std::vector<BinStream> send_buffer_;
void set_base_obj_addr_getter(std::function<DstObjT*()> base_obj_addr_getter) { base_obj_addr_getter_ = base_obj_addr_getter; }

std::vector<std::vector<MsgT>>* get_recv_buffer() { return &recv_buffer_; }

protected:
std::vector<base::BinStream> send_buffer_;
std::vector<std::vector<MsgT>> recv_buffer_;
std::function<DstObjT*()> base_obj_addr_getter_; // TODO(fan) cache the address?
};

} // namespace husky
Loading

0 comments on commit 0b24cf2

Please sign in to comment.