diff --git a/include/flux/core/functional.hpp b/include/flux/core/functional.hpp index 5ff0dc6c..a3a4b9a9 100644 --- a/include/flux/core/functional.hpp +++ b/include/flux/core/functional.hpp @@ -253,6 +253,39 @@ FLUX_EXPORT inline constexpr auto odd = detail::predicate([](auto const& val) -> } // namespace pred +namespace cmp { + +namespace detail { + +struct min_fn { + template + requires std::strict_weak_order + [[nodiscard]] + constexpr auto operator()(T&& t, U&& u, Cmp cmp = Cmp{}) const + -> std::common_reference_t + { + return std::invoke(cmp, u, t) ? FLUX_FWD(u) : FLUX_FWD(t); + }; +}; + +struct max_fn { + template + requires std::strict_weak_order + [[nodiscard]] + constexpr auto operator()(T&& t, U&& u, Cmp cmp = Cmp{}) const + -> std::common_reference_t + { + return !std::invoke(cmp, u, t) ? FLUX_FWD(u) : FLUX_FWD(t); + }; +}; + +} // namespace detail + +FLUX_EXPORT inline constexpr auto min = detail::min_fn{}; +FLUX_EXPORT inline constexpr auto max = detail::max_fn{}; + +} // namespace cmp + } // namespace flux #endif diff --git a/test/test_predicates.cpp b/test/test_predicates.cpp index bc14ee30..22654614 100644 --- a/test/test_predicates.cpp +++ b/test/test_predicates.cpp @@ -118,6 +118,134 @@ constexpr bool test_predicate_combiners() } static_assert(test_predicate_combiners()); +// Not really predicates, but we'll test them here anyway +constexpr bool test_comparisons() +{ + namespace cmp = flux::cmp; + + struct Test { + int i; + double d; + + bool operator==(Test const&) const = default; + }; + + // min of two same-type non-const lvalue references is an lvalue + { + int i = 0, j = 1; + cmp::min(i, j) = 99; + STATIC_CHECK(i == 99); + STATIC_CHECK(j == 1); + } + + // min of same-type mixed-const lvalue refs is a const ref + { + int i = 1; + int const j = 0; + auto& m = cmp::min(i, j); + static_assert(std::same_as); + STATIC_CHECK(m == 0); + } + + // min of same-type lvalue and prvalue is a prvalue + { + int const i = 1; + using M = decltype(cmp::min(i, i + 1)); + static_assert(std::same_as); + STATIC_CHECK(cmp::min(i, i + 1) == 1); + } + + // mixed-type min is a prvalue + { + int const i = 10; + long const j = 5; + using M = decltype(cmp::min(i, j)); + static_assert(std::same_as); + STATIC_CHECK(cmp::min(i, j) == 5); + } + + // Custom comparators work okay with min() + { + Test t1{1, 3.0}; + Test t2{1, 2.0}; + + auto cmp_test = [](Test t1, Test t2) { return t1.d < t2.d; }; + + STATIC_CHECK(cmp::min(t1, t2, cmp_test) == t2); + } + + // If arguments are equal, min() returns the first + { + int i = 1, j = 1; + int& m = cmp::min(i, j); + STATIC_CHECK(&m == &i); + + Test t1{1, 3.0}; + Test t2{1, 2.0}; + + STATIC_CHECK(cmp::min(t1, t2, flux::proj(std::less{}, &Test::i)) == t1); + } + + // max of two same-type non-const lvalue references is an lvalue + { + int i = 0, j = 1; + cmp::max(i, j) = 99; + STATIC_CHECK(i == 0); + STATIC_CHECK(j == 99); + } + + // max of same-type mixed-const lvalue refs is a const ref + { + int i = 1; + int const j = 0; + auto& m = cmp::max(i, j); + static_assert(std::same_as); + STATIC_CHECK(m == 1); + } + + // max of same-type lvalue and prvalue is a prvalue + { + int const i = 1; + using M = decltype(cmp::max(i, i + 1)); + static_assert(std::same_as); + STATIC_CHECK(cmp::max(i, i + 1) == 2); + } + + // mixed-type max is a prvalue + { + int const i = 10; + long const j = 5; + using M = decltype(cmp::max(i, j)); + static_assert(std::same_as); + STATIC_CHECK(cmp::max(i, j) == 10); + } + + // Custom comparators work okay with max() + { + Test t1{1, 3.0}; + Test t2{1, 2.0}; + + auto cmp_test = [](Test t1, Test t2) { return t1.d < t2.d; }; + + STATIC_CHECK(cmp::max(t1, t2, cmp_test) == t1); + } + + // If arguments are equal, max() returns the second + { + int i = 1, j = 1; + int& m = cmp::max(i, j); + STATIC_CHECK(&m == &j); + + Test t1{1, 3.0}; + Test t2{1, 2.0}; + + STATIC_CHECK(cmp::max(t1, t2, flux::proj(std::less{}, &Test::i)) == t2); + } + + return true; +} +static_assert(test_comparisons()); + } TEST_CASE("predicates") @@ -126,3 +254,8 @@ TEST_CASE("predicates") REQUIRE(test_predicate_combiners()); } +TEST_CASE("comparators") +{ + REQUIRE(test_comparisons()); +} +