Skip to content

Commit

Permalink
Transfer data_frames in multiples of 64 bytes to lift maximum transfe…
Browse files Browse the repository at this point in the history
…r size restriction
  • Loading branch information
fknorr committed Nov 8, 2022
1 parent 0a4fca5 commit 972682f
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 12 deletions.
4 changes: 4 additions & 0 deletions include/buffer_transfer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ namespace detail {
bool complete = false;
};

buffer_transfer_manager();

std::shared_ptr<const transfer_handle> push(const command_pkg& pkg);
std::shared_ptr<const transfer_handle> await_push(const command_pkg& pkg);

Expand Down Expand Up @@ -75,6 +77,8 @@ namespace detail {
// - Still outstanding pushes that have been requested through ::await_push
std::unordered_map<command_id, std::shared_ptr<incoming_transfer_handle>> m_push_blackboard;

mpi_support::data_type m_send_recv_unit;

void poll_incoming_transfers();
void update_incoming_transfers();
void update_outgoing_transfers();
Expand Down
3 changes: 1 addition & 2 deletions include/command.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@ namespace detail {
using payload_type = command_id;

command_pkg pkg;
size_t num_dependencies = 0; // This information is duplicated from unique_frame_ptr::get_payload_count() so that we can still use a
// `const command_frame &` without its owning pointer
size_t num_dependencies = 0;
payload_type dependencies[];

// variable-sized structure
Expand Down
1 change: 0 additions & 1 deletion include/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class unique_frame_ptr : private std::unique_ptr<Frame, unique_frame_delete<Fram
Frame* get_pointer() { return impl::get(); }
const Frame* get_pointer() const { return impl::get(); }
size_t get_size_bytes() const { return m_size_bytes; }
size_t get_payload_count() const { return (m_size_bytes - sizeof(Frame)) / sizeof(payload_type); }

using impl::operator bool;
using impl::operator*;
Expand Down
15 changes: 15 additions & 0 deletions include/mpi_support.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
#pragma once

#include <mpi.h>

namespace celerity::detail::mpi_support {

constexpr int TAG_CMD = 0;
constexpr int TAG_DATA_TRANSFER = 1;
constexpr int TAG_TELEMETRY = 2;

class data_type {
public:
explicit data_type(MPI_Datatype dt) : m_dt(dt) {}
data_type(const data_type&) = delete;
data_type& operator=(const data_type&) = delete;
~data_type() { MPI_Type_free(&m_dt); }

operator MPI_Datatype() const { return m_dt; }

private:
MPI_Datatype m_dt;
};

} // namespace celerity::detail::mpi_support
28 changes: 20 additions & 8 deletions src/buffer_transfer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
namespace celerity {
namespace detail {

inline constexpr size_t send_recv_unit_bytes = 64;

mpi_support::data_type make_send_recv_unit() {
MPI_Datatype unit;
MPI_Type_contiguous(send_recv_unit_bytes, MPI_BYTE, &unit);
MPI_Type_commit(&unit);
return mpi_support::data_type(unit);
}

buffer_transfer_manager::buffer_transfer_manager() : m_send_recv_unit(make_send_recv_unit()) {}

std::shared_ptr<const buffer_transfer_manager::transfer_handle> buffer_transfer_manager::push(const command_pkg& pkg) {
assert(pkg.get_command_type() == command_type::push);
auto t_handle = std::make_shared<transfer_handle>();
Expand All @@ -30,12 +41,13 @@ namespace detail {
frame->push_cid = pkg.cid;
bm.get_buffer_data(data.bid, data.sr, frame->data);

CELERITY_TRACE("Ready to send {} of buffer {} ({} B) to {}", data.sr, data.bid, frame.get_size_bytes(), data.target);
size_t frame_units = (frame.get_size_bytes() + send_recv_unit_bytes - 1) / send_recv_unit_bytes;
CELERITY_TRACE("Ready to send {} of buffer {} ({} * {}B) to {}", data.sr, data.bid, frame_units, send_recv_unit_bytes, data.target);

// Start transmitting data
MPI_Request req;
assert(frame.get_size_bytes() <= static_cast<size_t>(std::numeric_limits<int>::max()));
MPI_Isend(frame.get_pointer(), static_cast<int>(frame.get_size_bytes()), MPI_BYTE, static_cast<int>(data.target), mpi_support::TAG_DATA_TRANSFER,
assert(frame_units <= static_cast<size_t>(std::numeric_limits<int>::max()));
MPI_Isend(frame.get_pointer(), static_cast<int>(frame_units), m_send_recv_unit, static_cast<int>(data.target), mpi_support::TAG_DATA_TRANSFER,
MPI_COMM_WORLD, &req);

auto transfer = std::make_unique<transfer_out>();
Expand Down Expand Up @@ -86,18 +98,18 @@ namespace detail {
// No incoming transfers at the moment
return;
}
int frame_bytes;
MPI_Get_count(&status, MPI_BYTE, &frame_bytes);
int frame_units;
MPI_Get_count(&status, m_send_recv_unit, &frame_units);

auto transfer = std::make_unique<transfer_in>();
transfer->source_nid = static_cast<node_id>(status.MPI_SOURCE);
transfer->frame = unique_frame_ptr<data_frame>(from_size_bytes, static_cast<size_t>(frame_bytes));
transfer->frame = unique_frame_ptr<data_frame>(from_size_bytes, static_cast<size_t>(frame_units) * send_recv_unit_bytes);

// Start receiving data
MPI_Imrecv(transfer->frame.get_pointer(), frame_bytes, MPI_BYTE, &msg, &transfer->request);
MPI_Imrecv(transfer->frame.get_pointer(), frame_units, m_send_recv_unit, &msg, &transfer->request);
m_incoming_transfers.push_back(std::move(transfer));

CELERITY_TRACE("Receiving incoming data of size {} B from {}", frame_bytes, status.MPI_SOURCE);
CELERITY_TRACE("Receiving incoming data of size {} * {}B from {}", frame_units, send_recv_unit_bytes, status.MPI_SOURCE);
}

void buffer_transfer_manager::update_incoming_transfers() {
Expand Down
1 change: 0 additions & 1 deletion src/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ namespace detail {
MPI_Get_count(&status, MPI_BYTE, &frame_bytes);
unique_frame_ptr<command_frame> frame(from_size_bytes, static_cast<size_t>(frame_bytes));
MPI_Mrecv(frame.get_pointer(), frame_bytes, MPI_BYTE, &msg, &status);
assert(frame->num_dependencies == frame.get_payload_count());
command_queue.push(std::move(frame));

if(!m_first_command_received) {
Expand Down

0 comments on commit 972682f

Please sign in to comment.