Skip to content

Commit

Permalink
Handle tuples of captures in experimental::fence()
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Nov 7, 2022
1 parent 77f8f41 commit 69820f1
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 46 deletions.
156 changes: 111 additions & 45 deletions include/distr_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace detail {
template <typename CGF>
constexpr bool is_safe_cgf = std::is_standard_layout<CGF>::value;

struct fence_builder;

} // namespace detail

struct allow_by_ref_t {};
Expand Down Expand Up @@ -142,65 +144,129 @@ namespace experimental {
friend bool operator!=(const buffer_snapshot& lhs, const buffer_snapshot& rhs) { return !operator==(lhs, rhs); }

private:
template <typename U, int Dims2>
friend buffer_snapshot<U, Dims2> fence(distr_queue& q, const buffer_subrange<U, Dims2>& buf);
friend struct detail::fence_builder;

subrange<Dims> m_sr;
std::vector<T> m_data;

explicit buffer_snapshot(subrange<Dims> sr, std::vector<T> data) : m_sr(sr), m_data(std::move(data)) { assert(m_data.size() == m_sr.range.size()); }
};

template <typename T>
T fence(distr_queue& q, const host_object<T>& ho) {
auto& tm = detail::runtime::get_instance().get_task_manager();
detail::side_effect_map side_effects;
side_effects.add_side_effect(ho.get_id(), side_effect_order::sequential);
const auto fence_guard = tm.get_task(tm.generate_fence_task({}, std::move(side_effects)))->get_fence().await_arrived_and_acquire();
return *ho.get_object();
}
} // namespace experimental

template <typename T, int Dims>
buffer_snapshot<T, Dims> fence(distr_queue& q, const buffer_subrange<T, Dims>& bsr) {
auto& tm = detail::runtime::get_instance().get_task_manager();
const auto range = detail::range_cast<3>(bsr.subrange.range);
const auto offset = detail::id_cast<3>(bsr.subrange.offset);

detail::buffer_capture_map buffer_captures;
buffer_captures.add_read_access(detail::get_buffer_id(bsr.buffer), detail::subrange_cast<3>(bsr.subrange));
const auto fence_guard = tm.get_task(tm.generate_fence_task(std::move(buffer_captures), {}))->get_fence().await_arrived_and_acquire();

auto& bm = detail::runtime::get_instance().get_buffer_manager();
const auto access_info = bm.get_host_buffer<T, Dims>(detail::get_buffer_id(bsr.buffer), access_mode::read, range, offset);

// TODO this should be able to use host_buffer_storage::get_data
const auto allocation_window = buffer_allocation_window<T, Dims>{
access_info.buffer.get_pointer(),
bsr.buffer.get_range(),
access_info.buffer.get_range(),
bsr.subrange.range,
access_info.offset,
bsr.subrange.offset,
};
const auto allocation_range_3 = detail::range_cast<3>(allocation_window.get_allocation_range());
const auto window_range_3 = detail::range_cast<3>(allocation_window.get_window_range());
const auto read_offset_3 = detail::id_cast<3>(allocation_window.get_window_offset_in_allocation());
std::vector<T> data(allocation_window.get_window_range().size());
for(id<3> item{0, 0, 0}; item[0] < window_range_3[0]; ++item[0]) {
for(item[1] = 0; item[1] < window_range_3[1]; ++item[1]) {
for(item[2] = 0; item[2] < window_range_3[2]; ++item[2]) {
data[detail::get_linear_index(window_range_3, item)] =
allocation_window.get_allocation()[detail::get_linear_index(allocation_range_3, item + read_offset_3)];
namespace detail {

struct fence_builder {
buffer_capture_map buffer_captures;
side_effect_map side_effects;

template <typename T>
void add(const experimental::host_object<T>& ho) {
side_effects.add_side_effect(detail::get_host_object_id(ho), experimental::side_effect_order::sequential);
}

template <typename T, int Dims>
void add(const experimental::buffer_subrange<T, Dims>& bsr) {
buffer_captures.add_read_access(detail::get_buffer_id(bsr.buffer), detail::subrange_cast<3>(bsr.subrange));
}

template <typename T, int Dims>
void add(const buffer<T, Dims>& buf) {
buffer_captures.add_read_access(detail::get_buffer_id(buf), subrange<3>({}, range_cast<3>(buf.get_range())));
}

fence_guard await_and_acquire() {
auto& tm = detail::runtime::get_instance().get_task_manager();
const auto tid = tm.generate_fence_task(std::move(buffer_captures), std::move(side_effects));
return tm.get_task(tid)->get_fence().await_arrived_and_acquire();
}

template <typename T>
T extract(const experimental::host_object<T>& ho) const {
return detail::get_host_object_instance(ho);
}

template <typename T, int Dims>
experimental::buffer_snapshot<T, Dims> extract(const experimental::buffer_subrange<T, Dims>& bsr) const {
auto& bm = detail::runtime::get_instance().get_buffer_manager();
const auto access_info = bm.get_host_buffer<T, Dims>(
detail::get_buffer_id(bsr.buffer), access_mode::read, detail::range_cast<3>(bsr.subrange.range), detail::id_cast<3>(bsr.subrange.offset));

// TODO this should be able to use host_buffer_storage::get_data
const auto allocation_window = buffer_allocation_window<T, Dims>{
access_info.buffer.get_pointer(),
bsr.buffer.get_range(),
access_info.buffer.get_range(),
bsr.subrange.range,
access_info.offset,
bsr.subrange.offset,
};
const auto allocation_range_3 = detail::range_cast<3>(allocation_window.get_allocation_range());
const auto window_range_3 = detail::range_cast<3>(allocation_window.get_window_range());
const auto read_offset_3 = detail::id_cast<3>(allocation_window.get_window_offset_in_allocation());
std::vector<T> data(allocation_window.get_window_range().size());
for(id<3> item{0, 0, 0}; item[0] < window_range_3[0]; ++item[0]) {
for(item[1] = 0; item[1] < window_range_3[1]; ++item[1]) {
for(item[2] = 0; item[2] < window_range_3[2]; ++item[2]) {
data[detail::get_linear_index(window_range_3, item)] =
allocation_window.get_allocation()[detail::get_linear_index(allocation_range_3, item + read_offset_3)];
}
}
}

return experimental::buffer_snapshot<T, Dims>(bsr.subrange, std::move(data));
}

return buffer_snapshot<T, Dims>(bsr.subrange, std::move(data));
}
template <typename T, int Dims>
experimental::buffer_snapshot<T, Dims> extract(const buffer<T, Dims>& buf) const {
return extract(buffer_subrange(buffer_subrange(buf, subrange({}, buf.get_range()))));
}
};

template <typename>
struct captured_data;

template <typename T>
struct captured_data<experimental::host_object<T>> {
using type = T;
};

template <typename T, int Dims>
buffer_snapshot<T, Dims> fence(distr_queue& q, const buffer<T, Dims>& buf) {
return fence(buffer_subrange(q, buffer_subrange(buf, subrange({}, buf.get_range()))));
struct captured_data<buffer<T, Dims>> {
using type = experimental::buffer_snapshot<T, Dims>;
};

template <typename T, int Dims>
struct captured_data<experimental::buffer_subrange<T, Dims>> {
using type = experimental::buffer_snapshot<T, Dims>;
};

template <typename T>
using captured_data_t = typename captured_data<T>::type;

template <typename... Captures, size_t... Indices>
std::tuple<detail::captured_data_t<Captures>...> fence_internal(const std::tuple<Captures...>& cap, std::index_sequence<Indices...>) {
detail::fence_builder builder;
(builder.add(std::get<Indices>(cap)), ...);
const auto guard = builder.await_and_acquire();
return std::tuple(builder.extract(std::get<Indices>(cap))...);
}

} // namespace detail

namespace experimental {

template <typename Capture>
detail::captured_data_t<Capture> fence(distr_queue&, const Capture& cap) {
detail::fence_builder builder;
builder.add(cap);
const auto guard = builder.await_and_acquire();
return builder.extract(cap);
}

template <typename... Captures>
std::tuple<detail::captured_data_t<Captures>...> fence(distr_queue&, const std::tuple<Captures...>& cap) {
return detail::fence_internal(cap, std::index_sequence_for<Captures...>());
}

} // namespace experimental
Expand Down
2 changes: 1 addition & 1 deletion test/runtime_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ namespace detail {

{
const auto sr = subrange<1>(0, 3);
const auto snapshot = experimental::fence(q, experimental::buffer_subrange(buf, sr));
const auto [snapshot] = experimental::fence(q, std::tuple(experimental::buffer_subrange(buf, sr)));
CHECK(snapshot.get_subrange() == sr);
CHECK(snapshot.get_data().size() == 3);
CHECK(snapshot.get_data() == std::vector{1, 2, 3});
Expand Down

0 comments on commit 69820f1

Please sign in to comment.