diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index e724ebcb7aaa..111975870ac7 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -90,6 +90,7 @@ OBJECTS= \ $(PKGROOT)/src/common/stats.o \ $(PKGROOT)/src/common/survival_util.o \ $(PKGROOT)/src/common/threading_utils.o \ + $(PKGROOT)/src/common/ranking_utils.o \ $(PKGROOT)/src/common/timer.o \ $(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/c_api/c_api.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index ebd7c808cb7d..d89aadc3da33 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -90,6 +90,7 @@ OBJECTS= \ $(PKGROOT)/src/common/stats.o \ $(PKGROOT)/src/common/survival_util.o \ $(PKGROOT)/src/common/threading_utils.o \ + $(PKGROOT)/src/common/ranking_utils.o \ $(PKGROOT)/src/common/timer.o \ $(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/c_api/c_api.o \ diff --git a/include/xgboost/string_view.h b/include/xgboost/string_view.h index 98a3df2baac0..8b5bff7f6d3e 100644 --- a/include/xgboost/string_view.h +++ b/include/xgboost/string_view.h @@ -1,15 +1,15 @@ -/*! - * Copyright 2021 by XGBoost Contributors +/** + * Copyright 2021-2023 by XGBoost Contributors */ #ifndef XGBOOST_STRING_VIEW_H_ #define XGBOOST_STRING_VIEW_H_ -#include -#include +#include // CHECK_LT +#include // Span -#include -#include -#include -#include +#include // std::equal,std::min +#include // std::reverse_iterator +#include // std::ostream +#include // std::char_traits,std::string namespace xgboost { struct StringView { @@ -28,29 +28,31 @@ struct StringView { public: constexpr StringView() = default; - constexpr StringView(CharT const* str, size_t size) : str_{str}, size_{size} {} + constexpr StringView(CharT const* str, std::size_t size) : str_{str}, size_{size} {} explicit StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {} - StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {} // NOLINT + constexpr StringView(CharT const* str) // NOLINT + : str_{str}, size_{str == nullptr ? 0ul : Traits::length(str)} {} CharT const& operator[](size_t p) const { return str_[p]; } CharT const& at(size_t p) const { // NOLINT CHECK_LT(p, size_); return str_[p]; } - constexpr size_t size() const { return size_; } // NOLINT - StringView substr(size_t beg, size_t n) const { // NOLINT + constexpr std::size_t size() const { return size_; } // NOLINT + constexpr bool empty() const { return size() == 0; } // NOLINT + StringView substr(size_t beg, size_t n) const { // NOLINT CHECK_LE(beg, size_); size_t len = std::min(n, size_ - beg); return {str_ + beg, len}; } - CharT const* c_str() const { return str_; } // NOLINT + CharT const* c_str() const { return str_; } // NOLINT constexpr CharT const* cbegin() const { return str_; } // NOLINT constexpr CharT const* cend() const { return str_ + size(); } // NOLINT constexpr CharT const* begin() const { return str_; } // NOLINT constexpr CharT const* end() const { return str_ + size(); } // NOLINT - const_reverse_iterator rbegin() const noexcept { // NOLINT + const_reverse_iterator rbegin() const noexcept { // NOLINT return const_reverse_iterator(this->end()); } const_reverse_iterator crbegin() const noexcept { // NOLINT diff --git a/src/common/ranking_utils.cc b/src/common/ranking_utils.cc new file mode 100644 index 000000000000..f0b1c1a5ee77 --- /dev/null +++ b/src/common/ranking_utils.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2023 by XGBoost contributors + */ +#include "ranking_utils.h" + +#include // std::uint32_t +#include // std::ostringstream +#include // std::string,std::sscanf + +#include "xgboost/string_view.h" // StringView + +namespace xgboost { +namespace ltr { +std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus) { + std::string out_name; + if (!param.empty()) { + std::ostringstream os; + if (std::sscanf(param.c_str(), "%u[-]?", topn) == 1) { + os << name << '@' << param; + out_name = os.str(); + } else { + os << name << param; + out_name = os.str(); + } + if (*param.crbegin() == '-') { + *minus = true; + } + } else { + out_name = name.c_str(); + } + return out_name; +} +} // namespace ltr +} // namespace xgboost diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h new file mode 100644 index 000000000000..35ee36c2185d --- /dev/null +++ b/src/common/ranking_utils.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 by XGBoost contributors + */ +#ifndef XGBOOST_COMMON_RANKING_UTILS_H_ +#define XGBOOST_COMMON_RANKING_UTILS_H_ + +#include // std::size_t +#include // std::uint32_t +#include // std::string + +#include "xgboost/string_view.h" // StringView + +namespace xgboost { +namespace ltr { +/** + * \brief Construct name for ranking metric given parameters. + * + * \param [in] name Null terminated string for metric name + * \param [in] param Null terminated string for parameter like the `3-` in `ndcg@3-`. + * \param [out] topn Top n documents parsed from param. Unchanged if it's not specified. + * \param [out] minus Whether we should turn the score into loss. Unchanged if it's not + * specified. + * + * \return The name of the metric. + */ +std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus); +} // namespace ltr +} // namespace xgboost +#endif // XGBOOST_COMMON_RANKING_UTILS_H_ diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index ed31a0ebc596..683a0f0ed12e 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -28,6 +28,7 @@ #include "../collective/communicator-inl.h" #include "../common/math.h" +#include "../common/ranking_utils.h" // MakeMetricName #include "../common/threading_utils.h" #include "metric_common.h" #include "xgboost/host_device_vector.h" @@ -232,23 +233,7 @@ struct EvalRank : public Metric, public EvalRankConfig { protected: explicit EvalRank(const char* name, const char* param) { - using namespace std; // NOLINT(*) - - if (param != nullptr) { - std::ostringstream os; - if (sscanf(param, "%u[-]?", &topn) == 1) { - os << name << '@' << param; - this->name = os.str(); - } else { - os << name << param; - this->name = os.str(); - } - if (param[strlen(param) - 1] == '-') { - minus = true; - } - } else { - this->name = name; - } + this->name = ltr::MakeMetricName(name, param, &topn, &minus); } virtual double EvalGroup(PredIndPairContainer *recptr) const = 0; diff --git a/tests/cpp/common/test_ranking_utils.cc b/tests/cpp/common/test_ranking_utils.cc new file mode 100644 index 000000000000..ea72edd9fdb7 --- /dev/null +++ b/tests/cpp/common/test_ranking_utils.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include + +#include // std::uint32_t + +#include "../../../src/common/ranking_utils.h" + +namespace xgboost { +namespace ltr { +TEST(RankingUtils, MakeMetricName) { + std::uint32_t topn{32}; + bool minus{false}; + auto name = MakeMetricName("ndcg", "3-", &topn, &minus); + ASSERT_EQ(name, "ndcg@3-"); + ASSERT_EQ(topn, 3); + ASSERT_TRUE(minus); + + name = MakeMetricName("ndcg", "6", &topn, &minus); + ASSERT_EQ(topn, 6); + ASSERT_TRUE(minus); // unchanged + + minus = false; + name = MakeMetricName("ndcg", "-", &topn, &minus); + ASSERT_EQ(topn, 6); // unchanged + ASSERT_TRUE(minus); + + name = MakeMetricName("ndcg", nullptr, &topn, &minus); + ASSERT_EQ(topn, 6); // unchanged + ASSERT_TRUE(minus); // unchanged + + name = MakeMetricName("ndcg", StringView{}, &topn, &minus); + ASSERT_EQ(topn, 6); // unchanged + ASSERT_TRUE(minus); // unchanged +} +} // namespace ltr +} // namespace xgboost diff --git a/tests/cpp/common/test_ranking_utils.cu b/tests/cpp/common/test_ranking_utils.cu index 7e0f4244cefb..8d240e41ff77 100644 --- a/tests/cpp/common/test_ranking_utils.cu +++ b/tests/cpp/common/test_ranking_utils.cu @@ -1,3 +1,6 @@ +/** + * Copyright 2021 by XGBoost Contributors + */ #include #include "../../../src/common/ranking_utils.cuh" #include "../../../src/common/device_helpers.cuh" diff --git a/tests/cpp/common/test_string_view.cc b/tests/cpp/common/test_string_view.cc index b2ba24c7180e..e89689162661 100644 --- a/tests/cpp/common/test_string_view.cc +++ b/tests/cpp/common/test_string_view.cc @@ -1,9 +1,13 @@ -/*! - * Copyright (c) by XGBoost Contributors 2021 +/** + * Copyright 2021-2023 by XGBoost Contributors */ #include #include -#include + +#include // std::equal +#include // std::stringstream +#include // std::string + namespace xgboost { TEST(StringView, Basic) { StringView str{"This is a string."}; @@ -24,5 +28,16 @@ TEST(StringView, Basic) { ASSERT_FALSE(substr == "i"); ASSERT_TRUE(std::equal(substr.crbegin(), substr.crend(), StringView{"si"}.cbegin())); + + { + StringView empty{nullptr}; + ASSERT_TRUE(empty.empty()); + } + { + StringView empty{""}; + ASSERT_TRUE(empty.empty()); + StringView empty2{nullptr}; + ASSERT_EQ(empty, empty2); + } } } // namespace xgboost