Skip to content

Commit 0bdccf4

Browse files
zhanyongwancopybara-github
authored andcommitted
Add a DistanceFrom() matcher for general distance comparison.
We have a bunch of matchers for asserting that a value is near the target value, e.g. `DoubleNear()` and `FloatNear()`. These matchers only work for specific types (`double` and `float`). They are not flexible enough to support other types that have the notion of a "distance" (e.g. N-dimensional points and vectors, which are commonly used in ML). In this diff, we generalize the idea to a `DistanceFrom(target, get_distance, m)` matcher that works on arbitrary types that have the "distance" concept (the `get_distance` argument is optional and can be omitted for types that support `-`, and `std::abs()`). What it does: 1. compute the distance between the value and the target using `get_distance(value, target)`; if `get_distance` is omitted, compute the distance as `std::abs(value - target)`. 2. match the distance against matcher `m`; if the match succeeds, the `DistanceFrom()` match succeeds. Examples: ``` // 0.5's distance from 0.6 should be <= 0.2. EXPECT_THAT(0.5, DistanceFrom(0.6, Le(0.2))); Vector2D v1(3.0, 4.0), v2(3.2, 6.0); // v1's distance from v2, as computed by EuclideanDistance(v1, v2), // should be >= 1.0. EXPECT_THAT(v1, DistanceFrom(v2, EuclideanDistance, Ge(1.0))); ``` PiperOrigin-RevId: 734593292 Change-Id: Id6bb7074dc4aa4d8abd78b57ad2426637e590de5
1 parent e88cb95 commit 0bdccf4

File tree

3 files changed

+276
-0
lines changed

3 files changed

+276
-0
lines changed

docs/reference/matchers.md

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ Matcher | Description
4242
| `Lt(value)` | `argument < value` |
4343
| `Ne(value)` | `argument != value` |
4444
| `IsFalse()` | `argument` evaluates to `false` in a Boolean context. |
45+
| `DistanceFrom(target, m)` | The distance between `argument` and `target` (computed by `std::abs(argument - target)`) matches `m`. |
46+
| `DistanceFrom(target, get_distance, m)` | The distance between `argument` and `target` (computed by `get_distance(argument, target)`) matches `m`. |
4547
| `IsTrue()` | `argument` evaluates to `true` in a Boolean context. |
4648
| `IsNull()` | `argument` is a `NULL` pointer (raw or smart). |
4749
| `NotNull()` | `argument` is a non-null pointer (raw or smart). |

googlemock/include/gmock/gmock-matchers.h

+128
Original file line numberDiff line numberDiff line change
@@ -2855,6 +2855,54 @@ class ContainsMatcherImpl : public QuantifierMatcherImpl<Container> {
28552855
}
28562856
};
28572857

2858+
// Implements DistanceFrom(target, get_distance, distance_matcher) for the given
2859+
// argument types:
2860+
// * V is the type of the value to be matched.
2861+
// * T is the type of the target value.
2862+
// * Distance is the type of the distance between V and T.
2863+
// * GetDistance is the type of the functor for computing the distance between
2864+
// V and T.
2865+
template <typename V, typename T, typename Distance, typename GetDistance>
2866+
class DistanceFromMatcherImpl : public MatcherInterface<V> {
2867+
public:
2868+
// Arguments:
2869+
// * target: the target value.
2870+
// * get_distance: the functor for computing the distance between the value
2871+
// being matched and target.
2872+
// * distance_matcher: the matcher for checking the distance.
2873+
DistanceFromMatcherImpl(T target, GetDistance get_distance,
2874+
Matcher<const Distance&> distance_matcher)
2875+
: target_(std::move(target)),
2876+
get_distance_(std::move(get_distance)),
2877+
distance_matcher_(std::move(distance_matcher)) {}
2878+
2879+
// Describes what this matcher does.
2880+
void DescribeTo(::std::ostream* os) const override {
2881+
distance_matcher_.DescribeTo(os);
2882+
*os << " away from " << PrintToString(target_);
2883+
}
2884+
2885+
void DescribeNegationTo(::std::ostream* os) const override {
2886+
distance_matcher_.DescribeNegationTo(os);
2887+
*os << " away from " << PrintToString(target_);
2888+
}
2889+
2890+
bool MatchAndExplain(V value, MatchResultListener* listener) const override {
2891+
const auto distance = get_distance_(value, target_);
2892+
const bool match = distance_matcher_.Matches(distance);
2893+
if (!match && listener->IsInterested()) {
2894+
*listener << "which is " << PrintToString(distance) << " away from "
2895+
<< PrintToString(target_);
2896+
}
2897+
return match;
2898+
}
2899+
2900+
private:
2901+
const T target_;
2902+
const GetDistance get_distance_;
2903+
const Matcher<const Distance&> distance_matcher_;
2904+
};
2905+
28582906
// Implements Each(element_matcher) for the given argument type Container.
28592907
// Symmetric to ContainsMatcherImpl.
28602908
template <typename Container>
@@ -2990,6 +3038,50 @@ auto Second(T& x, Rank1) -> decltype((x.second)) { // NOLINT
29903038
}
29913039
} // namespace pair_getters
29923040

3041+
// Default functor for computing the distance between two values.
3042+
struct DefaultGetDistance {
3043+
template <typename T, typename U>
3044+
auto operator()(const T& lhs, const U& rhs) const {
3045+
return std::abs(lhs - rhs);
3046+
}
3047+
};
3048+
3049+
// Implements polymorphic DistanceFrom(target, get_distance, distance_matcher)
3050+
// matcher. Template arguments:
3051+
// * T is the type of the target value.
3052+
// * GetDistance is the type of the functor for computing the distance between
3053+
// the value being matched and the target.
3054+
// * DistanceMatcher is the type of the matcher for checking the distance.
3055+
template <typename T, typename GetDistance, typename DistanceMatcher>
3056+
class DistanceFromMatcher {
3057+
public:
3058+
// Arguments:
3059+
// * target: the target value.
3060+
// * get_distance: the functor for computing the distance between the value
3061+
// being matched and target.
3062+
// * distance_matcher: the matcher for checking the distance.
3063+
DistanceFromMatcher(T target, GetDistance get_distance,
3064+
DistanceMatcher distance_matcher)
3065+
: target_(std::move(target)),
3066+
get_distance_(std::move(get_distance)),
3067+
distance_matcher_(std::move(distance_matcher)) {}
3068+
3069+
DistanceFromMatcher(const DistanceFromMatcher& other) = default;
3070+
3071+
// Implicitly converts to a monomorphic matcher of the given type.
3072+
template <typename V>
3073+
operator Matcher<V>() const { // NOLINT
3074+
using Distance = decltype(get_distance_(std::declval<V>(), target_));
3075+
return Matcher<V>(new DistanceFromMatcherImpl<V, T, Distance, GetDistance>(
3076+
target_, get_distance_, distance_matcher_));
3077+
}
3078+
3079+
private:
3080+
const T target_;
3081+
const GetDistance get_distance_;
3082+
const DistanceMatcher distance_matcher_;
3083+
};
3084+
29933085
// Implements Key(inner_matcher) for the given argument pair type.
29943086
// Key(inner_matcher) matches an std::pair whose 'first' field matches
29953087
// inner_matcher. For example, Contains(Key(Ge(5))) can be used to match an
@@ -4372,6 +4464,42 @@ inline internal::FloatingEqMatcher<double> DoubleNear(double rhs,
43724464
return internal::FloatingEqMatcher<double>(rhs, false, max_abs_error);
43734465
}
43744466

4467+
// The DistanceFrom(target, get_distance, m) and DistanceFrom(target, m)
4468+
// matchers work on arbitrary types that have the "distance" concept. What they
4469+
// do:
4470+
//
4471+
// 1. compute the distance between the value and the target using
4472+
// get_distance(value, target) if get_distance is provided; otherwise compute
4473+
// the distance as std::abs(value - target).
4474+
// 2. match the distance against the user-provided matcher m; if the match
4475+
// succeeds, the DistanceFrom() match succeeds.
4476+
//
4477+
// Examples:
4478+
//
4479+
// // 0.5's distance from 0.6 should be <= 0.2.
4480+
// EXPECT_THAT(0.5, DistanceFrom(0.6, Le(0.2)));
4481+
//
4482+
// Vector2D v1(3.0, 4.0), v2(3.2, 6.0);
4483+
// // v1's distance from v2, as computed by EuclideanDistance(v1, v2),
4484+
// // should be >= 1.0.
4485+
// EXPECT_THAT(v1, DistanceFrom(v2, EuclideanDistance, Ge(1.0)));
4486+
4487+
template <typename T, typename GetDistance, typename DistanceMatcher>
4488+
inline internal::DistanceFromMatcher<T, GetDistance, DistanceMatcher>
4489+
DistanceFrom(T target, GetDistance get_distance,
4490+
DistanceMatcher distance_matcher) {
4491+
return internal::DistanceFromMatcher<T, GetDistance, DistanceMatcher>(
4492+
std::move(target), std::move(get_distance), std::move(distance_matcher));
4493+
}
4494+
4495+
template <typename T, typename DistanceMatcher>
4496+
inline internal::DistanceFromMatcher<T, internal::DefaultGetDistance,
4497+
DistanceMatcher>
4498+
DistanceFrom(T target, DistanceMatcher distance_matcher) {
4499+
return DistanceFrom(std::move(target), internal::DefaultGetDistance(),
4500+
std::move(distance_matcher));
4501+
}
4502+
43754503
// Creates a matcher that matches any double argument approximately equal to
43764504
// rhs, up to the specified max absolute error bound, including NaN values when
43774505
// rhs is NaN. The max absolute error bound must be non-negative.

googlemock/test/gmock-matchers-arithmetic_test.cc

+146
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <cmath>
3535
#include <limits>
3636
#include <memory>
37+
#include <ostream>
3738
#include <string>
3839

3940
#include "gmock/gmock.h"
@@ -398,6 +399,151 @@ TEST(NanSensitiveDoubleNearTest, CanDescribeSelfWithNaNs) {
398399
EXPECT_EQ("are an almost-equal pair", Describe(m));
399400
}
400401

402+
// Tests that DistanceFrom() can describe itself properly.
403+
TEST(DistanceFrom, CanDescribeSelf) {
404+
Matcher<double> m = DistanceFrom(1.5, Lt(0.1));
405+
EXPECT_EQ(Describe(m), "is < 0.1 away from 1.5");
406+
407+
m = DistanceFrom(2.5, Gt(0.2));
408+
EXPECT_EQ(Describe(m), "is > 0.2 away from 2.5");
409+
}
410+
411+
// Tests that DistanceFrom() can explain match failure.
412+
TEST(DistanceFrom, CanExplainMatchFailure) {
413+
Matcher<double> m = DistanceFrom(1.5, Lt(0.1));
414+
EXPECT_EQ(Explain(m, 2.0), "which is 0.5 away from 1.5");
415+
}
416+
417+
// Tests that DistanceFrom() matches a double that is within the given range of
418+
// the given value.
419+
TEST(DistanceFrom, MatchesDoubleWithinRange) {
420+
const Matcher<double> m = DistanceFrom(0.5, Le(0.1));
421+
EXPECT_TRUE(m.Matches(0.45));
422+
EXPECT_TRUE(m.Matches(0.5));
423+
EXPECT_TRUE(m.Matches(0.55));
424+
EXPECT_FALSE(m.Matches(0.39));
425+
EXPECT_FALSE(m.Matches(0.61));
426+
}
427+
428+
// Tests that DistanceFrom() matches a double reference that is within the given
429+
// range of the given value.
430+
TEST(DistanceFrom, MatchesDoubleRefWithinRange) {
431+
const Matcher<const double&> m = DistanceFrom(0.5, Le(0.1));
432+
EXPECT_TRUE(m.Matches(0.45));
433+
EXPECT_TRUE(m.Matches(0.5));
434+
EXPECT_TRUE(m.Matches(0.55));
435+
EXPECT_FALSE(m.Matches(0.39));
436+
EXPECT_FALSE(m.Matches(0.61));
437+
}
438+
439+
// Tests that DistanceFrom() can be implicitly converted to a matcher depending
440+
// on the type of the argument.
441+
TEST(DistanceFrom, CanBeImplicitlyConvertedToMatcher) {
442+
EXPECT_THAT(0.58, DistanceFrom(0.5, Le(0.1)));
443+
EXPECT_THAT(0.2, Not(DistanceFrom(0.5, Le(0.1))));
444+
445+
EXPECT_THAT(0.58f, DistanceFrom(0.5f, Le(0.1f)));
446+
EXPECT_THAT(0.7f, Not(DistanceFrom(0.5f, Le(0.1f))));
447+
}
448+
449+
// Tests that DistanceFrom() can be used on compatible types (i.e. not
450+
// everything has to be of the same type).
451+
TEST(DistanceFrom, CanBeUsedOnCompatibleTypes) {
452+
EXPECT_THAT(0.58, DistanceFrom(0.5, Le(0.1f)));
453+
EXPECT_THAT(0.2, Not(DistanceFrom(0.5, Le(0.1f))));
454+
455+
EXPECT_THAT(0.58, DistanceFrom(0.5f, Le(0.1)));
456+
EXPECT_THAT(0.2, Not(DistanceFrom(0.5f, Le(0.1))));
457+
458+
EXPECT_THAT(0.58, DistanceFrom(0.5f, Le(0.1f)));
459+
EXPECT_THAT(0.2, Not(DistanceFrom(0.5f, Le(0.1f))));
460+
461+
EXPECT_THAT(0.58f, DistanceFrom(0.5, Le(0.1)));
462+
EXPECT_THAT(0.2f, Not(DistanceFrom(0.5, Le(0.1))));
463+
464+
EXPECT_THAT(0.58f, DistanceFrom(0.5, Le(0.1f)));
465+
EXPECT_THAT(0.2f, Not(DistanceFrom(0.5, Le(0.1f))));
466+
467+
EXPECT_THAT(0.58f, DistanceFrom(0.5f, Le(0.1)));
468+
EXPECT_THAT(0.2f, Not(DistanceFrom(0.5f, Le(0.1))));
469+
}
470+
471+
// A 2-dimensional point. For testing using DistanceFrom() with a custom type
472+
// that doesn't have a built-in distance function.
473+
class Point {
474+
public:
475+
Point(double x, double y) : x_(x), y_(y) {}
476+
double x() const { return x_; }
477+
double y() const { return y_; }
478+
479+
private:
480+
double x_;
481+
double y_;
482+
};
483+
484+
// Returns the distance between two points.
485+
double PointDistance(const Point& lhs, const Point& rhs) {
486+
return std::sqrt(std::pow(lhs.x() - rhs.x(), 2) +
487+
std::pow(lhs.y() - rhs.y(), 2));
488+
}
489+
490+
// Tests that DistanceFrom() can be used on a type with a custom distance
491+
// function.
492+
TEST(DistanceFrom, CanBeUsedOnTypeWithCustomDistanceFunction) {
493+
const Matcher<Point> m =
494+
DistanceFrom(Point(0.5, 0.5), PointDistance, Le(0.1));
495+
EXPECT_THAT(Point(0.45, 0.45), m);
496+
EXPECT_THAT(Point(0.2, 0.45), Not(m));
497+
}
498+
499+
// A wrapper around a double value. For testing using DistanceFrom() with a
500+
// custom type that has neither a built-in distance function nor a built-in
501+
// distance comparator.
502+
class Double {
503+
public:
504+
explicit Double(double value) : value_(value) {}
505+
Double(const Double& other) = default;
506+
double value() const { return value_; }
507+
508+
// Defines how to print a Double value. We don't use the AbslStringify API
509+
// because googletest doesn't require absl yet.
510+
friend void PrintTo(const Double& value, std::ostream* os) {
511+
*os << "Double(" << value.value() << ")";
512+
}
513+
514+
private:
515+
double value_;
516+
};
517+
518+
// Returns the distance between two Double values.
519+
Double DoubleDistance(Double lhs, Double rhs) {
520+
return Double(std::abs(lhs.value() - rhs.value()));
521+
}
522+
523+
MATCHER_P(DoubleLe, rhs, (negation ? "is > " : "is <= ") + PrintToString(rhs)) {
524+
return arg.value() <= rhs.value();
525+
}
526+
527+
// Tests that DistanceFrom() can describe itself properly for a type with a
528+
// custom printer.
529+
TEST(DistanceFrom, CanDescribeWithCustomPrinter) {
530+
const Matcher<Double> m =
531+
DistanceFrom(Double(0.5), DoubleDistance, DoubleLe(Double(0.1)));
532+
EXPECT_EQ(Describe(m), "is <= Double(0.1) away from Double(0.5)");
533+
EXPECT_EQ(DescribeNegation(m), "is > Double(0.1) away from Double(0.5)");
534+
}
535+
536+
// Tests that DistanceFrom() can be used with a custom distance function and
537+
// comparator.
538+
TEST(DistanceFrom, CanCustomizeDistanceAndComparator) {
539+
const Matcher<Double> m =
540+
DistanceFrom(Double(0.5), DoubleDistance, DoubleLe(Double(0.1)));
541+
EXPECT_TRUE(m.Matches(Double(0.45)));
542+
EXPECT_TRUE(m.Matches(Double(0.5)));
543+
EXPECT_FALSE(m.Matches(Double(0.39)));
544+
EXPECT_FALSE(m.Matches(Double(0.61)));
545+
}
546+
401547
// Tests that Not(m) matches any value that doesn't match m.
402548
TEST(NotTest, NegatesMatcher) {
403549
Matcher<int> m;

0 commit comments

Comments
 (0)