Skip to content

Commit

Permalink
Add implementation for gamma(a, b) and distance_under_gamma(a, b, g)
Browse files Browse the repository at this point in the history
These are building stones for uniform floating point distribution.
  • Loading branch information
horenmar committed Nov 19, 2023
1 parent 6635987 commit 124b0a2
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ set(IMPL_HEADERS
${SOURCES_DIR}/internal/catch_polyfills.hpp
${SOURCES_DIR}/internal/catch_preprocessor.hpp
${SOURCES_DIR}/internal/catch_preprocessor_remove_parens.hpp
${SOURCES_DIR}/internal/catch_random_floating_point_helpers.hpp
${SOURCES_DIR}/internal/catch_random_number_generator.hpp
${SOURCES_DIR}/internal/catch_random_seed_generation.hpp
${SOURCES_DIR}/internal/catch_reporter_registry.hpp
Expand Down
19 changes: 3 additions & 16 deletions src/catch2/benchmark/detail/catch_stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <catch2/benchmark/detail/catch_stats.hpp>

#include <catch2/internal/catch_compiler_capabilities.hpp>
#include <catch2/internal/catch_floating_point_helpers.hpp>

#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -184,20 +185,6 @@ namespace Catch {
return std::sqrt( variance );
}

#if defined( __GNUC__ ) || defined( __clang__ )
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wfloat-equal"
#endif
// Used when we know we want == comparison of two doubles
// to centralize warning suppression
static bool directCompare( double lhs, double rhs ) {
return lhs == rhs;
}
#if defined( __GNUC__ ) || defined( __clang__ )
# pragma GCC diagnostic pop
#endif


static sample jackknife( double ( *estimator )( double const*,
double const* ),
double* first,
Expand Down Expand Up @@ -234,7 +221,7 @@ namespace Catch {
double g = idx - j;
std::nth_element(first, first + j, last);
auto xj = first[j];
if ( directCompare( g, 0 ) ) {
if ( Catch::Detail::directCompare( g, 0 ) ) {
return xj;
}

Expand Down Expand Up @@ -338,7 +325,7 @@ namespace Catch {
[point]( double x ) { return x < point; } ) /
static_cast<double>( n );
// degenerate case with uniform samples
if ( directCompare( prob_n, 0. ) ) {
if ( Catch::Detail::directCompare( prob_n, 0. ) ) {
return { point, point, point, confidence_level };
}

Expand Down
1 change: 1 addition & 0 deletions src/catch2/catch_all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
#include <catch2/internal/catch_preprocessor.hpp>
#include <catch2/internal/catch_preprocessor_internal_stringify.hpp>
#include <catch2/internal/catch_preprocessor_remove_parens.hpp>
#include <catch2/internal/catch_random_floating_point_helpers.hpp>
#include <catch2/internal/catch_random_number_generator.hpp>
#include <catch2/internal/catch_random_seed_generation.hpp>
#include <catch2/internal/catch_reporter_registry.hpp>
Expand Down
11 changes: 11 additions & 0 deletions src/catch2/internal/catch_floating_point_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ namespace Catch {
return i;
}

#if defined( __GNUC__ ) || defined( __clang__ )
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wfloat-equal"
#endif
bool directCompare( float lhs, float rhs ) { return lhs == rhs; }
bool directCompare( double lhs, double rhs ) { return lhs == rhs; }
#if defined( __GNUC__ ) || defined( __clang__ )
# pragma GCC diagnostic pop
#endif


} // end namespace Detail
} // end namespace Catch

5 changes: 5 additions & 0 deletions src/catch2/internal/catch_floating_point_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ namespace Catch {
uint32_t convertToBits(float f);
uint64_t convertToBits(double d);

// Used when we know we want == comparison of two doubles
// to centralize warning suppression
bool directCompare( float lhs, float rhs );
bool directCompare( double lhs, double rhs );

} // end namespace Detail


Expand Down
86 changes: 86 additions & 0 deletions src/catch2/internal/catch_random_floating_point_helpers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

// Copyright Catch2 Authors
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE.txt or copy at
// https://www.boost.org/LICENSE_1_0.txt)

// SPDX-License-Identifier: BSL-1.0
#ifndef CATCH_GAMMA_HPP_INCLUDED
#define CATCH_GAMMA_HPP_INCLUDED

#include <catch2/internal/catch_polyfills.hpp>

#include <cassert>
#include <cmath>
#include <cstdint>
#include <limits>
#include <type_traits>

namespace Catch {

namespace Detail {
/**
* Returns the largest magnitude of 1-ULP step inside the [a, b] range.
*
* Assumes `a < b`.
*/
template <typename FloatType>
FloatType gamma(FloatType a, FloatType b) {
static_assert( std::is_floating_point<FloatType>::value,
"gamma returns the largest ULP magnitude within "
"floating point range [a, b]. This only makes sense "
"for floating point types" );
assert( a < b );

const auto gamma_up = Catch::nextafter( a, std::numeric_limits<FloatType>::infinity() ) - a;
const auto gamma_down = b - Catch::nextafter( b, -std::numeric_limits<FloatType>::infinity() );

return gamma_up < gamma_down ? gamma_down : gamma_up;
}

template <typename FloatingPoint>
struct IntegerPicker;
template <>
struct IntegerPicker<float> {
using type = std::uint32_t;
};
template <>
struct IntegerPicker<double> {
using type = std::uint64_t;
};

template <typename T>
using PickedType = typename IntegerPicker<T>::type;

#if defined( __GNUC__ ) || defined( __clang__ )
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wfloat-equal"
#endif
/**
* TODO: explain
*/
template <typename FloatType>
PickedType<FloatType>
distance_under_gamma( FloatType a, FloatType b, FloatType g ) {
assert( a < b );

const auto ag = a / g;
const auto bg = b / g;

const auto s = bg - ag;
const auto err = ( std::fabs( a ) <= std::fabs( b ) )
? -ag - ( s - bg )
: bg - ( s + ag );
const auto ceil_s = static_cast<PickedType<FloatType>>( std::ceil( s ) );

return ( ceil_s != s ) ? ceil_s : ceil_s + ( err > 0 );
}
#if defined( __GNUC__ ) || defined( __clang__ )
# pragma GCC diagnostic pop
#endif

}

} // end namespace Catch

#endif // CATCH_GAMMA_HPP_INCLUDED
1 change: 1 addition & 0 deletions src/catch2/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ internal_headers = [
'internal/catch_preprocessor.hpp',
'internal/catch_preprocessor_internal_stringify.hpp',
'internal/catch_preprocessor_remove_parens.hpp',
'internal/catch_random_floating_point_helpers.hpp',
'internal/catch_random_number_generator.hpp',
'internal/catch_random_seed_generation.hpp',
'internal/catch_reporter_registry.hpp',
Expand Down
60 changes: 60 additions & 0 deletions tests/SelfTest/IntrospectiveTests/FloatingPoint.tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_template_test_macros.hpp>
#include <catch2/internal/catch_floating_point_helpers.hpp>
#include <catch2/internal/catch_random_floating_point_helpers.hpp>

#include <limits>

TEST_CASE("convertToBits", "[floating-point][conversion]") {
using Catch::Detail::convertToBits;
Expand Down Expand Up @@ -72,3 +74,61 @@ TEST_CASE("UlpDistance", "[floating-point][ulp][approvals]") {
CHECK( ulpDistance( 1.f, 2.f ) == 0x80'00'00 );
CHECK( ulpDistance( -2.f, 2.f ) == 0x80'00'00'00 );
}



TEMPLATE_TEST_CASE("gamma", "[approvals][floating-point][ulp][gamma]", float, double) {
using Catch::Detail::gamma;
using Catch::Detail::directCompare;

// We need to butcher the equal tests with the directCompare helper,
// because the Wfloat-equal triggers in decomposer rather than here,
// so we cannot locally disable it. Goddamn GCC.
CHECK( directCompare( gamma( TestType( -1. ), TestType( 1. ) ),
gamma( TestType( 0.2332 ), TestType( 1.0 ) ) ) );
CHECK( directCompare( gamma( TestType( -2. ), TestType( 0 ) ),
gamma( TestType( 1. ), TestType( 1.5 ) ) ) );
CHECK( gamma( TestType( 0. ), TestType( 1.0 ) ) <
gamma( TestType( 1.0 ), TestType( 1.5 ) ) );
CHECK( gamma( TestType( 0 ), TestType( 1. ) ) <
std::numeric_limits<TestType>::epsilon() );
CHECK( gamma( TestType( -1. ), TestType( -0. ) ) <
std::numeric_limits<TestType>::epsilon() );
CHECK( directCompare( gamma( TestType( 1. ), TestType( 2. ) ),
std::numeric_limits<TestType>::epsilon() ) );
CHECK( directCompare( gamma( TestType( -2. ), TestType( -1. ) ),
std::numeric_limits<TestType>::epsilon() ) );
}

TEMPLATE_TEST_CASE("distance_under_gamma",
"[approvals][floating-point][distance]",
float,
double) {
using Catch::Detail::distance_under_gamma;
auto count_steps = []( TestType a, TestType b ) {
return distance_under_gamma( a, b, Catch::Detail::gamma( a, b ) );
};

CHECK( count_steps( TestType( -1. ), TestType( 1. ) ) ==
2 * count_steps( TestType( 0. ), TestType( 1. ) ) );
}

TEST_CASE("distance_under_gamma", "[approvals][floating-point][distance]") {
using Catch::Detail::distance_under_gamma;
auto count_steps = []( auto a, auto b ) {
return distance_under_gamma( a, b, Catch::Detail::gamma( a, b ) );
};

CHECK( count_steps( 1., 1.5 ) == 1ull << 51 );
CHECK( count_steps( 1.25, 1.5 ) == 1ull << 50 );
CHECK( count_steps( 1.f, 1.5f ) == 1 << 22 );

STATIC_REQUIRE(
std::is_same<std::uint64_t, decltype( count_steps( 0., 1. ) )>::value );
STATIC_REQUIRE( std::is_same<std::uint32_t,
decltype(count_steps( 0.f, 1.f ))>::value );
}

#if defined( __GNUC__ ) || defined( __clang__ )
# pragma GCC diagnostic pop
#endif

0 comments on commit 124b0a2

Please sign in to comment.