diff --git a/.gitmodules b/.gitmodules index ab4d2748a..3775323c5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "vendor/Catch2"] path = vendor/Catch2 url = https://github.com/catchorg/Catch2.git +[submodule "vendor/libenvpp"] + path = vendor/libenvpp + url = https://github.com/ph3at/libenvpp.git \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 388f4a3b1..c1167cc94 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,7 +64,7 @@ else() set(CELERITY_SYCL_IMPL "${AVAILABLE_SYCL_IMPLS}") message(STATUS "Automatically chooosing ${CELERITY_SYCL_IMPL} because it is the only SYCL implementation available") endif() -endif () +endif() set(CELERITY_DPCPP_TARGETS "spir64" CACHE STRING "Intel DPC++ targets") if(CELERITY_SYCL_IMPL STREQUAL "DPC++") @@ -179,6 +179,9 @@ endif() fetch_content_from_submodule(Catch2 vendor/Catch2) list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) +set(LIBENVPP_INSTALL ON CACHE BOOL "" FORCE) +fetch_content_from_submodule(libenvpp vendor/libenvpp) + configure_file(include/version.h.in include/version.h @ONLY) # Add includes to library so they show up in IDEs @@ -259,6 +262,7 @@ target_link_libraries(celerity_runtime PUBLIC fmt::fmt spdlog::spdlog gch::small_vector + libenvpp::libenvpp ${SYCL_LIB} ) diff --git a/cmake/celerity-config.cmake.in b/cmake/celerity-config.cmake.in index aa60091e1..13e4dee81 100644 --- a/cmake/celerity-config.cmake.in +++ b/cmake/celerity-config.cmake.in @@ -22,6 +22,7 @@ find_dependency(Threads REQUIRED) find_dependency(fmt REQUIRED) find_dependency(spdlog REQUIRED) find_dependency(small_vector REQUIRED) +find_dependency(libenvpp REQUIRED) if(CELERITY_SYCL_IMPL STREQUAL "hipSYCL") if(NOT DEFINED HIPSYCL_TARGETS AND NOT "@HIPSYCL_TARGETS@" STREQUAL "") diff --git a/include/config.h b/include/config.h index f310bd8eb..24d590c3b 100644 --- a/include/config.h +++ b/include/config.h @@ -2,6 +2,8 @@ #include #include +#include +#include namespace celerity { namespace detail { @@ -47,7 +49,7 @@ namespace detail { std::optional m_device_cfg; std::optional m_enable_device_profiling; size_t m_graph_print_max_verts = 200; - int m_dry_run_nodes = 0; + size_t m_dry_run_nodes = 0; }; } // namespace detail diff --git a/src/config.cc b/src/config.cc index 97c9db81f..98290ab4e 100644 --- a/src/config.cc +++ b/src/config.cc @@ -5,7 +5,6 @@ #include #include #include -#include #include @@ -13,110 +12,170 @@ #include -std::pair get_env(const char* key) { - bool exists = false; - std::string str; -#ifdef _MSC_VER - char* buf; - _dupenv_s(&buf, nullptr, key); - if(buf != nullptr) { - exists = true; - str = buf; - delete buf; +#include + +static std::vector split(const std::string_view str, const char delimiter) { + auto result = std::vector{}; + auto sstream = std::istringstream(std::string(str)); + auto item = std::string{}; + while(std::getline(sstream, item, delimiter)) { + result.push_back(std::move(item)); } -#else - const auto value = std::getenv(key); - if(value != nullptr) { - exists = true; - str = value; + return result; +} + +namespace env { +template <> +struct default_parser { + celerity::detail::log_level operator()(const std::string_view str) const { + const std::vector> possible_values = { + {celerity::detail::log_level::trace, "trace"}, + {celerity::detail::log_level::debug, "debug"}, + {celerity::detail::log_level::info, "info"}, + {celerity::detail::log_level::warn, "warn"}, + {celerity::detail::log_level::err, "err"}, + {celerity::detail::log_level::critical, "critical"}, + {celerity::detail::log_level::off, "off"}, + }; + + auto lvl = celerity::detail::log_level::info; + bool valid = false; + for(const auto& pv : possible_values) { + if(str == pv.second) { + lvl = pv.first; + valid = true; + break; + } + } + auto err_msg = fmt::format("Unable to parse '{}'. Possible values are:", str); + for(size_t i = 0; i < possible_values.size(); ++i) { + err_msg += fmt::format(" {}{}", possible_values[i].second, (i < possible_values.size() - 1 ? ", " : ".")); + } + if(!valid) throw parser_error{err_msg}; + + return lvl; + } +}; +} // namespace env + +namespace { + +size_t parse_validate_graph_print_max_verts(const std::string_view str) { + const auto gmpv = env::default_parser{}(str); + if(!spdlog::should_log(celerity::detail::log_level::trace)) { + CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS: Graphs will only be printed for CELERITY_LOG_LEVEL=trace."); + } + CELERITY_DEBUG("CELERITY_GRAPH_PRINT_MAX_VERTS={}.", gmpv); + return gmpv; +} + +bool parse_validate_profile_kernel(const std::string_view str) { + const auto pk = env::default_parser{}(str); + CELERITY_DEBUG("CELERITY_PROFILE_KERNEL={}.", pk ? "on" : "off"); + return pk; +} + +size_t parse_validate_dry_run_nodes(const std::string_view str) { + const size_t drn = env::default_parser{}(str); + int world_size = 0; + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + if(world_size != 1) throw std::runtime_error("In order to run with CELERITY_DRY_RUN_NODES a single MPI process/rank must be used."); + CELERITY_WARN("Performing a dry run with {} simulated nodes", drn); + return drn; +} + +std::vector parse_validate_devices(const std::string_view str, const celerity::detail::host_config host_cfg) { + std::vector devices; + const auto split_str = split(str, ' '); + // Delegate parsing of primitive types to the default_parser + for(size_t i = 0; i < split_str.size(); ++i) { + devices.push_back(env::default_parser{}(split_str[i])); + } + if(devices.size() < 2) { + throw env::validation_error{fmt::format( + "Found {} IDs.\nExpected the following format: CELERITY_DEVICES=\" ... \"", + devices.size())}; + } + + if(static_cast(host_cfg.local_rank) > static_cast(devices.size()) - 2) { + throw env::validation_error{fmt::format( + "Process has local rank {}, but CELERITY_DEVICES only includes {} device(s)", host_cfg.local_rank, devices.empty() ? 0 : devices.size() - 1)}; + } + if(static_cast(devices.size()) - 1 > static_cast(host_cfg.node_count)) { + throw env::validation_error{fmt::format( + "CELERITY_DEVICES contains {} device indices, but only {} worker processes were spawned on this host", devices.size() - 1, host_cfg.node_count)}; } -#endif - std::string_view sv = str; - sv.remove_prefix(std::min(sv.find_first_not_of(" "), sv.size())); - sv.remove_suffix(std::min(sv.size() - sv.find_last_not_of(" ") - 1, sv.size())); + return devices; +} - return {exists, std::string{sv}}; +bool parse_validate_force_wg(const std::string_view str) { + throw env::validation_error{"Support for CELERITY_FORCE_WG has been removed with Celerity 0.3.0."}; + return false; } -std::pair parse_uint(const char* str) { - errno = 0; - char* eptr = nullptr; - const auto value = std::strtoul(str, &eptr, 10); - if(errno == 0 && eptr != str) { return {true, value}; } - return {false, 0}; +bool parse_validate_profile_ocl(const std::string_view str) { + throw env::validation_error{"CELERITY_PROFILE_OCL has been renamed to CELERITY_PROFILE_KERNEL with Celerity 0.3.0."}; + return false; } +} // namespace + namespace celerity { namespace detail { + config::config(int* argc, char** argv[]) { // TODO: At some point we might want to parse arguments from argv as well - { - // Determine the "host config", i.e., how many nodes are spawned on this host, - // and what this node's local rank is. We do this by finding all world-ranks - // that can use a shared-memory transport (if running on OpenMPI, use the - // per-host split instead). + // Determine the "host config", i.e., how many nodes are spawned on this host, + // and what this node's local rank is. We do this by finding all world-ranks + // that can use a shared-memory transport (if running on OpenMPI, use the + // per-host split instead). #ifdef OPEN_MPI #define SPLIT_TYPE OMPI_COMM_TYPE_HOST #else - // TODO: Assert that shared memory is available (i.e. not explicitly disabled) + // TODO: Assert that shared memory is available (i.e. not explicitly disabled) #define SPLIT_TYPE MPI_COMM_TYPE_SHARED #endif - MPI_Comm host_comm; - MPI_Comm_split_type(MPI_COMM_WORLD, SPLIT_TYPE, 0, MPI_INFO_NULL, &host_comm); - - int local_rank = 0; - MPI_Comm_rank(host_comm, &local_rank); - - int node_count = 0; - MPI_Comm_size(host_comm, &node_count); + MPI_Comm host_comm = nullptr; + MPI_Comm_split_type(MPI_COMM_WORLD, SPLIT_TYPE, 0, MPI_INFO_NULL, &host_comm); + + int local_rank = 0; + MPI_Comm_rank(host_comm, &local_rank); + + int node_count = 0; + MPI_Comm_size(host_comm, &node_count); + + m_host_cfg.local_rank = local_rank; + m_host_cfg.node_count = node_count; + + MPI_Comm_free(&host_comm); + + auto pref = env::prefix("CELERITY"); + const auto env_log_level = pref.register_option( + "LOG_LEVEL", {log_level::trace, log_level::debug, log_level::info, log_level::warn, log_level::err, log_level::critical, log_level::off}); + const auto env_gpmv = + pref.register_variable("GRAPH_PRINT_MAX_VERTS", [](const std::string_view str) { return parse_validate_graph_print_max_verts(str); }); + const auto env_devs = + pref.register_variable>("DEVICES", [this](const std::string_view str) { return parse_validate_devices(str, m_host_cfg); }); + const auto env_profile_kernel = + pref.register_variable("PROFILE_KERNEL", [](const std::string_view str) { return parse_validate_profile_kernel(str); }); + const auto env_dry_run_nodes = + pref.register_variable("DRY_RUN_NODES", [](const std::string_view str) { return parse_validate_dry_run_nodes(str); }); + [[maybe_unused]] const auto env_force_wg = + pref.register_variable("FORCE_WG", [](const std::string_view str) { return parse_validate_force_wg(str); }); + [[maybe_unused]] const auto env_profile_ocl = + pref.register_variable("PROFILE_OCL", [](const std::string_view str) { return parse_validate_profile_ocl(str); }); + + const auto parsed_and_validated_envs = pref.parse_and_validate(); + if(parsed_and_validated_envs.ok()) { + // ------------------------------- CELERITY_LOG_LEVEL --------------------------------- - m_host_cfg.local_rank = local_rank; - m_host_cfg.node_count = node_count; - - MPI_Comm_free(&host_comm); - } - - // ------------------------------- CELERITY_LOG_LEVEL --------------------------------- - - { #if defined(CELERITY_DETAIL_ENABLE_DEBUG) - auto log_lvl = log_level::debug; + const auto log_lvl = parsed_and_validated_envs.get_or(env_log_level, log_level::debug); #else - auto log_lvl = log_level::info; + const auto log_lvl = parsed_and_validated_envs.get_or(env_log_level, log_level::info); #endif - const std::vector> possible_values = { - {log_level::trace, "trace"}, - {log_level::debug, "debug"}, - {log_level::info, "info"}, - {log_level::warn, "warn"}, - {log_level::err, "err"}, - {log_level::critical, "critical"}, - {log_level::off, "off"}, - }; - - const auto result = get_env("CELERITY_LOG_LEVEL"); - if(result.first) { - bool valid = false; - for(auto& pv : possible_values) { - if(result.second == pv.second) { - log_lvl = pv.first; - valid = true; - break; - } - } - if(!valid) { - std::ostringstream oss; - oss << "Invalid value \"" << result.second << "\" provided for CELERITY_LOG_LEVEL. "; - oss << "Possible values are: "; - for(size_t i = 0; i < possible_values.size(); ++i) { - oss << possible_values[i].second << (i < possible_values.size() - 1 ? ", " : "."); - } - CELERITY_WARN(oss.str()); - } - } - // Set both the global log level and the default sink level so that the console logger adheres to CELERITY_LOG_LEVEL even if we temporarily // override the global level in test_utils::log_capture. // TODO do not modify global state in the constructor, but factor the LOG_LEVEL part out of detail::config entirely. @@ -124,81 +183,39 @@ namespace detail { for(auto& sink : spdlog::default_logger_raw()->sinks()) { sink->set_level(log_lvl); } - } - // ------------------------- CELERITY_GRAPH_PRINT_MAX_VERTS --------------------------- + // ------------------------- CELERITY_GRAPH_PRINT_MAX_VERTS --------------------------- - { - const auto [is_set, value] = get_env("CELERITY_GRAPH_PRINT_MAX_VERTS"); - if(is_set) { - if(spdlog::should_log(log_level::trace)) { - CELERITY_WARN("CELERITY_GRAPH_PRINT_MAX_VERTS: Graphs will only be printed for CELERITY_LOG_LEVEL=trace."); - } - const auto [is_valid, parsed] = parse_uint(value.c_str()); - if(is_valid) { m_graph_print_max_verts = parsed; } - } - } + const auto has_gpmv = parsed_and_validated_envs.get(env_gpmv); + if(has_gpmv) { m_graph_print_max_verts = *has_gpmv; } - // --------------------------------- CELERITY_DEVICES --------------------------------- - - { - const auto [is_set, value] = get_env("CELERITY_DEVICES"); - if(is_set) { - if(value.empty()) { - CELERITY_WARN("CELERITY_DEVICES is set but empty - ignoring"); - } else { - std::istringstream ss{value}; - std::vector values{std::istream_iterator{ss}, std::istream_iterator{}}; - if(static_cast(m_host_cfg.local_rank) > static_cast(values.size()) - 2) { - throw std::runtime_error(fmt::format("Process has local rank {}, but CELERITY_DEVICES only includes {} device(s)", - m_host_cfg.local_rank, values.empty() ? 0 : values.size() - 1)); - } - - if(static_cast(values.size()) - 1 > static_cast(m_host_cfg.node_count)) { - CELERITY_WARN("CELERITY_DEVICES contains {} device indices, but only {} worker processes were spawned on this host", values.size() - 1, - m_host_cfg.node_count); - } - - const auto pid_parsed = parse_uint(values[0].c_str()); - const auto did_parsed = parse_uint(values[m_host_cfg.local_rank + 1].c_str()); - if(!pid_parsed.first || !did_parsed.first) { - CELERITY_WARN("CELERITY_DEVICES contains invalid value(s) - will be ignored"); - } else { - m_device_cfg = device_config{pid_parsed.second, did_parsed.second}; - } - } - } - } + // --------------------------------- CELERITY_DEVICES --------------------------------- - // ----------------------------- CELERITY_PROFILE_KERNEL ------------------------------ - { - const auto result = get_env("CELERITY_PROFILE_OCL"); - if(result.first) { - CELERITY_WARN("CELERITY_PROFILE_OCL has been renamed to CELERITY_PROFILE_KERNEL with Celerity 0.3.0."); - m_enable_device_profiling = result.second == "1"; + const auto has_devs = parsed_and_validated_envs.get(env_devs); + if(has_devs) { + const auto pid_parsed = (*has_devs)[0]; + const auto did_parsed = (*has_devs)[m_host_cfg.local_rank + 1]; + m_device_cfg = device_config{pid_parsed, did_parsed}; } - } - { - const auto result = get_env("CELERITY_PROFILE_KERNEL"); - if(result.first) { m_enable_device_profiling = result.second == "1"; } - } + // ----------------------------- CELERITY_PROFILE_KERNEL ------------------------------ - // -------------------------------- CELERITY_FORCE_WG --------------------------------- + const auto has_profile_kernel = parsed_and_validated_envs.get(env_profile_kernel); + if(has_profile_kernel) { m_enable_device_profiling = *has_profile_kernel; } - { - const auto result = get_env("CELERITY_FORCE_WG"); - if(result.first) { CELERITY_WARN("Support for CELERITY_FORCE_WG has been removed with Celerity 0.3.0."); } - } + // -------------------------------- CELERITY_DRY_RUN_NODES --------------------------------- - // -------------------------------- CELERITY_DRY_RUN_NODES --------------------------------- - { - const auto [is_set, value] = get_env("CELERITY_DRY_RUN_NODES"); - if(is_set) { - const auto [is_valid, num_nodes] = parse_uint(value.c_str()); - if(!is_valid) { CELERITY_WARN("CELERITY_DRY_RUN_NODES contains invalid value - will be ignored"); } - m_dry_run_nodes = num_nodes; + const auto has_dry_run_nodes = parsed_and_validated_envs.get(env_dry_run_nodes); + if(has_dry_run_nodes) { m_dry_run_nodes = *has_dry_run_nodes; } + + } else { + for(const auto& warn : parsed_and_validated_envs.warnings()) { + CELERITY_ERROR(warn.what()); + } + for(const auto& err : parsed_and_validated_envs.errors()) { + CELERITY_ERROR(err.what()); } + throw std::runtime_error("Failed to parse/validate environment variables."); } } } // namespace detail diff --git a/src/runtime.cc b/src/runtime.cc index 382e16bf2..f5e496c21 100644 --- a/src/runtime.cc +++ b/src/runtime.cc @@ -103,11 +103,7 @@ namespace detail { m_num_nodes = world_size; m_cfg = std::make_unique(argc, argv); - if(m_cfg->is_dry_run()) { - if(m_num_nodes != 1) throw std::runtime_error("In order to run with CELERITY_DRY_RUN_NODES a single MPI process/rank must be used.\n"); - m_num_nodes = m_cfg->get_dry_run_nodes(); - CELERITY_WARN("Performing a dry run with {} simulated nodes", m_num_nodes); - } + if(m_cfg->is_dry_run()) { m_num_nodes = m_cfg->get_dry_run_nodes(); } int world_rank; MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); diff --git a/test/device_selection_tests.cc b/test/device_selection_tests.cc index bd45c2943..f1a238b52 100644 --- a/test/device_selection_tests.cc +++ b/test/device_selection_tests.cc @@ -106,13 +106,6 @@ struct mock_platform { std::vector m_devices; }; -namespace celerity::detail { -struct config_testspy { - static void set_mock_device_cfg(config& cfg, const device_config& d_cfg) { cfg.m_device_cfg = d_cfg; } - static void set_mock_host_cfg(config& cfg, const host_config& h_cfg) { cfg.m_host_cfg = h_cfg; } -}; -} // namespace celerity::detail - TEST_CASE_METHOD(celerity::test_utils::mpi_fixture, "pick_device prefers user specified device pointer", "[device-selection]") { celerity::detail::config cfg(nullptr, nullptr); mock_platform_factory mpf; diff --git a/test/runtime_tests.cc b/test/runtime_tests.cc index d68b50e9f..539534755 100644 --- a/test/runtime_tests.cc +++ b/test/runtime_tests.cc @@ -15,6 +15,8 @@ #include #include +#include + #include #include "affinity.h" @@ -1013,5 +1015,73 @@ namespace detail { dry_run_with_nodes(nodes); } + TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reads environment variables correctly", "[env-vars][config]") { + std::unordered_map valid_test_env_vars{ + {"CELERITY_LOG_LEVEL", "debug"}, + {"CELERITY_GRAPH_PRINT_MAX_VERTS", "1"}, + {"CELERITY_DEVICES", "1 1"}, + {"CELERITY_PROFILE_KERNEL", "1"}, + {"CELERITY_DRY_RUN_NODES", "4"}, + }; + + const auto test_env = env::scoped_test_environment(valid_test_env_vars); + auto cfg = config(nullptr, nullptr); + + CHECK(spdlog::get_level() == spdlog::level::debug); + CHECK(cfg.get_graph_print_max_verts() == 1); + const auto dev_cfg = config_testspy::get_device_config(cfg); + REQUIRE(dev_cfg != std::nullopt); + CHECK(dev_cfg->platform_id == 1); + CHECK(dev_cfg->device_id == 1); + const auto has_prof = cfg.get_enable_device_profiling(); + REQUIRE(has_prof.has_value()); + CHECK((*has_prof) == true); + CHECK(cfg.get_dry_run_nodes() == 4); + } + + TEST_CASE_METHOD(test_utils::mpi_fixture, "Config reports incorrect environment varibles", "[env-vars][config]") { + const std::string error_string{"Failed to parse/validate environment variables."}; + { + std::unordered_map invalid_test_env_var{{"CELERITY_LOG_LEVEL", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); + } + + { + std::unordered_map invalid_test_env_var{{"CELERITY_GRAPH_PRINT_MAX_VERTS", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); + } + + { + std::unordered_map invalid_test_env_var{{"CELERITY_DEVICES", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); + } + + { + std::unordered_map invalid_test_env_var{{"CELERITY_DRY_RUN_NODES", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); + } + + { + std::unordered_map invalid_test_env_var{{"CELERITY_PROFILE_KERNEL", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); + } + { + std::unordered_map invalid_test_env_var{{"CELERITY_FORCE_WG", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); + } + + { + std::unordered_map invalid_test_env_var{{"CELERITY_PROFILE_OCL", "a"}}; + const auto test_env = env::scoped_test_environment(invalid_test_env_var); + CHECK_THROWS_WITH((celerity::detail::config(nullptr, nullptr)), error_string); + } + } + } // namespace detail } // namespace celerity diff --git a/test/test_utils.h b/test/test_utils.h index cad1f2342..3fef20f31 100644 --- a/test/test_utils.h +++ b/test/test_utils.h @@ -80,6 +80,14 @@ namespace detail { static void create_task_slot(task_manager& tm) { task_ring_buffer_testspy::create_task_slot(tm.m_task_buffer); } }; + + struct config_testspy { + static void set_mock_device_cfg(config& cfg, const device_config& d_cfg) { cfg.m_device_cfg = d_cfg; } + static void set_mock_host_cfg(config& cfg, const host_config& h_cfg) { cfg.m_host_cfg = h_cfg; } + static std::optional get_device_config(config& cfg) { return cfg.m_device_cfg; } + }; + + inline bool has_dependency(const task_manager& tm, task_id dependent, task_id dependency, dependency_kind kind = dependency_kind::true_dep) { for(auto dep : tm.get_task(dependent)->get_dependencies()) { if(dep.node->get_id() == dependency && dep.kind == kind) return true; diff --git a/vendor/libenvpp b/vendor/libenvpp new file mode 160000 index 000000000..82f0819cb --- /dev/null +++ b/vendor/libenvpp @@ -0,0 +1 @@ +Subproject commit 82f0819cba606a7ffb6b0ba81acd697b248d0618