Skip to content

Commit

Permalink
Ensure correctness of unique_frame_ptr ↔ unique_payload_ptr conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Jul 15, 2022
1 parent 2a7437d commit 738a789
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions include/mpi_support.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <cassert>
#include <functional>
#include <memory>
#include <utility>

Expand Down Expand Up @@ -104,15 +105,40 @@ class unique_payload_ptr : private std::unique_ptr<void, std::function<void(void
unique_payload_ptr() noexcept = default;

template <typename T>
explicit unique_payload_ptr(allocate_uninitialized_tag<T>, size_t count) : impl(operator new(count * sizeof(T)), [](void* p) { operator delete(p); }) {}
explicit unique_payload_ptr(allocate_uninitialized_tag<T>, size_t count) : impl(allocate_uninitialized_payload<T>(count)) {}

template <typename Frame>
explicit unique_payload_ptr(unique_frame_ptr<Frame> frame) : impl(frame.release() + 1, [](void* p) { delete(static_cast<Frame*>(p) - 1); }) {}
explicit unique_payload_ptr(unique_frame_ptr<Frame> frame) : impl(unique_frame_to_payload(std::move(frame))) {}

void* get_pointer() { return impl::get(); }
const void* get_pointer() const { return impl::get(); }

using impl::operator bool;

private:
template <typename Frame>
static void delete_frame_from_payload(void* const type_erased_payload) {
const auto payload = static_cast<typename Frame::payload_type*>(type_erased_payload);
const auto frame = reinterpret_cast<Frame*>(payload) - 1; // frame header is located at -sizeof(Frame) bytes (-1 Frame object)
delete frame;
}

template <typename Frame>
static impl unique_frame_to_payload(unique_frame_ptr<Frame> unique_frame) {
deleter_type deleter{delete_frame_from_payload<Frame>}; // allocate deleter (aka std::function) first so `impl` construction is noexcept
const auto frame = unique_frame.release();
const auto payload = reinterpret_cast<typename Frame::payload_type*>(frame + 1); // payload is located at +sizeof(Frame) bytes (+1 Frame object)
return impl{payload, std::move(deleter)};
}

static void delete_uninitialized_payload(void* const p) { operator delete(p); }

template <typename T>
static impl allocate_uninitialized_payload(size_t count) {
deleter_type deleter{delete_uninitialized_payload}; // allocate deleter (aka std::function) first so `impl` construction is noexcept
const auto payload = operator new(count * sizeof(T));
return impl{payload, std::move(deleter)};
}
};

} // namespace celerity::detail

0 comments on commit 738a789

Please sign in to comment.