diff --git a/include/buffer_transfer_manager.h b/include/buffer_transfer_manager.h index 419559219..fff50f963 100644 --- a/include/buffer_transfer_manager.h +++ b/include/buffer_transfer_manager.h @@ -24,6 +24,8 @@ namespace detail { bool complete = false; }; + buffer_transfer_manager(); + std::shared_ptr push(const command_pkg& pkg); std::shared_ptr await_push(const command_pkg& pkg); @@ -75,6 +77,8 @@ namespace detail { // - Still outstanding pushes that have been requested through ::await_push std::unordered_map> m_push_blackboard; + mpi_support::data_type m_send_recv_unit; + void poll_incoming_transfers(); void update_incoming_transfers(); void update_outgoing_transfers(); diff --git a/include/command.h b/include/command.h index 5c18ed571..479ad3616 100644 --- a/include/command.h +++ b/include/command.h @@ -223,8 +223,7 @@ namespace detail { using payload_type = command_id; command_pkg pkg; - size_t num_dependencies = 0; // This information is duplicated from unique_frame_ptr::get_payload_count() so that we can still use a - // `const command_frame &` without its owning pointer + size_t num_dependencies = 0; payload_type dependencies[]; // variable-sized structure diff --git a/include/frame.h b/include/frame.h index dd79f4ae7..c57b5e205 100644 --- a/include/frame.h +++ b/include/frame.h @@ -59,7 +59,6 @@ class unique_frame_ptr : private std::unique_ptr + namespace celerity::detail::mpi_support { constexpr int TAG_CMD = 0; constexpr int TAG_DATA_TRANSFER = 1; constexpr int TAG_TELEMETRY = 2; +class data_type { + public: + explicit data_type(MPI_Datatype dt) : m_dt(dt) {} + data_type(const data_type&) = delete; + data_type& operator=(const data_type&) = delete; + ~data_type() { MPI_Type_free(&m_dt); } + + operator MPI_Datatype() const { return m_dt; } + + private: + MPI_Datatype m_dt; +}; + } // namespace celerity::detail::mpi_support diff --git a/src/buffer_transfer_manager.cc b/src/buffer_transfer_manager.cc index 5b4c995ec..714a6a4b2 100644 --- a/src/buffer_transfer_manager.cc +++ b/src/buffer_transfer_manager.cc @@ -12,6 +12,17 @@ namespace celerity { namespace detail { + inline constexpr size_t send_recv_unit_bytes = 64; + + mpi_support::data_type make_send_recv_unit() { + MPI_Datatype unit; + MPI_Type_contiguous(send_recv_unit_bytes, MPI_BYTE, &unit); + MPI_Type_commit(&unit); + return mpi_support::data_type(unit); + } + + buffer_transfer_manager::buffer_transfer_manager() : m_send_recv_unit(make_send_recv_unit()) {} + std::shared_ptr buffer_transfer_manager::push(const command_pkg& pkg) { assert(pkg.get_command_type() == command_type::push); auto t_handle = std::make_shared(); @@ -30,12 +41,13 @@ namespace detail { frame->push_cid = pkg.cid; bm.get_buffer_data(data.bid, data.sr, frame->data); - CELERITY_TRACE("Ready to send {} of buffer {} ({} B) to {}", data.sr, data.bid, frame.get_size_bytes(), data.target); + size_t frame_units = (frame.get_size_bytes() + send_recv_unit_bytes - 1) / send_recv_unit_bytes; + CELERITY_TRACE("Ready to send {} of buffer {} ({} * {}B) to {}", data.sr, data.bid, frame_units, send_recv_unit_bytes, data.target); // Start transmitting data MPI_Request req; - assert(frame.get_size_bytes() <= static_cast(std::numeric_limits::max())); - MPI_Isend(frame.get_pointer(), static_cast(frame.get_size_bytes()), MPI_BYTE, static_cast(data.target), mpi_support::TAG_DATA_TRANSFER, + assert(frame_units <= static_cast(std::numeric_limits::max())); + MPI_Isend(frame.get_pointer(), static_cast(frame_units), m_send_recv_unit, static_cast(data.target), mpi_support::TAG_DATA_TRANSFER, MPI_COMM_WORLD, &req); auto transfer = std::make_unique(); @@ -86,18 +98,18 @@ namespace detail { // No incoming transfers at the moment return; } - int frame_bytes; - MPI_Get_count(&status, MPI_BYTE, &frame_bytes); + int frame_units; + MPI_Get_count(&status, m_send_recv_unit, &frame_units); auto transfer = std::make_unique(); transfer->source_nid = static_cast(status.MPI_SOURCE); - transfer->frame = unique_frame_ptr(from_size_bytes, static_cast(frame_bytes)); + transfer->frame = unique_frame_ptr(from_size_bytes, static_cast(frame_units) * send_recv_unit_bytes); // Start receiving data - MPI_Imrecv(transfer->frame.get_pointer(), frame_bytes, MPI_BYTE, &msg, &transfer->request); + MPI_Imrecv(transfer->frame.get_pointer(), frame_units, m_send_recv_unit, &msg, &transfer->request); m_incoming_transfers.push_back(std::move(transfer)); - CELERITY_TRACE("Receiving incoming data of size {} B from {}", frame_bytes, status.MPI_SOURCE); + CELERITY_TRACE("Receiving incoming data of size {} * {}B from {}", frame_units, send_recv_unit_bytes, status.MPI_SOURCE); } void buffer_transfer_manager::update_incoming_transfers() { diff --git a/src/executor.cc b/src/executor.cc index 0d5236e47..72fd6c67f 100644 --- a/src/executor.cc +++ b/src/executor.cc @@ -117,7 +117,6 @@ namespace detail { MPI_Get_count(&status, MPI_BYTE, &frame_bytes); unique_frame_ptr frame(from_size_bytes, static_cast(frame_bytes)); MPI_Mrecv(frame.get_pointer(), frame_bytes, MPI_BYTE, &msg, &status); - assert(frame->num_dependencies == frame.get_payload_count()); command_queue.push(std::move(frame)); if(!m_first_command_received) {