Skip to content

Commit

Permalink
Minor changes requested on review.
Browse files Browse the repository at this point in the history
  • Loading branch information
facuMH committed Feb 22, 2023
1 parent 7b294ce commit 8dc777d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 101 deletions.
16 changes: 2 additions & 14 deletions include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

#include <cstddef>
#include <optional>
#include <string>
#include <unordered_map>
#include <string_view>
#include <vector>

namespace celerity {
Expand Down Expand Up @@ -50,18 +49,7 @@ namespace detail {
std::optional<device_config> m_device_cfg;
std::optional<bool> m_enable_device_profiling;
size_t m_graph_print_max_verts = 200;
int m_dry_run_nodes = 0;

// required for test
std::vector<size_t> m_devices_list;

// parsing functions
size_t parse_validate_graph_print_max_verts(const std::string_view str);
bool parse_validate_profile_kernel(const std::string_view str);
size_t parse_validate_dry_run_nodes(const std::string_view str);
std::vector<size_t> parse_validate_devices(const std::string_view str);
bool parse_validate_force_wg(const std::string_view str);
bool parse_validate_profile_ocl(const std::string_view str);
size_t m_dry_run_nodes = 0;
};

} // namespace detail
Expand Down
166 changes: 88 additions & 78 deletions src/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,90 +28,101 @@ namespace env {
template <>
struct default_parser<celerity::detail::log_level> {
celerity::detail::log_level operator()(const std::string_view str) const {
auto opt = celerity::detail::log_level::info;

if(str == "trace") {
opt = celerity::detail::log_level::trace;
} else if(str == "debug") {
opt = celerity::detail::log_level::debug;
} else if(str == "off") {
opt = celerity::detail::log_level::off;
} else if(str == "warn") {
opt = celerity::detail::log_level::warn;
} else if(str == "err") {
opt = celerity::detail::log_level::err;
} else if(str == "critical") {
opt = celerity::detail::log_level::critical;
} else if(str == "info") {
opt = celerity::detail::log_level::info;
} else {
throw parser_error{std::string("Unable to parse '") + std::string(str) + "'"};
const std::vector<std::pair<celerity::detail::log_level, std::string>> possible_values = {
{celerity::detail::log_level::trace, "trace"},
{celerity::detail::log_level::debug, "debug"},
{celerity::detail::log_level::info, "info"},
{celerity::detail::log_level::warn, "warn"},
{celerity::detail::log_level::err, "err"},
{celerity::detail::log_level::critical, "critical"},
{celerity::detail::log_level::off, "off"},
};

auto lvl = celerity::detail::log_level::info;
bool valid = false;
for(const auto& pv : possible_values) {
if(str == pv.second) {
lvl = pv.first;
valid = true;
break;
}
}
auto err_msg = fmt::format("Unable to parse '{}'. Possible values are:", str);
for(size_t i = 0; i < possible_values.size(); ++i) {
err_msg += fmt::format(" {}{}", possible_values[i].second, (i < possible_values.size() - 1 ? ", " : "."));
}
if(!valid) throw parser_error{err_msg};

return opt;
return lvl;
}
};
} // namespace env

namespace celerity {
namespace detail {
namespace {

size_t config::parse_validate_graph_print_max_verts(const std::string_view str) {
const auto gmpv = env::default_parser<size_t>{}(str);
if(spdlog::should_log(celerity::detail::log_level::trace)) {
CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS: Graphs will only be printed for CELERITY_LOG_LEVEL=trace.");
}
CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS new value: {}.", gmpv);
return gmpv;
size_t parse_validate_graph_print_max_verts(const std::string_view str) {
const auto gmpv = env::default_parser<size_t>{}(str);
if(!spdlog::should_log(celerity::detail::log_level::trace)) {
CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS: Graphs will only be printed for CELERITY_LOG_LEVEL=trace.");
}
CELERITY_DEBUG("CELERITY_GRAPH_PRINT_MAX_VERTS={}.", gmpv);
return gmpv;
}

bool config::parse_validate_profile_kernel(const std::string_view str) {
const auto pk = env::default_parser<bool>{}(str);
CELERITY_WARN("CELERITY_PROFILE_KERNEL is {}.", pk ? "on" : "off");
return pk;
bool parse_validate_profile_kernel(const std::string_view str) {
const auto pk = env::default_parser<bool>{}(str);
CELERITY_DEBUG("CELERITY_PROFILE_KERNEL={}.", pk ? "on" : "off");
return pk;
}

size_t parse_validate_dry_run_nodes(const std::string_view str) {
const size_t drn = env::default_parser<size_t>{}(str);
int world_size = 0;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
if(world_size != 1) throw std::runtime_error("In order to run with CELERITY_DRY_RUN_NODES a single MPI process/rank must be used.");
CELERITY_WARN("Performing a dry run with {} simulated nodes", drn);
return drn;
}

std::vector<size_t> parse_validate_devices(const std::string_view str, const celerity::detail::host_config host_cfg) {
std::vector<size_t> devices;
const auto split_str = split(str, ' ');
// Delegate parsing of primitive types to the default_parser
for(size_t i = 0; i < split_str.size(); ++i) {
devices.push_back(env::default_parser<size_t>{}(split_str[i]));
}
if(devices.size() < 2) {
throw env::validation_error{fmt::format(
"Found {} IDs.\nExpected the following format: CELERITY_DEVICES=\"<platform_id> <first device_id> <second device_id> ... <nth device_id>\"",
devices.size())};
}

size_t config::parse_validate_dry_run_nodes(const std::string_view str) {
const size_t drn = env::default_parser<size_t>{}(str);
if(m_host_cfg.node_count != 1) throw std::runtime_error("In order to run with CELERITY_DRY_RUN_NODES a single MPI process/rank must be used.");
CELERITY_WARN("Performing a dry run with {} simulated nodes", drn);
return drn;
if(static_cast<long>(host_cfg.local_rank) > static_cast<long>(devices.size()) - 2) {
throw env::validation_error{fmt::format(
"Process has local rank {}, but CELERITY_DEVICES only includes {} device(s)", host_cfg.local_rank, devices.empty() ? 0 : devices.size() - 1)};
}
if(static_cast<long>(devices.size()) - 1 > static_cast<long>(host_cfg.node_count)) {
throw env::validation_error{fmt::format(
"CELERITY_DEVICES contains {} device indices, but only {} worker processes were spawned on this host", devices.size() - 1, host_cfg.node_count)};
}

std::vector<size_t> config::parse_validate_devices(const std::string_view str) {
std::vector<size_t> devices;
const auto split_str = split(str, ' ');
// Delegate parsing of primitive types to the default_parser
for(size_t i = 0; i < split_str.size(); ++i) {
devices.push_back(env::default_parser<size_t>{}(split_str[i]));
}
if(devices.size() < 2) {
throw env::validation_error{fmt::format(
"Found {} IDs.\nExpected the following format: CELERITY_DEVICES=\"<platform_id> <first device_id> <second device_id> ... <nth device_id>\"",
devices.size())};
}
return devices;
}

if(static_cast<long>(m_host_cfg.local_rank) > static_cast<long>(devices.size()) - 2) {
throw env::validation_error{fmt::format(
"Process has local rank {}, but CELERITY_DEVICES only includes {} device(s)", m_host_cfg.local_rank, devices.empty() ? 0 : devices.size() - 1)};
}
if(static_cast<long>(devices.size()) - 1 > static_cast<long>(m_host_cfg.node_count)) {
throw env::validation_error{fmt::format("CELERITY_DEVICES contains {} device indices, but only {} worker processes were spawned on this host",
devices.size() - 1, m_host_cfg.node_count)};
}
bool parse_validate_force_wg(const std::string_view str) {
throw env::validation_error{"Support for CELERITY_FORCE_WG has been removed with Celerity 0.3.0."};
return false;
}

return devices;
}
bool parse_validate_profile_ocl(const std::string_view str) {
throw env::validation_error{"CELERITY_PROFILE_OCL has been renamed to CELERITY_PROFILE_KERNEL with Celerity 0.3.0."};
return false;
}

bool config::parse_validate_force_wg(const std::string_view str) {
throw env::validation_error{"Support for CELERITY_FORCE_WG has been removed with Celerity 0.3.0."};
return false;
}
} // namespace

bool config::parse_validate_profile_ocl(const std::string_view str) {
throw env::validation_error{"CELERITY_PROFILE_OCL has been renamed to CELERITY_PROFILE_KERNEL with Celerity 0.3.0."};
return false;
}
namespace celerity {
namespace detail {

config::config(int* argc, char** argv[]) {
// TODO: At some point we might want to parse arguments from argv as well
Expand All @@ -126,7 +137,7 @@ namespace detail {
// TODO: Assert that shared memory is available (i.e. not explicitly disabled)
#define SPLIT_TYPE MPI_COMM_TYPE_SHARED
#endif
MPI_Comm host_comm;
MPI_Comm host_comm = nullptr;
MPI_Comm_split_type(MPI_COMM_WORLD, SPLIT_TYPE, 0, MPI_INFO_NULL, &host_comm);

int local_rank = 0;
Expand All @@ -144,17 +155,17 @@ namespace detail {
const auto env_log_level = pref.register_option<log_level>(
"LOG_LEVEL", {log_level::trace, log_level::debug, log_level::info, log_level::warn, log_level::err, log_level::critical, log_level::off});
const auto env_gpmv =
pref.register_variable<size_t>("GRAPH_PRINT_MAX_VERTS", [this](const std::string_view str) { return parse_validate_graph_print_max_verts(str); });
pref.register_variable<size_t>("GRAPH_PRINT_MAX_VERTS", [](const std::string_view str) { return parse_validate_graph_print_max_verts(str); });
const auto env_devs =
pref.register_variable<std::vector<size_t>>("DEVICES", [this](const std::string_view str) { return parse_validate_devices(str); });
pref.register_variable<std::vector<size_t>>("DEVICES", [this](const std::string_view str) { return parse_validate_devices(str, m_host_cfg); });
const auto env_profile_kernel =
pref.register_variable<bool>("PROFILE_KERNEL", [this](const std::string_view str) { return parse_validate_profile_kernel(str); });
pref.register_variable<bool>("PROFILE_KERNEL", [](const std::string_view str) { return parse_validate_profile_kernel(str); });
const auto env_dry_run_nodes =
pref.register_variable<size_t>("DRY_RUN_NODES", [this](const std::string_view str) { return parse_validate_dry_run_nodes(str); });
pref.register_variable<size_t>("DRY_RUN_NODES", [](const std::string_view str) { return parse_validate_dry_run_nodes(str); });
[[maybe_unused]] const auto env_force_wg =
pref.register_variable<bool>("FORCE_WG", [this](const std::string_view str) { return parse_validate_force_wg(str); });
pref.register_variable<bool>("FORCE_WG", [](const std::string_view str) { return parse_validate_force_wg(str); });
[[maybe_unused]] const auto env_profile_ocl =
pref.register_variable<bool>("PROFILE_OCL", [this](const std::string_view str) { return parse_validate_profile_ocl(str); });
pref.register_variable<bool>("PROFILE_OCL", [](const std::string_view str) { return parse_validate_profile_ocl(str); });

const auto parsed_and_validated_envs = pref.parse_and_validate();
if(parsed_and_validated_envs.ok()) {
Expand Down Expand Up @@ -182,7 +193,6 @@ namespace detail {

const auto has_devs = parsed_and_validated_envs.get(env_devs);
if(has_devs) {
m_devices_list = *has_devs;
const auto pid_parsed = (*has_devs)[0];
const auto did_parsed = (*has_devs)[m_host_cfg.local_rank + 1];
m_device_cfg = device_config{pid_parsed, did_parsed};
Expand All @@ -199,10 +209,10 @@ namespace detail {
if(has_dry_run_nodes) { m_dry_run_nodes = *has_dry_run_nodes; }

} else {
for(const auto& warn : parsed_and_validated_envs.errors()) {
CELERITY_WARN(warn.what());
for(const auto& warn : parsed_and_validated_envs.warnings()) {
CELERITY_ERROR(warn.what());
}
for(const auto& err : parsed_and_validated_envs.warnings()) {
for(const auto& err : parsed_and_validated_envs.errors()) {
CELERITY_ERROR(err.what());
}
throw std::runtime_error("Failed to parse/validate environment variables.");
Expand Down
19 changes: 11 additions & 8 deletions test/runtime_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1015,28 +1015,31 @@ namespace detail {
dry_run_with_nodes(nodes);
}

TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads environment variables correctly", "[env-vars]") {
TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads environment variables correctly", "[env-vars][config]") {
std::unordered_map<std::string, std::string> valid_test_env_vars{
{"CELERITY_LOG_LEVEL", "debug"},
{"CELERITY_GRAPH_PRINT_MAX_VERTS", "1"},
{"CELERITY_DEVICES", "0 0"},
{"CELERITY_DEVICES", "1 1"},
{"CELERITY_PROFILE_KERNEL", "1"},
{"CELERITY_DRY_RUN_NODES", "4"},
};

const auto test_env = env::scoped_test_environment(valid_test_env_vars);
auto cfg = config(nullptr, nullptr);

REQUIRE(spdlog::get_level() == spdlog::level::debug);
REQUIRE(cfg.get_graph_print_max_verts() == 1);
REQUIRE(config_testspy::get_device_list(cfg) == std::vector<size_t>{0, 0});
CHECK(spdlog::get_level() == spdlog::level::debug);
CHECK(cfg.get_graph_print_max_verts() == 1);
const auto dev_cfg = config_testspy::get_device_config(cfg);
REQUIRE(dev_cfg != std::nullopt);
CHECK(dev_cfg->platform_id == 1);
CHECK(dev_cfg->device_id == 1);
const auto has_prof = cfg.get_enable_device_profiling();
REQUIRE(has_prof.has_value());
REQUIRE((*has_prof) == true);
REQUIRE(cfg.get_dry_run_nodes() == 4);
CHECK((*has_prof) == true);
CHECK(cfg.get_dry_run_nodes() == 4);
}

TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reports incorrect environment varibles", "[env-vars]") {
TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reports incorrect environment varibles", "[env-vars][config]") {
const std::string error_string{"Failed to parse/validate environment variables."};
{
std::unordered_map<std::string, std::string> invalid_test_env_var{{"CELERITY_LOG_LEVEL", "a"}};
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ namespace detail {
struct config_testspy {
static void set_mock_device_cfg(config& cfg, const device_config& d_cfg) { cfg.m_device_cfg = d_cfg; }
static void set_mock_host_cfg(config& cfg, const host_config& h_cfg) { cfg.m_host_cfg = h_cfg; }
static std::vector<size_t> get_device_list(config& cfg) { return cfg.m_devices_list; }
static std::optional<device_config> get_device_config(config& cfg) { return cfg.m_device_cfg; }
};


Expand Down

0 comments on commit 8dc777d

Please sign in to comment.