Skip to content

Commit

Permalink
print_graph: deduplicate requirements label generation
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Aug 29, 2022
1 parent 635053a commit 8122798
Showing 1 changed file with 24 additions and 38 deletions.
62 changes: 24 additions & 38 deletions src/print_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,13 @@ namespace detail {
}
}

std::string get_task_label(const task& tsk, const reduction_manager& rm) {
std::string label;
fmt::format_to(std::back_inserter(label), "T{}", tsk.get_id());
if(!tsk.get_debug_name().empty()) { fmt::format_to(std::back_inserter(label), " \"{}\" ", tsk.get_debug_name()); }
void format_requirements(std::string& label, const task& tsk, subrange<3> execution_range, access_mode reduction_init_mode, const reduction_manager& rm) {
for(auto rid : tsk.get_reductions()) {
auto reduction = rm.get_reduction(rid);

const auto execution_range = subrange<3>{tsk.get_global_offset(), tsk.get_global_size()};
auto rmode = cl::sycl::access::mode::discard_write;
if(reduction.initialize_from_buffer) { rmode = reduction_init_mode; }

fmt::format_to(std::back_inserter(label), "<br/><b>{}</b>", task_type_string(tsk.get_type()));
if(tsk.get_type() == task_type::host_compute || tsk.get_type() == task_type::device_compute) {
fmt::format_to(std::back_inserter(label), " {}", execution_range);
} else if(tsk.get_type() == task_type::collective) {
fmt::format_to(std::back_inserter(label), " in CG{}", tsk.get_collective_group_id());
}

for(auto rid : tsk.get_reductions()) {
const auto reduction = rm.get_reduction(rid);
const auto rmode = reduction.initialize_from_buffer ? access_mode::read_write : access_mode::discard_write;
const auto bid = reduction.output_buffer_id;
const auto req = GridRegion<3>{{1, 1, 1}};
fmt::format_to(std::back_inserter(label), "<br/>(R{}) <i>{}</i> B{} {}", rid, detail::access::mode_traits::name(rmode), bid, req);
Expand All @@ -60,13 +50,31 @@ namespace detail {
for(const auto bid : bam.get_accessed_buffers()) {
for(const auto mode : bam.get_access_modes(bid)) {
const auto req = bam.get_requirements_for_access(bid, mode, tsk.get_dimensions(), execution_range, tsk.get_global_size());
// While uncommon, we do support chunks that don't require access to a particular buffer at all.
if(!req.empty()) { fmt::format_to(std::back_inserter(label), "<br/><i>{}</i> B{} {}", detail::access::mode_traits::name(mode), bid, req); }
}
}

for(const auto& [hoid, order] : tsk.get_side_effect_map()) {
fmt::format_to(std::back_inserter(label), "<br/><i>affect</i> H{}", hoid);
}
}

std::string get_task_label(const task& tsk, const reduction_manager& rm) {
std::string label;
fmt::format_to(std::back_inserter(label), "T{}", tsk.get_id());
if(!tsk.get_debug_name().empty()) { fmt::format_to(std::back_inserter(label), " \"{}\" ", tsk.get_debug_name()); }

const auto execution_range = subrange<3>{tsk.get_global_offset(), tsk.get_global_size()};

fmt::format_to(std::back_inserter(label), "<br/><b>{}</b>", task_type_string(tsk.get_type()));
if(tsk.get_type() == task_type::host_compute || tsk.get_type() == task_type::device_compute) {
fmt::format_to(std::back_inserter(label), " {}", execution_range);
} else if(tsk.get_type() == task_type::collective) {
fmt::format_to(std::back_inserter(label), " in CG{}", tsk.get_collective_group_id());
}

format_requirements(label, tsk, execution_range, access_mode::read_write, rm);

return label;
}
Expand Down Expand Up @@ -127,29 +135,7 @@ namespace detail {
execution_range = ecmd->get_execution_range();
}

for(auto rid : tsk.get_reductions()) {
auto reduction = rm.get_reduction(rid);

auto rmode = cl::sycl::access::mode::discard_write;
if(reduction.initialize_from_buffer) { rmode = reduction_init_mode; }

const auto bid = reduction.output_buffer_id;
const auto req = GridRegion<3>{{1, 1, 1}};
fmt::format_to(std::back_inserter(label), "<br/>(R{}) <i>{}</i> B{} {}", rid, detail::access::mode_traits::name(rmode), bid, req);
}

const auto& bam = tsk.get_buffer_access_map();
for(const auto bid : bam.get_accessed_buffers()) {
for(const auto mode : bam.get_access_modes(bid)) {
const auto req = bam.get_requirements_for_access(bid, mode, tsk.get_dimensions(), execution_range, tsk.get_global_size());
// While uncommon, we do support chunks that don't require access to a particular buffer at all.
if(!req.empty()) { fmt::format_to(std::back_inserter(label), "<br/><i>{}</i> B{} {}", detail::access::mode_traits::name(mode), bid, req); }
}
}

for(const auto& [hoid, order] : tsk.get_side_effect_map()) {
fmt::format_to(std::back_inserter(label), "<br/><i>affect</i> H{}", hoid);
}
format_requirements(label, tsk, execution_range, reduction_init_mode, rm);
}

return label;
Expand Down

0 comments on commit 8122798

Please sign in to comment.