Skip to content

Commit

Permalink
fix(interpolation): query key is out of range due to double calculati…
Browse files Browse the repository at this point in the history
…on error (#2204)

* fix issue sligtly out of range due to double's calculation error

Signed-off-by: Takayuki Murooka <takayuki5168@gmail.com>

* update test

Signed-off-by: Takayuki Murooka <takayuki5168@gmail.com>

Signed-off-by: Takayuki Murooka <takayuki5168@gmail.com>
  • Loading branch information
takayuki5168 authored Nov 4, 2022
1 parent 7d761a5 commit be5b0dc
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 15 deletions.
16 changes: 14 additions & 2 deletions common/interpolation/include/interpolation/interpolation_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef INTERPOLATION__INTERPOLATION_UTILS_HPP_
#define INTERPOLATION__INTERPOLATION_UTILS_HPP_

#include <algorithm>
#include <array>
#include <stdexcept>
#include <vector>
Expand Down Expand Up @@ -51,7 +52,7 @@ inline bool isNotDecreasing(const std::vector<double> & x)
return true;
}

inline void validateKeys(
inline std::vector<double> validateKeys(
const std::vector<double> & base_keys, const std::vector<double> & query_keys)
{
// when vectors are empty
Expand All @@ -71,9 +72,20 @@ inline void validateKeys(
}

// when query_keys is out of base_keys (This function does not allow exterior division.)
if (query_keys.front() < base_keys.front() || base_keys.back() < query_keys.back()) {
constexpr double epsilon = 1e-3;
if (
query_keys.front() < base_keys.front() - epsilon ||
base_keys.back() + epsilon < query_keys.back()) {
throw std::invalid_argument("query_keys is out of base_keys");
}

// NOTE: Due to calculation error of double, a query key may be slightly out of base keys.
// Therefore, query keys are cropped here.
auto validated_query_keys = query_keys;
validated_query_keys.front() = std::max(validated_query_keys.front(), base_keys.front());
validated_query_keys.back() = std::min(validated_query_keys.back(), base_keys.back());

return validated_query_keys;
}

template <class T>
Expand Down
10 changes: 5 additions & 5 deletions common/interpolation/include/interpolation/zero_order_hold.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ std::vector<T> zero_order_hold(
const std::vector<double> & query_keys, const double overlap_threshold = 1e-3)
{
// throw exception for invalid arguments
interpolation_utils::validateKeys(base_keys, query_keys);
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys, query_keys);
interpolation_utils::validateKeysAndValues(base_keys, base_values);

std::vector<T> query_values;
size_t closest_segment_idx = 0;
for (size_t i = 0; i < query_keys.size(); ++i) {
for (size_t i = 0; i < validated_query_keys.size(); ++i) {
// Check if query_key is closes to the terminal point of the base keys
if (base_keys.back() - overlap_threshold < query_keys.at(i)) {
if (base_keys.back() - overlap_threshold < validated_query_keys.at(i)) {
closest_segment_idx = base_keys.size() - 1;
} else {
for (size_t j = closest_segment_idx; j < base_keys.size() - 1; ++j) {
if (
base_keys.at(j) - overlap_threshold < query_keys.at(i) &&
query_keys.at(i) < base_keys.at(j + 1)) {
base_keys.at(j) - overlap_threshold < validated_query_keys.at(i) &&
validated_query_keys.at(i) < base_keys.at(j + 1)) {
// find closest segment in base keys
closest_segment_idx = j;
}
Expand Down
4 changes: 2 additions & 2 deletions common/interpolation/src/linear_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ std::vector<double> lerp(
const std::vector<double> & query_keys)
{
// throw exception for invalid arguments
interpolation_utils::validateKeys(base_keys, query_keys);
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys, query_keys);
interpolation_utils::validateKeysAndValues(base_keys, base_values);

// calculate linear interpolation
std::vector<double> query_values;
size_t key_index = 0;
for (const auto query_key : query_keys) {
for (const auto query_key : validated_query_keys) {
while (base_keys.at(key_index + 1) < query_key) {
++key_index;
}
Expand Down
4 changes: 2 additions & 2 deletions common/interpolation/src/spherical_linear_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ std::vector<geometry_msgs::msg::Quaternion> slerp(
const std::vector<double> & query_keys)
{
// throw exception for invalid arguments
interpolation_utils::validateKeys(base_keys, query_keys);
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys, query_keys);
interpolation_utils::validateKeysAndValues(base_keys, base_values);

// calculate linear interpolation
std::vector<geometry_msgs::msg::Quaternion> query_values;
size_t key_index = 0;
for (const auto query_key : query_keys) {
for (const auto query_key : validated_query_keys) {
while (base_keys.at(key_index + 1) < query_key) {
++key_index;
}
Expand Down
8 changes: 4 additions & 4 deletions common/interpolation/src/spline_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedValues(
const std::vector<double> & query_keys) const
{
// throw exceptions for invalid arguments
interpolation_utils::validateKeys(base_keys_, query_keys);
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys);

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
Expand All @@ -231,7 +231,7 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedValues(

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : query_keys) {
for (const auto & query_key : validated_query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}
Expand All @@ -247,15 +247,15 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
const std::vector<double> & query_keys) const
{
// throw exceptions for invalid arguments
interpolation_utils::validateKeys(base_keys_, query_keys);
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys);

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
const auto & c = multi_spline_coef_.c;

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : query_keys) {
for (const auto & query_key : validated_query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}
Expand Down
17 changes: 17 additions & 0 deletions common/interpolation/test/src/test_interpolation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@ TEST(interpolation_utils, validateKeys)

const std::vector<double> back_out_query_keys{0.0, 1.0, 2.0, 4.0};
EXPECT_THROW(validateKeys(base_keys, back_out_query_keys), std::invalid_argument);

{ // validated key check in normal case
const std::vector<double> normal_query_keys{0.5, 1.5, 3.0};
const auto validated_query_keys = validateKeys(base_keys, normal_query_keys);
for (size_t i = 0; i < normal_query_keys.size(); ++i) {
EXPECT_EQ(normal_query_keys.at(i), validated_query_keys.at(i));
}
}

{ // validated key check in case slightly out of range
constexpr double slightly_out_of_range_epsilon = 1e-6;
const std::vector<double> slightly_out_of_range__query_keys{
0.0 - slightly_out_of_range_epsilon, 3.0 + slightly_out_of_range_epsilon};
const auto validated_query_keys = validateKeys(base_keys, slightly_out_of_range__query_keys);
EXPECT_NEAR(validated_query_keys.at(0), 0.0, 1e-10);
EXPECT_NEAR(validated_query_keys.at(1), 3.0, 1e-10);
}
}

TEST(interpolation_utils, validateKeysAndValues)
Expand Down

0 comments on commit be5b0dc

Please sign in to comment.