diff --git a/common/interpolation/include/interpolation/interpolation_utils.hpp b/common/interpolation/include/interpolation/interpolation_utils.hpp index ed381a3108410..9c0372f788ecb 100644 --- a/common/interpolation/include/interpolation/interpolation_utils.hpp +++ b/common/interpolation/include/interpolation/interpolation_utils.hpp @@ -15,6 +15,7 @@ #ifndef INTERPOLATION__INTERPOLATION_UTILS_HPP_ #define INTERPOLATION__INTERPOLATION_UTILS_HPP_ +#include #include #include #include @@ -51,7 +52,7 @@ inline bool isNotDecreasing(const std::vector & x) return true; } -inline void validateKeys( +inline std::vector validateKeys( const std::vector & base_keys, const std::vector & query_keys) { // when vectors are empty @@ -71,9 +72,20 @@ inline void validateKeys( } // when query_keys is out of base_keys (This function does not allow exterior division.) - if (query_keys.front() < base_keys.front() || base_keys.back() < query_keys.back()) { + constexpr double epsilon = 1e-3; + if ( + query_keys.front() < base_keys.front() - epsilon || + base_keys.back() + epsilon < query_keys.back()) { throw std::invalid_argument("query_keys is out of base_keys"); } + + // NOTE: Due to calculation error of double, a query key may be slightly out of base keys. + // Therefore, query keys are cropped here. + auto validated_query_keys = query_keys; + validated_query_keys.front() = std::max(validated_query_keys.front(), base_keys.front()); + validated_query_keys.back() = std::min(validated_query_keys.back(), base_keys.back()); + + return validated_query_keys; } template diff --git a/common/interpolation/include/interpolation/zero_order_hold.hpp b/common/interpolation/include/interpolation/zero_order_hold.hpp index e48da814c5740..1142cb544c174 100644 --- a/common/interpolation/include/interpolation/zero_order_hold.hpp +++ b/common/interpolation/include/interpolation/zero_order_hold.hpp @@ -27,20 +27,20 @@ std::vector zero_order_hold( const std::vector & query_keys, const double overlap_threshold = 1e-3) { // throw exception for invalid arguments - interpolation_utils::validateKeys(base_keys, query_keys); + const auto validated_query_keys = interpolation_utils::validateKeys(base_keys, query_keys); interpolation_utils::validateKeysAndValues(base_keys, base_values); std::vector query_values; size_t closest_segment_idx = 0; - for (size_t i = 0; i < query_keys.size(); ++i) { + for (size_t i = 0; i < validated_query_keys.size(); ++i) { // Check if query_key is closes to the terminal point of the base keys - if (base_keys.back() - overlap_threshold < query_keys.at(i)) { + if (base_keys.back() - overlap_threshold < validated_query_keys.at(i)) { closest_segment_idx = base_keys.size() - 1; } else { for (size_t j = closest_segment_idx; j < base_keys.size() - 1; ++j) { if ( - base_keys.at(j) - overlap_threshold < query_keys.at(i) && - query_keys.at(i) < base_keys.at(j + 1)) { + base_keys.at(j) - overlap_threshold < validated_query_keys.at(i) && + validated_query_keys.at(i) < base_keys.at(j + 1)) { // find closest segment in base keys closest_segment_idx = j; } diff --git a/common/interpolation/src/linear_interpolation.cpp b/common/interpolation/src/linear_interpolation.cpp index 32d3654dbbdd8..f74d085dfee9e 100644 --- a/common/interpolation/src/linear_interpolation.cpp +++ b/common/interpolation/src/linear_interpolation.cpp @@ -28,13 +28,13 @@ std::vector lerp( const std::vector & query_keys) { // throw exception for invalid arguments - interpolation_utils::validateKeys(base_keys, query_keys); + const auto validated_query_keys = interpolation_utils::validateKeys(base_keys, query_keys); interpolation_utils::validateKeysAndValues(base_keys, base_values); // calculate linear interpolation std::vector query_values; size_t key_index = 0; - for (const auto query_key : query_keys) { + for (const auto query_key : validated_query_keys) { while (base_keys.at(key_index + 1) < query_key) { ++key_index; } diff --git a/common/interpolation/src/spherical_linear_interpolation.cpp b/common/interpolation/src/spherical_linear_interpolation.cpp index 014e9011e2a61..c3595d212f349 100644 --- a/common/interpolation/src/spherical_linear_interpolation.cpp +++ b/common/interpolation/src/spherical_linear_interpolation.cpp @@ -34,13 +34,13 @@ std::vector slerp( const std::vector & query_keys) { // throw exception for invalid arguments - interpolation_utils::validateKeys(base_keys, query_keys); + const auto validated_query_keys = interpolation_utils::validateKeys(base_keys, query_keys); interpolation_utils::validateKeysAndValues(base_keys, base_values); // calculate linear interpolation std::vector query_values; size_t key_index = 0; - for (const auto query_key : query_keys) { + for (const auto query_key : validated_query_keys) { while (base_keys.at(key_index + 1) < query_key) { ++key_index; } diff --git a/common/interpolation/src/spline_interpolation.cpp b/common/interpolation/src/spline_interpolation.cpp index cf00452f1d850..bd92af1007b50 100644 --- a/common/interpolation/src/spline_interpolation.cpp +++ b/common/interpolation/src/spline_interpolation.cpp @@ -222,7 +222,7 @@ std::vector SplineInterpolation::getSplineInterpolatedValues( const std::vector & query_keys) const { // throw exceptions for invalid arguments - interpolation_utils::validateKeys(base_keys_, query_keys); + const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys); const auto & a = multi_spline_coef_.a; const auto & b = multi_spline_coef_.b; @@ -231,7 +231,7 @@ std::vector SplineInterpolation::getSplineInterpolatedValues( std::vector res; size_t j = 0; - for (const auto & query_key : query_keys) { + for (const auto & query_key : validated_query_keys) { while (base_keys_.at(j + 1) < query_key) { ++j; } @@ -247,7 +247,7 @@ std::vector SplineInterpolation::getSplineInterpolatedDiffValues( const std::vector & query_keys) const { // throw exceptions for invalid arguments - interpolation_utils::validateKeys(base_keys_, query_keys); + const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys); const auto & a = multi_spline_coef_.a; const auto & b = multi_spline_coef_.b; @@ -255,7 +255,7 @@ std::vector SplineInterpolation::getSplineInterpolatedDiffValues( std::vector res; size_t j = 0; - for (const auto & query_key : query_keys) { + for (const auto & query_key : validated_query_keys) { while (base_keys_.at(j + 1) < query_key) { ++j; } diff --git a/common/interpolation/test/src/test_interpolation_utils.cpp b/common/interpolation/test/src/test_interpolation_utils.cpp index 3eb40fd439c56..8b3a3b9faa0c6 100644 --- a/common/interpolation/test/src/test_interpolation_utils.cpp +++ b/common/interpolation/test/src/test_interpolation_utils.cpp @@ -95,6 +95,23 @@ TEST(interpolation_utils, validateKeys) const std::vector back_out_query_keys{0.0, 1.0, 2.0, 4.0}; EXPECT_THROW(validateKeys(base_keys, back_out_query_keys), std::invalid_argument); + + { // validated key check in normal case + const std::vector normal_query_keys{0.5, 1.5, 3.0}; + const auto validated_query_keys = validateKeys(base_keys, normal_query_keys); + for (size_t i = 0; i < normal_query_keys.size(); ++i) { + EXPECT_EQ(normal_query_keys.at(i), validated_query_keys.at(i)); + } + } + + { // validated key check in case slightly out of range + constexpr double slightly_out_of_range_epsilon = 1e-6; + const std::vector slightly_out_of_range__query_keys{ + 0.0 - slightly_out_of_range_epsilon, 3.0 + slightly_out_of_range_epsilon}; + const auto validated_query_keys = validateKeys(base_keys, slightly_out_of_range__query_keys); + EXPECT_NEAR(validated_query_keys.at(0), 0.0, 1e-10); + EXPECT_NEAR(validated_query_keys.at(1), 3.0, 1e-10); + } } TEST(interpolation_utils, validateKeysAndValues)