From 66664b1d742d8ac4f197663aade3fbf57c59d0d8 Mon Sep 17 00:00:00 2001 From: Katie Graham Date: Tue, 22 Nov 2022 08:58:49 -0800 Subject: [PATCH] Added user input args to mlperf callback, moved mlperf data class out of separate class --- include/lbann/callbacks/mlperf_logging.hpp | 48 +++++---- src/callbacks/mlperf_logging.cpp | 107 +++++++++------------ src/proto/callbacks.proto | 7 +- 3 files changed, 84 insertions(+), 78 deletions(-) diff --git a/include/lbann/callbacks/mlperf_logging.hpp b/include/lbann/callbacks/mlperf_logging.hpp index 061823bba89..708fb2f892a 100644 --- a/include/lbann/callbacks/mlperf_logging.hpp +++ b/include/lbann/callbacks/mlperf_logging.hpp @@ -53,11 +53,28 @@ class mlperf_logging : public callback_base { /** @brief mlperf_logging Constructor. * @param output_filename Output filename (default = results.txt) */ - mlperf_logging(std::string output_filename) + mlperf_logging(std::string output_filename, std::string sub_benchmark, + std::string sub_org, std::string sub_division, + std::string sub_status, std::string sub_platform) : callback_base(/*batch_interval=*/1), m_output_filename{output_filename.size() ? std::move(output_filename) : - std::string("results.txt")} + std::string("results.txt")}, + m_sub_benchmark{sub_benchmark.size() ? + std::move(sub_benchmark) : + std::string("UNKNOWN_SUBMISSION_BENCHMARK")}, + m_sub_org{sub_org.size() ? + std::move(sub_org) : + std::string("LBANN")}, + m_sub_division{sub_division.size() ? + std::move(sub_division) : + std::string("UNKNOWN_SUBMISSION_DIVISION")}, + m_sub_status{sub_status.size() ? + std::move(sub_status) : + std::string("UNKNOWN_SUBMISSION_STATUS")}, + m_sub_platform{sub_platform.size() ? + std::move(sub_platform) : + std::string("UNKNOWN_SUBMISSION_PLATFORM")} {} /** @brief Copy interface */ @@ -69,7 +86,7 @@ class mlperf_logging : public callback_base { std::string name() const override { return "mlperf_logging"; } /** @brief Push mlperf formatted log string to stream object. - * @param ostream os Stores log strings. + * @param ostringstream os Stores log strings. * @param event_type et Type of mlperf style event. * @param string key Mlperf log key. * @param T value Mlperf log value. @@ -78,7 +95,7 @@ class mlperf_logging : public callback_base { * @param double epoch Current epoch number. */ template - void print(std::ostream& os, mlperf_logging::event_type et, std::string key, + void print(std::ostringstream& os, mlperf_logging::event_type et, std::string key, T value, char const* file, size_t line, double epoch = -1) const; void setup(model *m) override; @@ -93,22 +110,15 @@ class mlperf_logging : public callback_base { private: /** @brief Populate log with mlperf event type. - * @param ostream os Stores log string. + * @param ostringstream os Stores log string. * @param event_type et Type of mlperf style event. */ - void print_event_type(std::ostream& os, mlperf_logging::event_type et) const; + void print_event_type(std::ostringstream& os, mlperf_logging::event_type et) const; /** @brief Populate log with value. - * @param ostream os Stores log string. + * @param ostringstream os Stores log string. * @param event_type et Mlperf log value. */ - void print_value(std::ostream& os, double value) const; - void print_value(std::ostream& os, long value) const; - void print_value(std::ostream& os, size_t value) const; - void print_value(std::ostream& os, std::string value) const; - //FIXME: Always picks this function first - //template - //void print_value(std::ostream& os, T value) const; static size_t get_ms_since_epoch(); @@ -117,10 +127,14 @@ class mlperf_logging : public callback_base { //FIXME: get logger to output file /* @brief name of output file. Default = results.txt */ std::string m_output_filename; - - //FIXME: Add custom logging tag /* @brief DiHydrogen logger */ - h2::Logger m_logger; + h2::Logger m_logger{":::MLLOG", m_output_filename}; + std::string m_sub_benchmark; + std::string m_sub_org; + std::string m_sub_division; + std::string m_sub_status; + std::string m_sub_platform; + }; // class mlperf_logging diff --git a/src/callbacks/mlperf_logging.cpp b/src/callbacks/mlperf_logging.cpp index 18244ec33f4..c48ed59837c 100644 --- a/src/callbacks/mlperf_logging.cpp +++ b/src/callbacks/mlperf_logging.cpp @@ -42,8 +42,36 @@ namespace lbann { namespace callback { +// FIXME Does this need an anon namespace since it's only in the cpp file? +void print_value(std::ostringstream& os, double value) +{ + os << value; +} +void print_value(std::ostringstream& os, long value) +{ + os << value; +} +void print_value(std::ostringstream& os, size_t value) +{ + os << value; +} +void print_value(std::ostringstream& os, std::string const& value) +{ + os << "\"" << value << "\""; +} +void print_value(std::ostringstream& os, char const* value) +{ + os << "\"" << value << "\""; +} template -void mlperf_logging::print(std::ostream& os, mlperf_logging::event_type et, +void print_value(std::ostringstream& os, T value) +{ + //FIXME: Should I push the value anyway? + os << "UNKNOWN_DATA_TYPE"; +} + +template +void mlperf_logging::print(std::ostringstream& os, mlperf_logging::event_type et, std::string key, T value, char const* file, size_t line, double epoch) const { @@ -54,19 +82,22 @@ void mlperf_logging::print(std::ostream& os, mlperf_logging::event_type et, print_event_type(os, et); os << "\", " - << "\"key\": " << key << "\", " + << "\"key\": \"" << key << "\", " << "\"value\": "; print_value(os, value); os << ", " << "\"metadata\": {\"file\": \"" << file << "\", " << "\"lineno\": " << line; if(epoch < 0) - os << "}}\n"; + os << "}}"; else - os << ", " << "\"epoch_num\": " << epoch << "}}\n"; + os << ", " << "\"epoch_num\": " << epoch << "}}"; + + H2_INFO(os.str()); + os.flush(); } -void mlperf_logging::print_event_type(std::ostream& os, mlperf_logging::event_type et) const +void mlperf_logging::print_event_type(std::ostringstream& os, mlperf_logging::event_type et) const { switch (et) { case mlperf_logging::event_type::TIME_POINT: os << "POINT_IN_TIME"; break; @@ -76,30 +107,6 @@ void mlperf_logging::print_event_type(std::ostream& os, mlperf_logging::event_ty } } -void mlperf_logging::print_value(std::ostream& os, double value) const -{ - os << value; -} -void mlperf_logging::print_value(std::ostream& os, long value) const -{ - os << value; -} -void mlperf_logging::print_value(std::ostream& os, size_t value) const -{ - os << value; -} -void mlperf_logging::print_value(std::ostream& os, std::string value) const -{ - os << value; -} -/*template -void mlperf_logging::print_value(std::ostream& os, T value) const -{ - //FIXME: Should I push the value anyway? - os << "UNKNOWN_DATA_TYPE"; -} -*/ - size_t mlperf_logging::get_ms_since_epoch() { using namespace std::chrono; @@ -117,35 +124,24 @@ void mlperf_logging::setup(model *m) print(os, mlperf_logging::event_type::TIME_POINT, "cache_clear", value, __FILE__, __LINE__); - //FIXME: Make these user input vars - value = "oc20"; print(os, mlperf_logging::event_type::TIME_POINT, "submission_benchmark", - value, __FILE__, __LINE__); + m_sub_benchmark, __FILE__, __LINE__); - value = "LBANN"; print(os, mlperf_logging::event_type::TIME_POINT, "submission_org", - value, __FILE__, __LINE__); + m_sub_org, __FILE__, __LINE__); - //FIXME: value = closed? - value = "closed"; print(os, mlperf_logging::event_type::TIME_POINT, "submission_division", - value, __FILE__, __LINE__); + m_sub_division, __FILE__, __LINE__); - //FIXME: value = onprem? - value = "onprem"; print(os, mlperf_logging::event_type::TIME_POINT, "submission_status", - value, __FILE__, __LINE__); + m_sub_status, __FILE__, __LINE__); - //FIXME: value = SUBMISSION_PLATFORM_PLACEHOLDER? - value = "?"; print(os, mlperf_logging::event_type::TIME_POINT, "submission_platform", - value, __FILE__, __LINE__); + m_sub_platform, __FILE__, __LINE__); value = "null"; print(os, mlperf_logging::event_type::TIME_POINT, "init_start", value, __FILE__, __LINE__); - - H2_INFO(os.str()); } void mlperf_logging::on_setup_end(model *m) { @@ -227,8 +223,6 @@ void mlperf_logging::on_setup_end(model *m) print(os, mlperf_logging::event_type::TIME_POINT, "init_stop", "null", __FILE__, __LINE__); - - H2_INFO(os.str()); } void mlperf_logging::on_epoch_begin(model *m) @@ -239,8 +233,6 @@ void mlperf_logging::on_epoch_begin(model *m) print(os, mlperf_logging::event_type::INT_START, "epoch_start", "null", __FILE__, __LINE__, epoch); - - H2_INFO(os.str()); } void mlperf_logging::on_epoch_end(model *m) @@ -251,8 +243,6 @@ void mlperf_logging::on_epoch_end(model *m) print(os, mlperf_logging::event_type::INT_START, "epoch_stop", "null", __FILE__, __LINE__, epoch); - - H2_INFO(os.str()); } void mlperf_logging::on_train_begin(model *m) @@ -264,8 +254,6 @@ void mlperf_logging::on_train_begin(model *m) //FIXME: run_start? Same time stamp as epoch 1 in results print(os, mlperf_logging::event_type::INT_START, "run_start", "null", __FILE__, __LINE__, epoch); - - H2_INFO(os.str()); } void mlperf_logging::on_train_end(model *m) @@ -277,8 +265,6 @@ void mlperf_logging::on_train_end(model *m) //FIXME: run_stop? End of training? print(os, mlperf_logging::event_type::INT_START, "run_stop", "null", __FILE__, __LINE__, epoch); - - H2_INFO(os.str()); } void mlperf_logging::on_batch_evaluate_begin(model *m) @@ -289,8 +275,6 @@ void mlperf_logging::on_batch_evaluate_begin(model *m) print(os, mlperf_logging::event_type::INT_START, "eval_start", "null", __FILE__, __LINE__, epoch); - - H2_INFO(os.str()); } void mlperf_logging::on_batch_evaluate_end(model *m) @@ -307,8 +291,6 @@ void mlperf_logging::on_batch_evaluate_end(model *m) print(os, mlperf_logging::event_type::TIME_POINT, "eval_error", static_cast(eval_error), __FILE__, __LINE__, epoch); - - H2_INFO(os.str()); } std::unique_ptr @@ -318,7 +300,12 @@ build_mlperf_logging_callback_from_pbuf( { const auto& params = dynamic_cast(proto_msg); - return std::make_unique(params.output_filename()); + return std::make_unique(params.sub_benchmark(), + params.sub_org(), + params.sub_division(), + params.sub_status(), + params.sub_platform(), + params.output_filename()); } } // namespace callback } // namespace lbann diff --git a/src/proto/callbacks.proto b/src/proto/callbacks.proto index a277cb6c4aa..23ab3982acd 100644 --- a/src/proto/callbacks.proto +++ b/src/proto/callbacks.proto @@ -428,6 +428,11 @@ message Callback { /** @brief Prints mlperf compliant benchmark logs */ message CallbackMlperfLogging { - string output_filename = 1; + string output_filename = 1; // Output filename + string sub_benchmark = 2; // FIXME(KLG): document these + string sub_org = 3; + string sub_division = 4; + string sub_status = 5; + string sub_platform = 6; } }