Skip to content

Commit

Permalink
Add new CELERITY_GRAPH_PRINT_MAX_VERTS config option
Browse files Browse the repository at this point in the history
...to control how large task/command graphs can be before their GraphViz
output is omitted.
  • Loading branch information
psalz committed Feb 4, 2022
1 parent 64cd02d commit 6ecaab0
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 17 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ Celerity's runtime behavior:
automatically assign a unique device to each worker on a host.
- `CELERITY_PROFILE_KERNEL` controls whether SYCL queue profiling information
should be queried (currently not supported when using hipSYCL).
- `CELERITY_GRAPH_PRINT_MAX_VERTS` sets the maximum number of vertices the
task/command graphs can have above which their GraphViz output will be omitted.

## Disclaimer

Expand Down
2 changes: 1 addition & 1 deletion include/command_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ namespace detail {

auto& task_commands(task_id tid) { return by_task.at(tid); }

std::optional<std::string> print_graph() const;
std::optional<std::string> print_graph(size_t max_nodes) const;

// TODO unify dependency terminology to this
void add_dependency(abstract_command* depender, abstract_command* dependee, dependency_kind kind = dependency_kind::TRUE_DEP) {
Expand Down
3 changes: 3 additions & 0 deletions include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ namespace detail {
const std::optional<device_config>& get_device_config() const { return device_cfg; };
std::optional<bool> get_enable_device_profiling() const { return enable_device_profiling; };

size_t get_graph_print_max_verts() const { return graph_print_max_verts; };

private:
log_level log_lvl;
host_config host_cfg;
std::optional<device_config> device_cfg;
std::optional<bool> enable_device_profiling;
size_t graph_print_max_verts = 200;
};

} // namespace detail
Expand Down
2 changes: 1 addition & 1 deletion include/task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ namespace detail {

const task* get_task(task_id tid) const;

std::optional<std::string> print_graph() const;
std::optional<std::string> print_graph(size_t max_nodes) const;

/**
* @brief Shuts down the task_manager, freeing all stored tasks.
Expand Down
6 changes: 2 additions & 4 deletions src/command_graph.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "command_graph.h"

#include "log.h"
#include "print_graph.h"

namespace celerity {
Expand All @@ -27,9 +26,8 @@ namespace detail {
}
}

std::optional<std::string> command_graph::print_graph() const {
if(command_count() < 200) { return detail::print_graph(*this); }
CELERITY_WARN("Command graph is very large ({} vertices). Skipping GraphViz output", command_count());
std::optional<std::string> command_graph::print_graph(size_t max_nodes) const {
if(command_count() <= max_nodes) { return detail::print_graph(*this); }
return std::nullopt;
}

Expand Down
11 changes: 11 additions & 0 deletions src/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ namespace detail {
spdlog::set_level(log_lvl);
}

// ------------------------- CELERITY_GRAPH_PRINT_MAX_VERTS ---------------------------

{
const auto [is_set, value] = get_env("CELERITY_GRAPH_PRINT_MAX_VERTS");
if(is_set) {
if(log_lvl > log_level::trace) { CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS: Graphs will only be printed for CELERITY_LOG_LEVEL=trace."); }
const auto [is_valid, parsed] = parse_uint(value.c_str());
if(is_valid) { graph_print_max_verts = parsed; }
}
}

// --------------------------------- CELERITY_DEVICES ---------------------------------

{
Expand Down
19 changes: 15 additions & 4 deletions src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,24 @@ namespace detail {
h_queue->wait();

if(is_master_node() && cfg->get_log_level() == log_level::trace) {
const auto print_max_nodes = cfg->get_graph_print_max_verts();
{
const auto graph_str = task_mngr->print_graph();
if(graph_str.has_value()) { CELERITY_TRACE("Task graph:\n\n{}\n", *graph_str); }
const auto graph_str = task_mngr->print_graph(print_max_nodes);
if(graph_str.has_value()) {
CELERITY_TRACE("Task graph:\n\n{}\n", *graph_str);
} else {
CELERITY_WARN("Task graph with {} vertices exceeds CELERITY_GRAPH_PRINT_MAX_VERTS={}. Skipping GraphViz output",
task_mngr->get_current_task_count(), print_max_nodes);
}
}
{
const auto graph_str = cdag->print_graph();
if(graph_str.has_value()) { CELERITY_TRACE("Command graph:\n\n{}\n", *graph_str); }
const auto graph_str = cdag->print_graph(print_max_nodes);
if(graph_str.has_value()) {
CELERITY_TRACE("Command graph:\n\n{}\n", *graph_str);
} else {
CELERITY_WARN("Command graph with {} vertices exceeds CELERITY_GRAPH_PRINT_MAX_VERTS={}. Skipping GraphViz output", cdag->command_count(),
print_max_nodes);
}
}
}

Expand Down
6 changes: 2 additions & 4 deletions src/task_manager.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "task_manager.h"

#include "access_modes.h"
#include "log.h"
#include "print_graph.h"

namespace celerity {
Expand Down Expand Up @@ -33,10 +32,9 @@ namespace detail {
return task_map.at(tid).get();
}

std::optional<std::string> task_manager::print_graph() const {
std::optional<std::string> task_manager::print_graph(size_t max_nodes) const {
std::lock_guard<std::mutex> lock(task_mutex);
if(task_map.size() < 200) { return detail::print_graph(task_map); }
CELERITY_WARN("Task graph is very large ({} vertices). Skipping GraphViz output", task_map.size());
if(task_map.size() <= max_nodes) { return detail::print_graph(task_map); }
return std::nullopt;
}

Expand Down
2 changes: 1 addition & 1 deletion test/system/distr_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ namespace detail {
MPI_Comm test_communicator;
MPI_Comm_create(MPI_COMM_WORLD, world_group, &test_communicator);

const auto graph_str = runtime::get_instance().get_task_manager().print_graph();
const auto graph_str = runtime::get_instance().get_task_manager().print_graph(100);
REQUIRE(graph_str.has_value());
const int graph_str_length = graph_str->length();
REQUIRE(graph_str_length > 0);
Expand Down
5 changes: 3 additions & 2 deletions test/unit_test_suite_celerity.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "unit_test_suite_celerity.h"

#include <limits>
#include <sstream>

namespace detail {
Expand All @@ -16,15 +17,15 @@ void test_run_ended_callback() {

void maybe_print_graph(celerity::detail::task_manager& tm) {
if(print_graphs) {
const auto graph_str = tm.print_graph();
const auto graph_str = tm.print_graph(std::numeric_limits<size_t>::max());
assert(graph_str.has_value());
CELERITY_INFO("Task graph:\n\n{}\n", *graph_str);
}
}

void maybe_print_graph(celerity::detail::command_graph& cdag) {
if(print_graphs) {
const auto graph_str = cdag.print_graph();
const auto graph_str = cdag.print_graph(std::numeric_limits<size_t>::max());
assert(graph_str.has_value());
CELERITY_INFO("Command graph:\n\n{}\n", *graph_str);
}
Expand Down

0 comments on commit 6ecaab0

Please sign in to comment.