From 948d73bb34d3566146ca862cb6ef067ab0927e57 Mon Sep 17 00:00:00 2001 From: FacuMH Date: Wed, 25 Jan 2023 11:30:39 +0100 Subject: [PATCH] Changes/fixes noted in review --- include/config.h | 5 +++-- src/config.cc | 41 +++++++++++++++++++------------------- test/runtime_tests.cc | 46 +++++++++++++++++++++++++------------------ 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/include/config.h b/include/config.h index 94bac110d..4a8d1206f 100644 --- a/include/config.h +++ b/include/config.h @@ -56,10 +56,11 @@ namespace detail { std::vector m_devices_list; // parsing functions - size_t parse_validate_gpmv(const std::string_view str); + size_t parse_validate_graph_print_max_verts(const std::string_view str); bool parse_validate_profile_kernel(const std::string_view str); - size_t parse_validate_drn(const std::string_view str); + size_t parse_validate_dry_run_nodes(const std::string_view str); std::vector parse_validate_devices(const std::string_view str); + int parse_validate_forge_wg(const std::string_view str); }; } // namespace detail diff --git a/src/config.cc b/src/config.cc index 792eb62bc..f98918a57 100644 --- a/src/config.cc +++ b/src/config.cc @@ -56,24 +56,22 @@ struct default_parser { namespace celerity { namespace detail { - size_t config::parse_validate_gpmv(const std::string_view str) { + size_t config::parse_validate_graph_print_max_verts(const std::string_view str) { const auto gmpv = env::default_parser{}(str); if(spdlog::should_log(celerity::detail::log_level::trace)) { CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS: Graphs will only be printed for CELERITY_LOG_LEVEL=trace."); } CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS new value: {}.", gmpv); - - CELERITY_WARN("Support for CELERITY_FORCE_WG has been removed with Celerity 0.3.0."); return gmpv; } bool config::parse_validate_profile_kernel(const std::string_view str) { - const auto pk = str == "1"; - CELERITY_WARN("CELERITY_PROFILE_KERNEL is on."); + const auto pk = env::default_parser{}(str); + CELERITY_WARN("CELERITY_PROFILE_KERNEL is {}.", pk ? "on" : "off"); return pk; } - size_t config::parse_validate_drn(const std::string_view str) { + size_t config::parse_validate_dry_run_nodes(const std::string_view str) { const size_t drn = env::default_parser{}(str); if(m_host_cfg.node_count != 1) throw std::runtime_error("In order to run with CELERITY_DRY_RUN_NODES a single MPI process/rank must be used.\n"); CELERITY_WARN("Performing a dry run with {} simulated nodes", drn); @@ -90,7 +88,7 @@ namespace detail { if(devices.size() < 2) throw env::validation_error{ "Found " + std::to_string(devices.size()) - + "\nExpected the following format:CELERITY_DEVICES=\" ... \""}; + + "\nExpected the following format: CELERITY_DEVICES=\" ... \""}; if(static_cast(m_host_cfg.local_rank) > static_cast(devices.size()) - 2) { throw env::validation_error{fmt::format( @@ -102,13 +100,16 @@ namespace detail { } return devices; - }; + } + int config::parse_validate_forge_wg(const std::string_view str) { + CELERITY_WARN("Support for CELERITY_FORCE_WG has been removed with Celerity 0.3.0."); + return 0; + } config::config(int* argc, char** argv[]) { // TODO: At some point we might want to parse arguments from argv as well - // Determine the "host config", i.e., how many nodes are spawned on this host, // and what this node's local rank is. We do this by finding all world-ranks // that can use a shared-memory transport (if running on OpenMPI, use the @@ -133,20 +134,20 @@ namespace detail { MPI_Comm_free(&host_comm); - auto pv_gpmv = [this](const std::string_view str) { return parse_validate_gpmv(str); }; - auto pv_d = [this](const std::string_view str) { return parse_validate_devices(str); }; - auto pv_pk = [this](const std::string_view str) { return parse_validate_profile_kernel(str); }; - auto pv_drn = [this](const std::string_view str) { return parse_validate_drn(str); }; - auto pref = env::prefix("CELERITY"); const auto env_log_level = pref.register_option( "LOG_LEVEL", {log_level::trace, log_level::debug, log_level::info, log_level::warn, log_level::err, log_level::critical, log_level::off}); - const auto env_gpmv = pref.register_variable("GRAPH_PRINT_MAX_VERTS", pv_gpmv); - const auto env_devs = pref.register_variable>("DEVICES", pv_d); - const auto env_profile_kernel = pref.register_variable("PROFILE_KERNEL", pv_pk); - const auto env_dry_run_nodes = pref.register_variable("DRY_RUN_NODES", pv_drn); + const auto env_gpmv = + pref.register_variable("GRAPH_PRINT_MAX_VERTS", [this](const std::string_view str) { return parse_validate_graph_print_max_verts(str); }); + const auto env_devs = + pref.register_variable>("DEVICES", [this](const std::string_view str) { return parse_validate_devices(str); }); + const auto env_profile_kernel = + pref.register_variable("PROFILE_KERNEL", [this](const std::string_view str) { return parse_validate_profile_kernel(str); }); + const auto env_dry_run_nodes = + pref.register_variable("DRY_RUN_NODES", [this](const std::string_view str) { return parse_validate_dry_run_nodes(str); }); + [[maybe_unused]] const auto env_force_wg = + pref.register_variable("FORCE_WG", [this](const std::string_view str) { return parse_validate_forge_wg(str); }); - // const auto parsed_and_validated_envs = m_test_key_value.empty() ? pref.parse_and_validate() : pref.parse_and_validate(m_test_key_value); const auto parsed_and_validated_envs = pref.parse_and_validate(); if(parsed_and_validated_envs.ok()) { // ------------------------------- CELERITY_LOG_LEVEL --------------------------------- @@ -196,7 +197,7 @@ namespace detail { for(const auto& err : parsed_and_validated_envs.warnings()) { CELERITY_ERROR(err.what()); } - throw std::runtime_error("Please make sure to use the environment variables correctly."); + throw std::runtime_error("Failed to parse/validate environment variables."); } } } // namespace detail diff --git a/test/runtime_tests.cc b/test/runtime_tests.cc index 75b879d90..9f199eb0e 100644 --- a/test/runtime_tests.cc +++ b/test/runtime_tests.cc @@ -1015,13 +1015,14 @@ namespace detail { dry_run_with_nodes(nodes); } - TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads correctly environment varibles", "[env-vars]") { - std::unordered_map valid_test_env_vars{// - {"CELERITY_LOG_LEVEL", "debug"}, // - {"CELERITY_GRAPH_PRINT_MAX_VERTS", "1"}, // - {"CELERITY_DEVICES", "0 0"}, // - {"CELERITY_PROFILE_KERNEL", "1"}, // - {"CELERITY_DRY_RUN_NODES", "4"}}; + TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads environment variables correctly", "[env-vars]") { + std::unordered_map valid_test_env_vars{ + {"CELERITY_LOG_LEVEL", "debug"}, + {"CELERITY_GRAPH_PRINT_MAX_VERTS", "1"}, + {"CELERITY_DEVICES", "0 0"}, + {"CELERITY_PROFILE_KERNEL", "1"}, + {"CELERITY_DRY_RUN_NODES", "4"}, + }; const auto test_env = env::scoped_test_environment(valid_test_env_vars); auto cfg = config(nullptr, nullptr); @@ -1035,28 +1036,35 @@ namespace detail { } TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reports incorrect environment varibles", "[env-vars]") { + const std::string error_string{"Failed to parse/validate environment variables."}; { - std::unordered_map valid_test_env_vars{{"CELERITY_LOG_LEVEL", "a"}}; - const auto test_env = env::scoped_test_environment(valid_test_env_vars); - CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), "Please make sure to use the environment variables correctly."); + std::unordered_map invalid_test_env_var{{"CELERITY_LOG_LEVEL", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); } { - std::unordered_map valid_test_env_vars{{"CELERITY_GRAPH_PRINT_MAX_VERTS", "a"}}; - const auto test_env = env::scoped_test_environment(valid_test_env_vars); - CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), "Please make sure to use the environment variables correctly."); + std::unordered_map invalid_test_env_var{{"CELERITY_GRAPH_PRINT_MAX_VERTS", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); } { - std::unordered_map valid_test_env_vars{{"CELERITY_DEVICES", "a"}}; - const auto test_env = env::scoped_test_environment(valid_test_env_vars); - CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), "Please make sure to use the environment variables correctly."); + std::unordered_map invalid_test_env_var{{"CELERITY_DEVICES", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); } { - std::unordered_map valid_test_env_vars{{"CELERITY_DRY_RUN_NODES", "a"}}; - const auto test_env = env::scoped_test_environment(valid_test_env_vars); - CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), "Please make sure to use the environment variables correctly."); + std::unordered_map invalid_test_env_var{{"CELERITY_DRY_RUN_NODES", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); + } + + { + std::unordered_map invalid_test_env_var{{"CELERITY_PROFILE_KERNEL", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); } }