Skip to content

Commit

Permalink
Extract make metric name from ranking metric.
Browse files Browse the repository at this point in the history
- Extract the metric parsing routine from ranking.
- Add test.
  • Loading branch information
trivialfis committed Feb 9, 2023
1 parent 48cefa0 commit e7b9792
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 34 deletions.
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ OBJECTS= \
$(PKGROOT)/src/common/stats.o \
$(PKGROOT)/src/common/survival_util.o \
$(PKGROOT)/src/common/threading_utils.o \
$(PKGROOT)/src/common/ranking_utils.o \
$(PKGROOT)/src/common/timer.o \
$(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ OBJECTS= \
$(PKGROOT)/src/common/stats.o \
$(PKGROOT)/src/common/survival_util.o \
$(PKGROOT)/src/common/threading_utils.o \
$(PKGROOT)/src/common/ranking_utils.o \
$(PKGROOT)/src/common/timer.o \
$(PKGROOT)/src/common/version.o \
$(PKGROOT)/src/c_api/c_api.o \
Expand Down
30 changes: 16 additions & 14 deletions include/xgboost/string_view.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
/*!
* Copyright 2021 by XGBoost Contributors
/**
* Copyright 2021-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_STRING_VIEW_H_
#define XGBOOST_STRING_VIEW_H_
#include <xgboost/logging.h>
#include <xgboost/span.h>
#include <xgboost/logging.h> // CHECK_LT
#include <xgboost/span.h> // Span

#include <algorithm>
#include <iterator>
#include <ostream>
#include <string>
#include <algorithm> // std::equal,std::min
#include <iterator> // std::reverse_iterator
#include <ostream> // std::ostream
#include <string> // std::char_traits,std::string

namespace xgboost {
struct StringView {
Expand All @@ -28,29 +28,31 @@ struct StringView {

public:
constexpr StringView() = default;
constexpr StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}
constexpr StringView(CharT const* str, std::size_t size) : str_{str}, size_{size} {}
explicit StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {}
StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {} // NOLINT
constexpr StringView(CharT const* str) // NOLINT
: str_{str}, size_{str == nullptr ? 0ul : Traits::length(str)} {}

CharT const& operator[](size_t p) const { return str_[p]; }
CharT const& at(size_t p) const { // NOLINT
CHECK_LT(p, size_);
return str_[p];
}
constexpr size_t size() const { return size_; } // NOLINT
StringView substr(size_t beg, size_t n) const { // NOLINT
constexpr std::size_t size() const { return size_; } // NOLINT
constexpr bool empty() const { return size() == 0; } // NOLINT
StringView substr(size_t beg, size_t n) const { // NOLINT
CHECK_LE(beg, size_);
size_t len = std::min(n, size_ - beg);
return {str_ + beg, len};
}
CharT const* c_str() const { return str_; } // NOLINT
CharT const* c_str() const { return str_; } // NOLINT

constexpr CharT const* cbegin() const { return str_; } // NOLINT
constexpr CharT const* cend() const { return str_ + size(); } // NOLINT
constexpr CharT const* begin() const { return str_; } // NOLINT
constexpr CharT const* end() const { return str_ + size(); } // NOLINT

const_reverse_iterator rbegin() const noexcept { // NOLINT
const_reverse_iterator rbegin() const noexcept { // NOLINT
return const_reverse_iterator(this->end());
}
const_reverse_iterator crbegin() const noexcept { // NOLINT
Expand Down
34 changes: 34 additions & 0 deletions src/common/ranking_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include "ranking_utils.h"

#include <cstdint> // std::uint32_t
#include <sstream> // std::ostringstream
#include <string> // std::string,std::sscanf

#include "xgboost/string_view.h" // StringView

namespace xgboost {
namespace ltr {
std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus) {
std::string out_name;
if (!param.empty()) {
std::ostringstream os;
if (std::sscanf(param.c_str(), "%u[-]?", topn) == 1) {
os << name << '@' << param;
out_name = os.str();
} else {
os << name << param;
out_name = os.str();
}
if (*param.crbegin() == '-') {
*minus = true;
}
} else {
out_name = name.c_str();
}
return out_name;
}
} // namespace ltr
} // namespace xgboost
29 changes: 29 additions & 0 deletions src/common/ranking_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#ifndef XGBOOST_COMMON_RANKING_UTILS_H_
#define XGBOOST_COMMON_RANKING_UTILS_H_

#include <cstddef> // std::size_t
#include <cstdint> // std::uint32_t
#include <string> // std::string

#include "xgboost/string_view.h" // StringView

namespace xgboost {
namespace ltr {
/**
* \brief Construct name for ranking metric given parameters.
*
* \param [in] name Null terminated string for metric name
* \param [in] param Null terminated string for parameter like the `3-` in `ndcg@3-`.
* \param [out] topn Top n documents parsed from param. Unchanged if it's not specified.
* \param [out] minus Whether we should turn the score into loss. Unchanged if it's not
* specified.
*
* \return The name of the metric.
*/
std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus);
} // namespace ltr
} // namespace xgboost
#endif // XGBOOST_COMMON_RANKING_UTILS_H_
19 changes: 2 additions & 17 deletions src/metric/rank_metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/ranking_utils.h" // MakeMetricName
#include "../common/threading_utils.h"
#include "metric_common.h"
#include "xgboost/host_device_vector.h"
Expand Down Expand Up @@ -232,23 +233,7 @@ struct EvalRank : public Metric, public EvalRankConfig {

protected:
explicit EvalRank(const char* name, const char* param) {
using namespace std; // NOLINT(*)

if (param != nullptr) {
std::ostringstream os;
if (sscanf(param, "%u[-]?", &topn) == 1) {
os << name << '@' << param;
this->name = os.str();
} else {
os << name << param;
this->name = os.str();
}
if (param[strlen(param) - 1] == '-') {
minus = true;
}
} else {
this->name = name;
}
this->name = ltr::MakeMetricName(name, param, &topn, &minus);
}

virtual double EvalGroup(PredIndPairContainer *recptr) const = 0;
Expand Down
37 changes: 37 additions & 0 deletions tests/cpp/common/test_ranking_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>

#include <cstdint> // std::uint32_t

#include "../../../src/common/ranking_utils.h"

namespace xgboost {
namespace ltr {
TEST(RankingUtils, MakeMetricName) {
std::uint32_t topn{32};
bool minus{false};
auto name = MakeMetricName("ndcg", "3-", &topn, &minus);
ASSERT_EQ(name, "ndcg@3-");
ASSERT_EQ(topn, 3);
ASSERT_TRUE(minus);

name = MakeMetricName("ndcg", "6", &topn, &minus);
ASSERT_EQ(topn, 6);
ASSERT_TRUE(minus); // unchanged

name = MakeMetricName("ndcg", "-", &topn, &minus);
ASSERT_EQ(topn, 6); // unchanged
ASSERT_TRUE(minus);

name = MakeMetricName("ndcg", nullptr, &topn, &minus);
ASSERT_EQ(topn, 6); // unchanged
ASSERT_TRUE(minus); // unchanged

name = MakeMetricName("ndcg", StringView{}, &topn, &minus);
ASSERT_EQ(topn, 6); // unchanged
ASSERT_TRUE(minus); // unchanged
}
} // namespace ltr
} // namespace xgboost
3 changes: 3 additions & 0 deletions tests/cpp/common/test_ranking_utils.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
/**
* Copyright 2021 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include "../../../src/common/ranking_utils.cuh"
#include "../../../src/common/device_helpers.cuh"
Expand Down
21 changes: 18 additions & 3 deletions tests/cpp/common/test_string_view.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
/*!
* Copyright (c) by XGBoost Contributors 2021
/**
* Copyright 2021-2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/string_view.h>
#include <string_view>

#include <algorithm> // std::equal
#include <sstream> // std::stringstream
#include <string> // std::string

namespace xgboost {
TEST(StringView, Basic) {
StringView str{"This is a string."};
Expand All @@ -24,5 +28,16 @@ TEST(StringView, Basic) {
ASSERT_FALSE(substr == "i");

ASSERT_TRUE(std::equal(substr.crbegin(), substr.crend(), StringView{"si"}.cbegin()));

{
StringView empty{nullptr};
ASSERT_TRUE(empty.empty());
}
{
StringView empty{""};
ASSERT_TRUE(empty.empty());
StringView empty2{nullptr};
ASSERT_EQ(empty, empty2);
}
}
} // namespace xgboost

0 comments on commit e7b9792

Please sign in to comment.