Skip to content

Commit

Permalink
Minor changes requested on review.
Browse files Browse the repository at this point in the history
  • Loading branch information
facuMH committed Feb 20, 2023
1 parent 976ea2f commit cecd852
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 38 deletions.
6 changes: 1 addition & 5 deletions include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

#include <cstddef>
#include <optional>
#include <string>
#include <unordered_map>
#include <string_view>
#include <vector>

namespace celerity {
Expand Down Expand Up @@ -52,9 +51,6 @@ namespace detail {
size_t m_graph_print_max_verts = 200;
int m_dry_run_nodes = 0;

// required for test
std::vector<size_t> m_devices_list;

// parsing functions
size_t parse_validate_graph_print_max_verts(const std::string_view str);
bool parse_validate_profile_kernel(const std::string_view str);
Expand Down
56 changes: 31 additions & 25 deletions src/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,32 @@ namespace env {
template <>
struct default_parser<celerity::detail::log_level> {
celerity::detail::log_level operator()(const std::string_view str) const {
auto opt = celerity::detail::log_level::info;

if(str == "trace") {
opt = celerity::detail::log_level::trace;
} else if(str == "debug") {
opt = celerity::detail::log_level::debug;
} else if(str == "off") {
opt = celerity::detail::log_level::off;
} else if(str == "warn") {
opt = celerity::detail::log_level::warn;
} else if(str == "err") {
opt = celerity::detail::log_level::err;
} else if(str == "critical") {
opt = celerity::detail::log_level::critical;
} else if(str == "info") {
opt = celerity::detail::log_level::info;
} else {
throw parser_error{std::string("Unable to parse '") + std::string(str) + "'"};
const std::vector<std::pair<celerity::detail::log_level, std::string>> possible_values = {
{celerity::detail::log_level::trace, "trace"},
{celerity::detail::log_level::debug, "debug"},
{celerity::detail::log_level::info, "info"},
{celerity::detail::log_level::warn, "warn"},
{celerity::detail::log_level::err, "err"},
{celerity::detail::log_level::critical, "critical"},
{celerity::detail::log_level::off, "off"},
};

auto lvl = celerity::detail::log_level::info;
bool valid = false;
for(auto& pv : possible_values) {
if(str == pv.second) {
lvl = pv.first;
valid = true;
break;
}
}
auto err_msg = fmt::format("Unable to parse '{}'. Possible values are:", str);
for(size_t i = 0; i < possible_values.size(); ++i) {
err_msg += fmt::format(" {}{}", possible_values[i].second, (i < possible_values.size() - 1 ? ", " : "."));
}
if(!valid) throw parser_error{err_msg};

return opt;
return lvl;
}
};
} // namespace env
Expand All @@ -58,10 +63,10 @@ namespace detail {

size_t config::parse_validate_graph_print_max_verts(const std::string_view str) {
const auto gmpv = env::default_parser<size_t>{}(str);
if(spdlog::should_log(celerity::detail::log_level::trace)) {
if(!spdlog::should_log(celerity::detail::log_level::trace)) {
CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS: Graphs will only be printed for CELERITY_LOG_LEVEL=trace.");
}
CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS new value: {}.", gmpv);
if(spdlog::should_log(celerity::detail::log_level::debug)) { CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS new value: {}.", gmpv); }
return gmpv;
}

Expand All @@ -73,7 +78,9 @@ namespace detail {

size_t config::parse_validate_dry_run_nodes(const std::string_view str) {
const size_t drn = env::default_parser<size_t>{}(str);
if(m_host_cfg.node_count != 1) throw std::runtime_error("In order to run with CELERITY_DRY_RUN_NODES a single MPI process/rank must be used.");
int world_size = 0;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
if(world_size != 1) throw std::runtime_error("In order to run with CELERITY_DRY_RUN_NODES a single MPI process/rank must be used.");
CELERITY_WARN("Performing a dry run with {} simulated nodes", drn);
return drn;
}
Expand Down Expand Up @@ -182,7 +189,6 @@ namespace detail {

const auto has_devs = parsed_and_validated_envs.get(env_devs);
if(has_devs) {
m_devices_list = *has_devs;
const auto pid_parsed = (*has_devs)[0];
const auto did_parsed = (*has_devs)[m_host_cfg.local_rank + 1];
m_device_cfg = device_config{pid_parsed, did_parsed};
Expand All @@ -199,10 +205,10 @@ namespace detail {
if(has_dry_run_nodes) { m_dry_run_nodes = *has_dry_run_nodes; }

} else {
for(const auto& warn : parsed_and_validated_envs.errors()) {
for(const auto& warn : parsed_and_validated_envs.warnings()) {
CELERITY_WARN(warn.what());
}
for(const auto& err : parsed_and_validated_envs.warnings()) {
for(const auto& err : parsed_and_validated_envs.errors()) {
CELERITY_ERROR(err.what());
}
throw std::runtime_error("Failed to parse/validate environment variables.");
Expand Down
17 changes: 10 additions & 7 deletions test/runtime_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ namespace detail {
dry_run_with_nodes(nodes);
}

TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads environment variables correctly", "[env-vars]") {
TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads environment variables correctly", "[env-vars][config]") {
std::unordered_map<std::string, std::string> valid_test_env_vars{
{"CELERITY_LOG_LEVEL", "debug"},
{"CELERITY_GRAPH_PRINT_MAX_VERTS", "1"},
Expand All @@ -1027,16 +1027,19 @@ namespace detail {
const auto test_env = env::scoped_test_environment(valid_test_env_vars);
auto cfg = config(nullptr, nullptr);

REQUIRE(spdlog::get_level() == spdlog::level::debug);
REQUIRE(cfg.get_graph_print_max_verts() == 1);
REQUIRE(config_testspy::get_device_list(cfg) == std::vector<size_t>{0, 0});
CHECK(spdlog::get_level() == spdlog::level::debug);
CHECK(cfg.get_graph_print_max_verts() == 1);
const auto dev_cfg = config_testspy::get_device_list(cfg);
REQUIRE(dev_cfg != std::nullopt);
CHECK(dev_cfg->platform_id == 0);
CHECK(dev_cfg->device_id == 0);
const auto has_prof = cfg.get_enable_device_profiling();
REQUIRE(has_prof.has_value());
REQUIRE((*has_prof) == true);
REQUIRE(cfg.get_dry_run_nodes() == 4);
CHECK((*has_prof) == true);
CHECK(cfg.get_dry_run_nodes() == 4);
}

TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reports incorrect environment varibles", "[env-vars]") {
TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reports incorrect environment varibles", "[env-vars][config]") {
const std::string error_string{"Failed to parse/validate environment variables."};
{
std::unordered_map<std::string, std::string> invalid_test_env_var{{"CELERITY_LOG_LEVEL", "a"}};
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ namespace detail {
struct config_testspy {
static void set_mock_device_cfg(config& cfg, const device_config& d_cfg) { cfg.m_device_cfg = d_cfg; }
static void set_mock_host_cfg(config& cfg, const host_config& h_cfg) { cfg.m_host_cfg = h_cfg; }
static std::vector<size_t> get_device_list(config& cfg) { return cfg.m_devices_list; }
static std::optional<device_config> get_device_list(config& cfg) { return cfg.m_device_cfg; }
};


Expand Down

0 comments on commit cecd852

Please sign in to comment.