diff --git a/src/scheduler.cc b/src/scheduler.cc index d9bce06f4..19980d20e 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -19,8 +19,9 @@ namespace detail { void abstract_scheduler::schedule() { graph_serializer serializer([this](command_pkg&& pkg) { - if(m_is_dry_run && pkg.get_command_type() != command_type::epoch && pkg.get_command_type() != command_type::fence) { - // in dry runs, skip everything except epochs and fences + if(m_is_dry_run && pkg.get_command_type() != command_type::epoch && pkg.get_command_type() != command_type::horizon + && pkg.get_command_type() != command_type::fence) { + // in dry runs, skip everything except epochs, horizons and fences return; } if(m_is_dry_run && pkg.get_command_type() == command_type::fence) { diff --git a/test/runtime_tests.cc b/test/runtime_tests.cc index 14fe4a47d..eda95a9f9 100644 --- a/test/runtime_tests.cc +++ b/test/runtime_tests.cc @@ -1218,12 +1218,12 @@ namespace detail { test_utils::maybe_print_graph(tm); } - TEST_CASE_METHOD(test_utils::runtime_fixture, "Dry run generates commands for an arbitrary number of simulated worker nodes", "[dryrun]") { + TEST_CASE_METHOD(test_utils::runtime_fixture, "dry run generates commands for an arbitrary number of simulated worker nodes", "[dryrun]") { const size_t num_nodes = GENERATE(values({4, 8, 16})); dry_run_with_nodes(num_nodes); } - TEST_CASE_METHOD(test_utils::runtime_fixture, "Dry run proceeds on fences", "[dryrun]") { + TEST_CASE_METHOD(test_utils::runtime_fixture, "dry run proceeds on fences", "[dryrun]") { env::scoped_test_environment ste(std::unordered_map{{dryrun_envvar_name, "1"}}); distr_queue q; @@ -1240,8 +1240,37 @@ namespace detail { }); auto ret = experimental::fence(q, buf); - bool val = *ret.get(); // this will hang if fences are not processed in the dry run - CHECK_FALSE(val); // extra check that the task was not actually executed + REQUIRE(ret.wait_for(std::chrono::seconds(1)) == std::future_status::ready); + CHECK_FALSE(*ret.get()); // extra check that the task was not actually executed + + // TODO: check that a warning is generated once the issues with log_capture are resolved + } + + TEST_CASE_METHOD(test_utils::runtime_fixture, "dry run processes horizons", "[dryrun]") { + env::scoped_test_environment ste(std::unordered_map{{dryrun_envvar_name, "1"}}); + + distr_queue q; + + auto& rt = runtime::get_instance(); + auto& tm = rt.get_task_manager(); + tm.set_horizon_step(1); // horizon step 1 to make testing easy and reproducable with config changes + + REQUIRE(rt.is_dry_run()); + + auto latest_hor = task_manager_testspy::get_latest_horizon_reached(tm); + CHECK_FALSE(latest_hor.has_value()); + + q.submit([&](handler& cgh) { cgh.host_task(on_master_node, [=] {}); }); + + // we can't slow_full_sync in this test, so we just try until the horizons have been processed + // 100*10ms is one second in total; if the horizon hasn't happened at that point, it's not happening + constexpr int max_num_tries = 100; + for(int i = 0; i < max_num_tries; ++i) { + latest_hor = task_manager_testspy::get_latest_horizon_reached(tm); + if(latest_hor.has_value()) break; + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + CHECK(latest_hor.has_value()); } TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads environment variables correctly", "[env-vars][config]") { @@ -1268,7 +1297,7 @@ namespace detail { CHECK(cfg.get_dry_run_nodes() == 4); } - TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reports incorrect environment varibles", "[env-vars][config]") { + TEST_CASE_METHOD(test_utils::mpi_fixture, "config reports incorrect environment varibles", "[env-vars][config]") { const std::string error_string{"Failed to parse/validate environment variables."}; { std::unordered_map invalid_test_env_var{{"CELERITY_LOG_LEVEL", "a"}}; diff --git a/test/test_utils.h b/test/test_utils.h index 46267a1e2..a4b2f6edb 100644 --- a/test/test_utils.h +++ b/test/test_utils.h @@ -66,6 +66,8 @@ namespace detail { struct task_manager_testspy { static std::optional get_current_horizon(task_manager& tm) { return tm.m_current_horizon; } + static std::optional get_latest_horizon_reached(task_manager& tm) { return tm.m_latest_horizon_reached; } + static int get_num_horizons(task_manager& tm) { int horizon_counter = 0; for(auto task_ptr : tm.m_task_buffer) {