Skip to content

Commit

Permalink
Changes/fixes noted in review
Browse files Browse the repository at this point in the history
  • Loading branch information
facuMH committed Feb 22, 2023
1 parent dc415ad commit 948d73b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 41 deletions.
5 changes: 3 additions & 2 deletions include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ namespace detail {
std::vector<size_t> m_devices_list;

// parsing functions
size_t parse_validate_gpmv(const std::string_view str);
size_t parse_validate_graph_print_max_verts(const std::string_view str);
bool parse_validate_profile_kernel(const std::string_view str);
size_t parse_validate_drn(const std::string_view str);
size_t parse_validate_dry_run_nodes(const std::string_view str);
std::vector<size_t> parse_validate_devices(const std::string_view str);
int parse_validate_forge_wg(const std::string_view str);
};

} // namespace detail
Expand Down
41 changes: 21 additions & 20 deletions src/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,22 @@ struct default_parser<celerity::detail::log_level> {
namespace celerity {
namespace detail {

size_t config::parse_validate_gpmv(const std::string_view str) {
size_t config::parse_validate_graph_print_max_verts(const std::string_view str) {
const auto gmpv = env::default_parser<size_t>{}(str);
if(spdlog::should_log(celerity::detail::log_level::trace)) {
CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS: Graphs will only be printed for CELERITY_LOG_LEVEL=trace.");
}
CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS new value: {}.", gmpv);

CELERITY_WARN("Support for CELERITY_FORCE_WG has been removed with Celerity 0.3.0.");
return gmpv;
}

bool config::parse_validate_profile_kernel(const std::string_view str) {
const auto pk = str == "1";
CELERITY_WARN("CELERITY_PROFILE_KERNEL is on.");
const auto pk = env::default_parser<bool>{}(str);
CELERITY_WARN("CELERITY_PROFILE_KERNEL is {}.", pk ? "on" : "off");
return pk;
}

size_t config::parse_validate_drn(const std::string_view str) {
size_t config::parse_validate_dry_run_nodes(const std::string_view str) {
const size_t drn = env::default_parser<size_t>{}(str);
if(m_host_cfg.node_count != 1) throw std::runtime_error("In order to run with CELERITY_DRY_RUN_NODES a single MPI process/rank must be used.\n");
CELERITY_WARN("Performing a dry run with {} simulated nodes", drn);
Expand All @@ -90,7 +88,7 @@ namespace detail {
if(devices.size() < 2)
throw env::validation_error{
"Found " + std::to_string(devices.size())
+ "\nExpected the following format:CELERITY_DEVICES=\"<platform_id> <first device_id> <second device_id> ... <nth device_id>\""};
+ "\nExpected the following format: CELERITY_DEVICES=\"<platform_id> <first device_id> <second device_id> ... <nth device_id>\""};

if(static_cast<long>(m_host_cfg.local_rank) > static_cast<long>(devices.size()) - 2) {
throw env::validation_error{fmt::format(
Expand All @@ -102,13 +100,16 @@ namespace detail {
}

return devices;
};
}

int config::parse_validate_forge_wg(const std::string_view str) {
CELERITY_WARN("Support for CELERITY_FORCE_WG has been removed with Celerity 0.3.0.");
return 0;
}

config::config(int* argc, char** argv[]) {
// TODO: At some point we might want to parse arguments from argv as well


// Determine the "host config", i.e., how many nodes are spawned on this host,
// and what this node's local rank is. We do this by finding all world-ranks
// that can use a shared-memory transport (if running on OpenMPI, use the
Expand All @@ -133,20 +134,20 @@ namespace detail {

MPI_Comm_free(&host_comm);

auto pv_gpmv = [this](const std::string_view str) { return parse_validate_gpmv(str); };
auto pv_d = [this](const std::string_view str) { return parse_validate_devices(str); };
auto pv_pk = [this](const std::string_view str) { return parse_validate_profile_kernel(str); };
auto pv_drn = [this](const std::string_view str) { return parse_validate_drn(str); };

auto pref = env::prefix("CELERITY");
const auto env_log_level = pref.register_option<log_level>(
"LOG_LEVEL", {log_level::trace, log_level::debug, log_level::info, log_level::warn, log_level::err, log_level::critical, log_level::off});
const auto env_gpmv = pref.register_variable<size_t>("GRAPH_PRINT_MAX_VERTS", pv_gpmv);
const auto env_devs = pref.register_variable<std::vector<size_t>>("DEVICES", pv_d);
const auto env_profile_kernel = pref.register_variable<bool>("PROFILE_KERNEL", pv_pk);
const auto env_dry_run_nodes = pref.register_variable<size_t>("DRY_RUN_NODES", pv_drn);
const auto env_gpmv =
pref.register_variable<size_t>("GRAPH_PRINT_MAX_VERTS", [this](const std::string_view str) { return parse_validate_graph_print_max_verts(str); });
const auto env_devs =
pref.register_variable<std::vector<size_t>>("DEVICES", [this](const std::string_view str) { return parse_validate_devices(str); });
const auto env_profile_kernel =
pref.register_variable<bool>("PROFILE_KERNEL", [this](const std::string_view str) { return parse_validate_profile_kernel(str); });
const auto env_dry_run_nodes =
pref.register_variable<size_t>("DRY_RUN_NODES", [this](const std::string_view str) { return parse_validate_dry_run_nodes(str); });
[[maybe_unused]] const auto env_force_wg =
pref.register_variable<int>("FORCE_WG", [this](const std::string_view str) { return parse_validate_forge_wg(str); });

// const auto parsed_and_validated_envs = m_test_key_value.empty() ? pref.parse_and_validate() : pref.parse_and_validate(m_test_key_value);
const auto parsed_and_validated_envs = pref.parse_and_validate();
if(parsed_and_validated_envs.ok()) {
// ------------------------------- CELERITY_LOG_LEVEL ---------------------------------
Expand Down Expand Up @@ -196,7 +197,7 @@ namespace detail {
for(const auto& err : parsed_and_validated_envs.warnings()) {
CELERITY_ERROR(err.what());
}
throw std::runtime_error("Please make sure to use the environment variables correctly.");
throw std::runtime_error("Failed to parse/validate environment variables.");
}
}
} // namespace detail
Expand Down
46 changes: 27 additions & 19 deletions test/runtime_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1015,13 +1015,14 @@ namespace detail {
dry_run_with_nodes(nodes);
}

TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads correctly environment varibles", "[env-vars]") {
std::unordered_map<std::string, std::string> valid_test_env_vars{//
{"CELERITY_LOG_LEVEL", "debug"}, //
{"CELERITY_GRAPH_PRINT_MAX_VERTS", "1"}, //
{"CELERITY_DEVICES", "0 0"}, //
{"CELERITY_PROFILE_KERNEL", "1"}, //
{"CELERITY_DRY_RUN_NODES", "4"}};
TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads environment variables correctly", "[env-vars]") {
std::unordered_map<std::string, std::string> valid_test_env_vars{
{"CELERITY_LOG_LEVEL", "debug"},
{"CELERITY_GRAPH_PRINT_MAX_VERTS", "1"},
{"CELERITY_DEVICES", "0 0"},
{"CELERITY_PROFILE_KERNEL", "1"},
{"CELERITY_DRY_RUN_NODES", "4"},
};

const auto test_env = env::scoped_test_environment(valid_test_env_vars);
auto cfg = config(nullptr, nullptr);
Expand All @@ -1035,28 +1036,35 @@ namespace detail {
}

TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reports incorrect environment varibles", "[env-vars]") {
const std::string error_string{"Failed to parse/validate environment variables."};
{
std::unordered_map<std::string, std::string> valid_test_env_vars{{"CELERITY_LOG_LEVEL", "a"}};
const auto test_env = env::scoped_test_environment(valid_test_env_vars);
CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), "Please make sure to use the environment variables correctly.");
std::unordered_map<std::string, std::string> invalid_test_env_var{{"CELERITY_LOG_LEVEL", "a"}};
const auto test_env = env::scoped_test_environment(invalid_test_env_var);
CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string);
}

{
std::unordered_map<std::string, std::string> valid_test_env_vars{{"CELERITY_GRAPH_PRINT_MAX_VERTS", "a"}};
const auto test_env = env::scoped_test_environment(valid_test_env_vars);
CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), "Please make sure to use the environment variables correctly.");
std::unordered_map<std::string, std::string> invalid_test_env_var{{"CELERITY_GRAPH_PRINT_MAX_VERTS", "a"}};
const auto test_env = env::scoped_test_environment(invalid_test_env_var);
CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string);
}

{
std::unordered_map<std::string, std::string> valid_test_env_vars{{"CELERITY_DEVICES", "a"}};
const auto test_env = env::scoped_test_environment(valid_test_env_vars);
CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), "Please make sure to use the environment variables correctly.");
std::unordered_map<std::string, std::string> invalid_test_env_var{{"CELERITY_DEVICES", "a"}};
const auto test_env = env::scoped_test_environment(invalid_test_env_var);
CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string);
}

{
std::unordered_map<std::string, std::string> valid_test_env_vars{{"CELERITY_DRY_RUN_NODES", "a"}};
const auto test_env = env::scoped_test_environment(valid_test_env_vars);
CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), "Please make sure to use the environment variables correctly.");
std::unordered_map<std::string, std::string> invalid_test_env_var{{"CELERITY_DRY_RUN_NODES", "a"}};
const auto test_env = env::scoped_test_environment(invalid_test_env_var);
CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string);
}

{
std::unordered_map<std::string, std::string> invalid_test_env_var{{"CELERITY_PROFILE_KERNEL", "a"}};
const auto test_env = env::scoped_test_environment(invalid_test_env_var);
CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string);
}
}

Expand Down

0 comments on commit 948d73b

Please sign in to comment.