diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a73a4b31..6f6a3c636 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.12.4) # Set the build version (specify a tweak version to indicated post-release if needed) -set(BUILD_VERSION 1.8.7) +set(BUILD_VERSION 1.8.8) # MSVC runtime library flags are defined by 'CMAKE_MSVC_RUNTIME_LIBRARY' if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.15.7) diff --git a/core/include/jiminy/core/engine/engine.h b/core/include/jiminy/core/engine/engine.h index a5730c1d4..c73dfb13f 100644 --- a/core/include/jiminy/core/engine/engine.h +++ b/core/include/jiminy/core/engine/engine.h @@ -292,10 +292,13 @@ namespace jiminy config["groundProfile"] = HeightmapFunction( [](const Eigen::Vector2d & /* xy */, double & height, - Eigen::Vector3d & normal) -> void + std::optional> normal) -> void { height = 0.0; - normal = Eigen::Vector3d::UnitZ(); + if (normal.has_value()) + { + normal.value() = Eigen::Vector3d::UnitZ(); + } }); return config; diff --git a/core/include/jiminy/core/fwd.h b/core/include/jiminy/core/fwd.h index 3cd45f272..acd54ff6d 100644 --- a/core/include/jiminy/core/fwd.h +++ b/core/include/jiminy/core/fwd.h @@ -1,20 +1,20 @@ #ifndef JIMINY_FORWARD_H #define JIMINY_FORWARD_H -#include // `std::string_view` #include // `int32_t`, `int64_t`, `uint32_t`, `uint64_t`, ... -#include // `std::function`, `std::invoke` -#include // `std::numeric_limits` #include // `std::map` +#include // `std::unordered_map` #include // `std::deque` +#include // `std::vector` +#include // `std::pair` +#include // `std::optional` #include // `std::string` #include // `std::ostringstream` -#include // `std::unordered_map` -#include // `std::pair` -#include // `std::vector` +#include // `std::function`, `std::invoke` +#include // `std::numeric_limits` #include // `std::addressof` #include // `std::forward` -#include // `std::runtime_error`, `std::logic_error` +#include // `std::logic_error` #include // `std::enable_if_t`, `std::decay_t`, `std::add_pointer_t`, `std::is_same_v`, ... #include "pinocchio/fwd.hpp" // To avoid having to include it everywhere @@ -189,9 +189,17 @@ namespace jiminy using std::logic_error::logic_error::what; }; - // Ground profile functors - using HeightmapFunction = std::function; + template::min(), + ResultType max_ = std::numeric_limits::max()> + class uniform_random_bit_generator_ref; + + // Ground profile functors. + // FIXME: use `std::move_only_function` instead of `std::function` when moving to C++23 + using HeightmapFunction = + std::function> /* normal */)>; // Flexible joints struct FlexibilityJointConfig diff --git a/core/include/jiminy/core/robot/model.h b/core/include/jiminy/core/robot/model.h index 8efd64c94..97cea4cfe 100644 --- a/core/include/jiminy/core/robot/model.h +++ b/core/include/jiminy/core/robot/model.h @@ -1,6 +1,8 @@ #ifndef JIMINY_MODEL_H #define JIMINY_MODEL_H +#include + #include "pinocchio/spatial/fwd.hpp" // `pinocchio::SE3` #include "pinocchio/multibody/model.hpp" // `pinocchio::Model` #include "pinocchio/multibody/data.hpp" // `pinocchio::Data` @@ -8,7 +10,6 @@ #include "pinocchio/multibody/frame.hpp" // `pinocchio::FrameType` (C-style enum cannot be forward declared) #include "jiminy/core/fwd.h" -#include "jiminy/core/utilities/random.h" // `uniform_random_bit_generator_ref` namespace jiminy diff --git a/core/include/jiminy/core/stepper/lie_group.h b/core/include/jiminy/core/stepper/lie_group.h index 4342f3d3f..f129a15b6 100644 --- a/core/include/jiminy/core/stepper/lie_group.h +++ b/core/include/jiminy/core/stepper/lie_group.h @@ -1238,7 +1238,7 @@ namespace Eigen #define StateDerivative_SHARED_ADDON \ template::ValueType>::value, \ void>> \ @@ -1268,7 +1268,7 @@ namespace Eigen template< \ typename Derived, \ typename OtherDerived, \ - typename = typename std::enable_if_t< \ + typename = std::enable_if_t< \ is_base_of_template_v::ValueType>::value && \ is_base_of_template_v::ValueType>::value, \ void>> \ @@ -1301,7 +1301,7 @@ namespace Eigen } \ \ template::ValueType>::value, \ void>> \ @@ -1320,7 +1320,7 @@ namespace Eigen template< \ typename Derived, \ typename OtherDerived, \ - typename = typename std::enable_if_t< \ + typename = std::enable_if_t< \ is_base_of_template_v::ValueType>::value && \ is_base_of_template_v // `std::array` #include // `std::unique_ptr` #include // `std::optional` #include // `std::pair`, `std::declval` @@ -85,16 +84,14 @@ namespace jiminy /// /// \sa For technical reference about type-erasure for random generators: /// https://stackoverflow.com/a/77809228/4820605 - template::min(), - ResultType max_ = std::numeric_limits::max()> + template class uniform_random_bit_generator_ref : private function_ref { public: using result_type = ResultType; template::min() == min_ && std::decay_t::max() == max_>::value> > constexpr uniform_random_bit_generator_ref(F && f) noexcept : @@ -220,9 +217,9 @@ namespace jiminy template std::enable_if_t< - (is_eigen_any_v || - is_eigen_any_v)&&(!std::is_arithmetic_v> || - !std::is_arithmetic_v>), + (is_eigen_any_v || is_eigen_any_v) && + (!std::is_arithmetic_v> || + !std::is_arithmetic_v>), Eigen::CwiseNullaryOp< scalar_random_op &, float, float), Generator &, @@ -271,9 +268,9 @@ namespace jiminy /// optimizations enabled (level 01 is enough), probably due to inlining. template std::enable_if_t< - (is_eigen_any_v || - is_eigen_any_v)&&(!std::is_arithmetic_v> || - !std::is_arithmetic_v>), + (is_eigen_any_v || is_eigen_any_v) && + (!std::is_arithmetic_v> || + !std::is_arithmetic_v>), Eigen::CwiseNullaryOp< scalar_random_op &, float, float), Generator &, @@ -314,31 +311,40 @@ namespace jiminy /// \param[in] coeffs First row of the matrix to decompose. template MatrixX - standardToeplitzCholeskyLower(const Eigen::MatrixBase & coeffs); + standardToeplitzCholeskyLower(const Eigen::MatrixBase & coeffs, double reg = 0.0); } - class JIMINY_DLLAPI PeriodicGaussianProcess + class JIMINY_TEMPLATE_DLLAPI PeriodicTabularProcess { public: - JIMINY_DISABLE_COPY(PeriodicGaussianProcess) + explicit PeriodicTabularProcess(double wavelength, double period); - public: - explicit PeriodicGaussianProcess(double wavelength, double period) noexcept; - - void reset(const uniform_random_bit_generator_ref & g) noexcept; + virtual void reset(const uniform_random_bit_generator_ref & g) = 0; - double operator()(float t); + double operator()(double t) const noexcept; + double grad(double t) const noexcept; double getWavelength() const noexcept; double getPeriod() const noexcept; - private: + protected: const double wavelength_; const double period_; + const Eigen::Index numTimes_{static_cast(std::ceil(period_ / (0.1 * wavelength_)))}; + const double dt_{period_ / static_cast(numTimes_)}; - const double dt_{0.02 * wavelength_}; - const Eigen::Index numTimes_{static_cast(std::ceil(period_ / dt_))}; + Eigen::VectorXd values_{numTimes_}; + Eigen::VectorXd grads_{numTimes_}; + }; + class JIMINY_TEMPLATE_DLLAPI PeriodicGaussianProcess final : public PeriodicTabularProcess + { + public: + explicit PeriodicGaussianProcess(double wavelength, double period); + + void reset(const uniform_random_bit_generator_ref & g) noexcept override; + + private: /// \brief Cholesky decomposition (LLT) of the covariance matrix. /// /// \details All decompositions are equivalent as the covariance matrix is symmetric, @@ -347,41 +353,46 @@ namespace jiminy /// positive semi-definite Toepliz matrix, which means that the computational /// complexity can be reduced even further using an specialized Cholesky /// decomposition algorithm. See: https://math.stackexchange.com/q/22825/375496 - Eigen::MatrixXd covSqrtRoot_{ - internal::standardToeplitzCholeskyLower(Eigen::VectorXd::NullaryExpr( + /// Ultimately, the algorithmic complexity can be reduced from O(n^3) to O(n^2), + /// which is lower than the matrix multiplication itself. + Eigen::MatrixXd covSqrtRoot_{internal::standardToeplitzCholeskyLower( + Eigen::VectorXd::NullaryExpr( numTimes_, [numTimes = static_cast(numTimes_), wavelength = wavelength_](double i) { return std::exp(-2.0 * std::pow(std::sin(M_PI / numTimes * i) / wavelength, 2)); - }))}; - Eigen::VectorXd values_{numTimes_}; + }), + 1e-9)}; + Eigen::MatrixXd covJacobian_{Eigen::MatrixXd::NullaryExpr( + numTimes_, + numTimes_, + [numTimes = static_cast(numTimes_), + wavelength = wavelength_, + period = period_](double i, double j) + { + return -2 * M_PI / period / std::pow(wavelength, 2) * + std::sin(2 * M_PI / numTimes * (i - j)) * + std::exp(-2.0 * + std::pow(std::sin(M_PI / numTimes * (i - j)) / wavelength, 2)); + })}; }; // **************************** Continuous 1D Fourier processes **************************** // /// \see Based on "Smooth random functions, random ODEs, and Gaussian processes": - /// https://hal.inria.fr/hal-01944992/file/random_revision2.pdf */ - class JIMINY_DLLAPI PeriodicFourierProcess + /// https://hal.inria.fr/hal-01944992/file/random_revision2.pdf + /// + /// \see For references about the derivatives of a Gaussian Process: + /// http://herbsusmann.com/2020/07/06/gaussian-process-derivatives + /// https://arxiv.org/abs/1810.12283 + class JIMINY_TEMPLATE_DLLAPI PeriodicFourierProcess final : public PeriodicTabularProcess { public: - JIMINY_DISABLE_COPY(PeriodicFourierProcess) + explicit PeriodicFourierProcess(double wavelength, double period); - public: - explicit PeriodicFourierProcess(double wavelength, double period) noexcept; - - void reset(const uniform_random_bit_generator_ref & g) noexcept; - - double operator()(float t); - - double getWavelength() const noexcept; - double getPeriod() const noexcept; + void reset(const uniform_random_bit_generator_ref & g) noexcept override; private: - const double wavelength_; - const double period_; - - const double dt_{0.02 * wavelength_}; - const Eigen::Index numTimes_{static_cast(std::ceil(period_ / dt_))}; const Eigen::Index numHarmonics_{ static_cast(std::ceil(period_ / wavelength_))}; @@ -389,79 +400,110 @@ namespace jiminy numTimes_, numHarmonics_, [numTimes = static_cast(numTimes_)](double i, double j) - { return std::cos(2 * M_PI / numTimes * i * j); })}; + { return std::cos(2 * M_PI / numTimes * i * (j + 1)); })}; const Eigen::MatrixXd sinMat_{Eigen::MatrixXd::NullaryExpr( numTimes_, numHarmonics_, [numTimes = static_cast(numTimes_)](double i, double j) - { return std::sin(2 * M_PI / numTimes * i * j); })}; - Eigen::VectorXd values_{numTimes_}; + { return std::sin(2 * M_PI / numTimes * i * (j + 1)); })}; }; - // ***************************** Continuous 1D Perlin processes **************************** // + // ****************************** Continuous Perlin processes ****************************** // + + /// \brief Non-cryptographic hash function initially designed for hash-based lookup. + /// + /// \sa Murmursh algorithms were proposed by Austin Appleby and placed in public domain. + /// The author hereby disclaims copyright to this source code: + /// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp + uint32_t MurmurHash3(const void * key, int32_t len, uint32_t seed) noexcept; + + /// \brief Non-cryptographic hash function initially designed for hash-based lookup. + /// + /// \sa xxHash algorithms were proposed by Yann Collet and placed in the public domain. + /// The author hereby disclaims copyright to this source code: + /// https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h + uint32_t xxHash(const void * key, int32_t len, uint32_t seed) noexcept; - class JIMINY_DLLAPI AbstractPerlinNoiseOctave + template class DerivedPerlinNoiseOctave, unsigned int N> + class JIMINY_TEMPLATE_DLLAPI AbstractPerlinNoiseOctave { + public: + template + using VectorN = Eigen::Matrix; + public: explicit AbstractPerlinNoiseOctave(double wavelength); - virtual ~AbstractPerlinNoiseOctave() = default; - virtual void reset(const uniform_random_bit_generator_ref & g) noexcept; + void reset(const uniform_random_bit_generator_ref & g) noexcept; - double operator()(double t) const; + double operator()(const VectorN & x) const; + VectorN grad(const VectorN & x) const; double getWavelength() const noexcept; - protected: - virtual double grad(int32_t knot, double delta) const noexcept = 0; - - /// \brief Improved Smoothstep function by Ken Perlin (aka Smootherstep). - /// - /// \details It has zero 1st and 2nd-order derivatives at dt = 0.0, and 1.0. - /// - /// \sa For reference, see: - /// https://en.wikipedia.org/wiki/Smoothstep#Variations - static double fade(double delta) noexcept; - static double lerp(double ratio, double yLeft, double yRight) noexcept; + private: + template + std::conditional_t, double> + evaluate(const VectorN & x) const; protected: const double wavelength_; - double shift_{0.0}; + VectorN shift_{}; + + mutable VectorN cellIndex_ = + VectorN::Constant(std::numeric_limits::max()); + mutable std::array, (1U << N)> gradKnots_{}; }; - class JIMINY_DLLAPI RandomPerlinNoiseOctave : public AbstractPerlinNoiseOctave + template + class JIMINY_TEMPLATE_DLLAPI RandomPerlinNoiseOctave : + public AbstractPerlinNoiseOctave { + public: + template + using VectorN = Eigen::Matrix; + + friend class AbstractPerlinNoiseOctave; + public: explicit RandomPerlinNoiseOctave(double wavelength); - ~RandomPerlinNoiseOctave() override = default; - void reset(const uniform_random_bit_generator_ref & g) noexcept override; + void reset(const uniform_random_bit_generator_ref & g) noexcept; protected: - double grad(int32_t knot, double delta) const noexcept override; + VectorN gradKnot(const VectorN & knot) const noexcept; private: uint32_t seed_{0U}; }; - class JIMINY_DLLAPI PeriodicPerlinNoiseOctave : public AbstractPerlinNoiseOctave + template + class JIMINY_TEMPLATE_DLLAPI PeriodicPerlinNoiseOctave : + public AbstractPerlinNoiseOctave { + public: + template + using VectorN = Eigen::Matrix; + + friend class AbstractPerlinNoiseOctave; + public: explicit PeriodicPerlinNoiseOctave(double wavelength, double period); - ~PeriodicPerlinNoiseOctave() override = default; - void reset(const uniform_random_bit_generator_ref & g) noexcept override; + void reset(const uniform_random_bit_generator_ref & g) noexcept; double getPeriod() const noexcept; protected: - double grad(int32_t knot, double delta) const noexcept override; + VectorN gradKnot(const VectorN & knot) const noexcept; private: const double period_; - std::array perm_{}; + const int32_t size_{static_cast(period_ / this->wavelength_)}; + std::vector> grads_ = + std::vector>(static_cast(std::pow(size_, N))); }; /// \brief Sum of Perlin noise octaves. @@ -484,18 +526,27 @@ namespace jiminy /// https://github.com/bradykieffer/SimplexNoise/blob/master/simplexnoise/noise.py /// https://github.com/sol-prog/Perlin_Noise/blob/master/PerlinNoise.cpp /// https://github.com/ashima/webgl-noise/blob/master/src/classicnoise2D.glsl - class JIMINY_DLLAPI AbstractPerlinProcess + template class DerivedPerlinNoiseOctave, + unsigned int N, + typename = std::enable_if_t< + std::is_base_of_v, + DerivedPerlinNoiseOctave>>> + class AbstractPerlinProcess; + + template class DerivedPerlinNoiseOctave, unsigned int N> + class JIMINY_TEMPLATE_DLLAPI AbstractPerlinProcess { public: - JIMINY_DISABLE_COPY(AbstractPerlinProcess) + template + using VectorN = Eigen::Matrix; - using OctaveScalePair = - std::pair, const double>; + using OctaveScalePair = std::pair, const double>; public: void reset(const uniform_random_bit_generator_ref & g) noexcept; - double operator()(float t); + double operator()(const VectorN & x) const; + VectorN grad(const VectorN & x) const; double getWavelength() const noexcept; std::size_t getNumOctaves() const noexcept; @@ -510,13 +561,16 @@ namespace jiminy double amplitude_{0.0}; }; - class JIMINY_DLLAPI RandomPerlinProcess : public AbstractPerlinProcess + template + class JIMINY_TEMPLATE_DLLAPI RandomPerlinProcess : + public AbstractPerlinProcess { public: explicit RandomPerlinProcess(double wavelength, std::size_t numOctaves = 6U); }; - class PeriodicPerlinProcess : public AbstractPerlinProcess + template + class PeriodicPerlinProcess : public AbstractPerlinProcess { public: explicit PeriodicPerlinProcess( diff --git a/core/include/jiminy/core/utilities/random.hxx b/core/include/jiminy/core/utilities/random.hxx index 8813f32e7..a68e67d40 100644 --- a/core/include/jiminy/core/utilities/random.hxx +++ b/core/include/jiminy/core/utilities/random.hxx @@ -8,6 +8,9 @@ namespace jiminy { + static inline constexpr double PERLIN_NOISE_PERSISTENCE{1.50}; + static inline constexpr double PERLIN_NOISE_LACUNARITY{0.85}; + // ***************************** Uniform random bit generators ***************************** // namespace internal @@ -57,9 +60,9 @@ namespace jiminy template std::enable_if_t< - (is_eigen_any_v || - is_eigen_any_v)&&(!std::is_arithmetic_v> || - !std::is_arithmetic_v>), + (is_eigen_any_v || is_eigen_any_v) && + (!std::is_arithmetic_v> || + !std::is_arithmetic_v>), Eigen::CwiseNullaryOp< scalar_random_op &, float, float), Generator &, @@ -103,9 +106,9 @@ namespace jiminy template std::enable_if_t< - (is_eigen_any_v || - is_eigen_any_v)&&(!std::is_arithmetic_v> || - !std::is_arithmetic_v>), + (is_eigen_any_v || is_eigen_any_v) && + (!std::is_arithmetic_v> || + !std::is_arithmetic_v>), Eigen::CwiseNullaryOp< scalar_random_op &, float, float), Generator &, @@ -153,7 +156,7 @@ namespace jiminy { template MatrixX - standardToeplitzCholeskyLower(const Eigen::MatrixBase & coeffs) + standardToeplitzCholeskyLower(const Eigen::MatrixBase & coeffs, double reg) { using Scalar = typename Derived::Scalar; @@ -165,6 +168,7 @@ namespace jiminy It coincides with the Schur generator for Toepliz matrices. */ Eigen::Matrix g{2, n}; g.rowwise() = coeffs.transpose(); + g(0, 0) += reg; // Run progressive Schur algorithm, adapted to Toepliz matrices l.col(0) = g.row(0); @@ -184,6 +188,516 @@ namespace jiminy return l; } } + + // ****************************** Continuous Perlin processes ****************************** // + + /// \brief Improved Smoothstep function by Ken Perlin (aka Smootherstep). + /// + /// \details It has zero 1st and 2nd-order derivatives at dt = 0.0, and 1.0. + /// + /// \sa For reference, see: + /// https://en.wikipedia.org/wiki/Smoothstep#Variations + static inline double fade(double delta) noexcept + { + return delta * delta * delta * (delta * (delta * 6.0 - 15.0) + 10.0); + } + + static inline double derivativeFade(double delta) noexcept + { + return 30.0 * delta * delta * (delta * (delta - 2.0) + 1.0); + } + + template + static std::common_type_t, std::decay_t> + lerp(double ratio, T1 && yLeft, T2 && yRight) noexcept + { + return yLeft + ratio * (yRight - yLeft); + } + + template + static std::common_type_t, std::decay_t> + derivativeLerp(double dratio, T1 && yLeft, T2 && yRight) noexcept + { + return dratio * (yRight - yLeft); + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + AbstractPerlinNoiseOctave::AbstractPerlinNoiseOctave( + double wavelength) : + wavelength_{wavelength} + { + if (wavelength_ <= 0.0) + { + JIMINY_THROW(std::invalid_argument, "'wavelength' must be strictly larger than 0.0."); + } + reset(std::random_device{}); + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + void AbstractPerlinNoiseOctave::reset( + const uniform_random_bit_generator_ref & g) noexcept + { + // Sample random cell shift + shift_ = uniform(N, 1, g).cast(); + + // Clear cache index + cellIndex_.setConstant(std::numeric_limits::max()); + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + double AbstractPerlinNoiseOctave::getWavelength() const noexcept + { + return wavelength_; + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + template + std::conditional_t< + isGradient, + typename AbstractPerlinNoiseOctave::template VectorN, + double> + AbstractPerlinNoiseOctave::evaluate( + const VectorN & x) const + { + // Get current cell + const VectorN cell = x / wavelength_ + shift_; + + // Compute the bottom left corner knot + const VectorN cellIndexLeft = cell.array().floor().template cast(); + const VectorN cellIndexRight = cellIndexLeft.array() + 1; + + // Compute smoothed ratio of query point wrt to the bottom left corner knot + const VectorN deltaLeft = cell - cellIndexLeft.template cast(); + const VectorN deltaRight = deltaLeft.array() - 1.0; + + // Compute gradients at knots (on a meshgrid), then corresponding offsets at query point + bool isCacheValid = (cellIndexLeft.array() == cellIndex_.array()).all(); + std::array offsets; + if (isCacheValid) + { + VectorN delta; + for (uint32_t k = 0; k < (1U << N); k++) + { + // Mapping from index to knot + for (uint32_t i = 0; i < N; i++) + { + if (k & (1U << i)) + { + delta[i] = deltaRight[i]; + } + else + { + delta[i] = deltaLeft[i]; + } + } + + // Compute the offset at query point + offsets[k] = gradKnots_[k].dot(delta); + } + } + else + { + VectorN knot; + VectorN delta; + const auto & derived = static_cast &>(*this); + for (uint32_t k = 0; k < (1U << N); k++) + { + // Mapping from index to knot + for (uint32_t i = 0; i < N; i++) + { + if (k & (1U << i)) + { + knot[i] = cellIndexRight[i]; + delta[i] = deltaRight[i]; + } + else + { + knot[i] = cellIndexLeft[i]; + delta[i] = deltaLeft[i]; + } + } + + // Evaluate the gradient at knot + gradKnots_[k] = derived.gradKnot(knot); + + // Compute the offset at query point + offsets[k] = gradKnots_[k].dot(delta); + } + } + + // Update cache index + cellIndex_ = cellIndexLeft; + + // Compute the derivative along each axis + const VectorN ratio = deltaLeft.array().unaryExpr(std::ref(fade)); + if constexpr (isGradient) + { + const VectorN dratio = deltaLeft.array().unaryExpr(std::ref(derivativeFade)); + std::array, (1U << N)> _interpGrads = gradKnots_; + for (int32_t i = N - 1; i >= 0; --i) + { + for (uint32_t k = 0; k < (1U << i); k++) + { + VectorN & gradLeft = _interpGrads[k]; + const VectorN gradRight = _interpGrads[k | (1U << i)]; + gradLeft = lerp(ratio[i], gradLeft, gradRight); + } + } + for (int32_t j = 0; j < static_cast(N); ++j) + { + std::array _interpOffsets = offsets; + for (int32_t i = N - 1; i >= 0; --i) + { + for (uint32_t k = 0; k < (1U << i); k++) + { + double & offsetLeft = _interpOffsets[k]; + const double offsetRight = _interpOffsets[k | (1U << i)]; + if (i == j) + { + offsetLeft = derivativeLerp(dratio[i], offsetLeft, offsetRight); + } + else + { + offsetLeft = lerp(ratio[i], offsetLeft, offsetRight); + } + } + } + _interpGrads[0][j] += _interpOffsets[0]; + } + return _interpGrads[0] / wavelength_; + } + else + { + // Perform linear interpolation on each dimension recursively until to get a scalar + for (int32_t i = N - 1; i >= 0; --i) + { + for (uint32_t k = 0; k < (1U << i); k++) + { + double & offsetLeft = offsets[k]; + const double offsetRight = offsets[k | (1U << i)]; + offsetLeft = lerp(ratio[i], offsetLeft, offsetRight); + } + } + return offsets[0]; + } + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + double AbstractPerlinNoiseOctave::operator()( + const VectorN & x) const + { + return evaluate(x); + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + typename AbstractPerlinNoiseOctave::template VectorN + AbstractPerlinNoiseOctave::grad(const VectorN & x) const + { + return evaluate(x); + } + + template + RandomPerlinNoiseOctave::RandomPerlinNoiseOctave(double wavelength) : + AbstractPerlinNoiseOctave(wavelength) + { + reset(std::random_device{}); + } + + template + void RandomPerlinNoiseOctave::reset( + const uniform_random_bit_generator_ref & g) noexcept + { + // Call base implementation + AbstractPerlinNoiseOctave::reset(g); + + // Sample new random seed + seed_ = g(); + } + + template + typename RandomPerlinNoiseOctave::template VectorN + RandomPerlinNoiseOctave::gradKnot(const VectorN & knot) const noexcept + { + constexpr float fHashMax = static_cast(std::numeric_limits::max()); + + // Compute knot hash + uint32_t hash = xxHash(knot.data(), static_cast(sizeof(int32_t) * N), seed_); + + /* Generate random gradient uniformly distributed on n-ball. + For technical reference, see: + https://extremelearning.com.au/how-to-generate-uniformly-random-points-on-n-spheres-and-n-balls/ + */ + if constexpr (N == 1) + { + // Sample random scalar in [0.0, 1.0) + const float s = static_cast(hash) / fHashMax; + + // Compute rescaled gradient between [-1.0, 1.0) + return VectorN{2.0 * s - 1.0}; + } + else if constexpr (N == 2) + { + // Sample random vector on a 2-ball (disk) using + // const double theta = 2 * M_PI * static_cast(hash) / fHashMax; + // hash = xxHash(&hash, sizeof(uint32_t), seed_); + // const float radius = std::sqrt(static_cast(hash) / fHashMax); + // return VectorN{radius * std::cos(theta), radius * std::sin(theta)}; + + /* The rejection method is much fast in 2d because it does not involve complex math + (sqrt, sincos) and the acceptance rate is high (~78%) compared to the cost of + sampling random numbers using `xxHash`. */ + while (true) + { + const float x = 2 * static_cast(hash) / fHashMax - 1.0F; + hash = xxHash(&hash, sizeof(uint32_t), seed_); + const float y = 2 * static_cast(hash) / fHashMax - 1.0F; + if (x * x + y * y <= 1.0F) + { + return VectorN{x, y}; + } + } + } + else + { + // Generate a uniformly distributed random vector on n-sphere + VectorN dir; + for (uint32_t i = 0; i < N; i += 2) + { + // Generate 2 uniformly distributed random variables + const float u1 = static_cast(hash) / fHashMax; + hash = xxHash(&hash, sizeof(uint32_t), seed_); + const float u2 = static_cast(hash) / fHashMax; + hash = xxHash(&hash, sizeof(uint32_t), seed_); + + // Apply Box-Mueller algorithm to deduce 2 normally distributed random variables + const double theta = 2 * M_PI * u2; + const float radius = std::sqrt(-2 * std::log(u1)); + dir[i] = radius * std::cos(theta); + if (i + 1 < N) + { + dir[i + 1] = radius * std::sin(theta); + } + } + dir.normalize(); + + // Sample radius + const double radius = std::pow(static_cast(hash) / fHashMax, 1.0 / N); + + // Return the resulting random vector on n-ball using Muller method + return radius * dir; + } + } + + template + PeriodicPerlinNoiseOctave::PeriodicPerlinNoiseOctave(double wavelength, double period) : + AbstractPerlinNoiseOctave( + period / std::max(std::round(period / wavelength), 1.0)), + period_{period} + { + // Make sure the period is larger than the wavelength + if (period < wavelength) + { + JIMINY_THROW(std::invalid_argument, "'period' must be larger than 'wavelength'."); + } + + // Initialize the pre-computed hash table + reset(std::random_device{}); + } + + template + void PeriodicPerlinNoiseOctave::reset( + const uniform_random_bit_generator_ref & g) noexcept + { + // Call base implementation + AbstractPerlinNoiseOctave::reset(g); + + // Re-initialize the pre-computed hash table + for (auto & grad : grads_) + { + if constexpr (N == 1) + { + grad = VectorN{uniform(g, -1.0F, 1.0F)}; + } + else if constexpr (N == 2) + { + const double theta = 2 * M_PI * uniform(g); + const float radius = std::sqrt(uniform(g)); + grad = VectorN{radius * std::cos(theta), radius * std::sin(theta)}; + } + else + { + const VectorN dir = normal(N, 1, g).cast().normalized(); + const double radius = std::pow(uniform(g), 1.0 / N); + grad = radius * dir; + } + } + } + + template + typename PeriodicPerlinNoiseOctave::template VectorN + PeriodicPerlinNoiseOctave::gradKnot(const VectorN & knot) const noexcept + { + // Wrap knot is period interval + int32_t index = 0; + int32_t shift = 1; + for (uint_fast8_t i = 0; i < N; ++i) + { + int32_t coord = knot[i] % size_; + if (coord < 0) + { + coord += size_; + } + index += coord * shift; + shift *= size_; + } + + // Return the gradient + return grads_[index]; + } + + template + double PeriodicPerlinNoiseOctave::getPeriod() const noexcept + { + return period_; + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + AbstractPerlinProcess::AbstractPerlinProcess( + std::vector && octaveScalePairs) noexcept : + octaveScalePairs_(std::move(octaveScalePairs)) + { + // Compute the scaling factor to keep values within range [-1.0, 1.0] + double amplitudeSquared = 0.0; + for (const OctaveScalePair & octaveScale : octaveScalePairs_) + { + // FIXME: replaced `std::get` by placeholder `_` when moving to C++26 (P2169R4) + amplitudeSquared += std::pow(std::get<1>(octaveScale), 2); + } + amplitude_ = std::sqrt(amplitudeSquared); + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + void AbstractPerlinProcess::reset( + const uniform_random_bit_generator_ref & g) noexcept + { + // Reset octaves + for (OctaveScalePair & octaveScale : octaveScalePairs_) + { + // FIXME: replaced `std::get` by placeholder `_` when moving to C++26 (P2169R4) + std::get<0>(octaveScale).reset(g); + } + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + double + AbstractPerlinProcess::operator()(const VectorN & x) const + { + // Compute sum of octaves' values + double value = 0.0; + for (const auto & [octave, scale] : octaveScalePairs_) + { + value += scale * octave(x); + } + + // Scale sum by maximum amplitude + return value / amplitude_; + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + typename AbstractPerlinProcess::template VectorN + AbstractPerlinProcess::grad(const VectorN & x) const + { + // Compute sum of octaves' values + VectorN value = VectorN::Zero(); + for (const auto & [octave, scale] : octaveScalePairs_) + { + value += scale * octave.grad(x); + } + + // Scale sum by maximum amplitude + return value / amplitude_; + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + double AbstractPerlinProcess::getWavelength() const noexcept + { + double wavelength = INF; + for (const OctaveScalePair & octaveScale : octaveScalePairs_) + { + // FIXME: replaced `std::get` by placeholder `_` when moving to C++26 (P2169R4) + wavelength = std::min(wavelength, std::get<0>(octaveScale).getWavelength()); + } + return wavelength; + } + + template class DerivedPerlinNoiseOctave, unsigned int N> + std::size_t AbstractPerlinProcess::getNumOctaves() const noexcept + { + return octaveScalePairs_.size(); + } + + template class DerivedPerlinNoiseOctave, + unsigned int N, + typename... ExtraArgs> + static std::vector, const double>> + buildPerlinNoiseOctaves(double wavelength, std::size_t numOctaves, ExtraArgs &&... args) + { + // Make sure that at least one octave has been requested + if (numOctaves < 1) + { + JIMINY_THROW(std::invalid_argument, "'numOctaves' must at least 1."); + } + + // Make sure that wavelength of all the octaves is consistent with period if application + if constexpr (std::is_base_of_v, DerivedPerlinNoiseOctave>) + { + const double period = std::get<0>(std::tuple{std::forward(args)...}); + const double wavelengthFinal = + wavelength / std::pow(PERLIN_NOISE_LACUNARITY, numOctaves - 1); + if (period < std::max(wavelength, wavelengthFinal)) + { + JIMINY_THROW(std::invalid_argument, + "'period' must be larger than the wavelength of all the octaves (", + std::max(wavelength, wavelengthFinal), + "), ie 'wavelength' / ", + PERLIN_NOISE_LACUNARITY, + "^i for i in [1, ..., 'numOctaves']."); + } + } + + std::vector, const double>> octaveScalePairs; + octaveScalePairs.reserve(numOctaves); + double scale = 1.0; + for (std::size_t i = 0; i < numOctaves; ++i) + { + octaveScalePairs.emplace_back(DerivedPerlinNoiseOctave(wavelength, args...), scale); + wavelength /= PERLIN_NOISE_LACUNARITY; + scale *= PERLIN_NOISE_PERSISTENCE; + } + return octaveScalePairs; + } + + template + RandomPerlinProcess::RandomPerlinProcess(double wavelength, std::size_t numOctaves) : + AbstractPerlinProcess( + buildPerlinNoiseOctaves(wavelength, numOctaves)) + { + } + + template + PeriodicPerlinProcess::PeriodicPerlinProcess( + double wavelength, double period, std::size_t numOctaves) : + AbstractPerlinProcess( + buildPerlinNoiseOctaves(wavelength, numOctaves, period)), + period_{period} + { + } + + template + double PeriodicPerlinProcess::getPeriod() const noexcept + { + return period_; + } } #endif // JIMINY_RANDOM_HXX diff --git a/core/src/io/serialization.cc b/core/src/io/serialization.cc index 1e8db48e4..57fdb6491 100644 --- a/core/src/io/serialization.cc +++ b/core/src/io/serialization.cc @@ -248,10 +248,15 @@ namespace boost::serialization template void load(Archive & /* ar */, HeightmapFunction & fun, const unsigned int /* version */) { - fun = [](const Eigen::Vector2d & /* xy */, double & height, Eigen::Vector3d & normal) + fun = [](const Eigen::Vector2d & /* xy */, + double & height, + std::optional> normal) { height = 0.0; - normal = Eigen::Vector3d::UnitZ(); + if (normal.has_value()) + { + normal.value() = Eigen::Vector3d::UnitZ(); + } }; } @@ -1006,6 +1011,7 @@ namespace boost::serialization // Restore extended simulation model model.pinocchioModel_ = pinocchioModel; + model.pinocchioData_ = pinocchio::Data(pinocchioModel); } template diff --git a/core/src/solver/constraint_solvers.cc b/core/src/solver/constraint_solvers.cc index 90eaf56cb..d56cde3c1 100644 --- a/core/src/solver/constraint_solvers.cc +++ b/core/src/solver/constraint_solvers.cc @@ -2,7 +2,6 @@ #include "pinocchio/multibody/data.hpp" // `pinocchio::Data` #include "pinocchio/algorithm/cholesky.hpp" // `pinocchio::cholesky::` -#include "jiminy/core/utilities/random.h" #include "jiminy/core/utilities/helpers.h" #include "jiminy/core/constraints/abstract_constraint.h" #include "jiminy/core/robot/pinocchio_overload_algorithms.h" @@ -265,11 +264,11 @@ namespace jiminy they can grow arbitrary large for constraints whose bounds are active. It follows that stagnation of residuals is the only viable criteria. The PGS algorithm has been modified for solving second-order cone LCP, which means - that only the L2-norm of the tangential forces can be expected to converge. Because + that only the L^2-norm of the tangential forces can be expected to converge. Because of this, it is too restrictive to check the element-wise variation of the residuals over iterations. It makes more sense to look at the Linf-norm instead, but this criteria is very lax. A good compromise may be to look at the constraint-block-wise - L2-norm, which is similar to what Drake simulator is doing. For reference, see: + L^2-norm, which is similar to what Drake simulator is doing. For reference, see: https://github.com/RobotLocomotion/drake/blob/master/multibody/contact_solvers/pgs_solver.cc */ const double tol = tolAbs_ + tolRel_ * y_.lpNorm() + EPS; diff --git a/core/src/utilities/geometry.cc b/core/src/utilities/geometry.cc index 8b11eca07..4832a2147 100644 --- a/core/src/utilities/geometry.cc +++ b/core/src/utilities/geometry.cc @@ -1,7 +1,9 @@ #include "hpp/fcl/BVH/BVH_model.h" // `hpp::fcl::CollisionGeometry`, `hpp::fcl::BVHModel`, `hpp::fcl::OBBRSS` #include "hpp/fcl/shape/geometric_shapes.h" // `hpp::fcl::Halfspace` #include "hpp/fcl/hfield.h" // `hpp::fcl::HeightField` + #include "jiminy/core/utilities/geometry.h" +#include "jiminy/core/utilities/random.h" namespace jiminy @@ -621,8 +623,7 @@ namespace jiminy for (Eigen::Index i = 0; i < vertices.rows(); ++i) { auto vertex = vertices.row(i); - Eigen::Vector3d normal; - heightmap(vertex.head<2>(), vertex[2], normal); + heightmap(vertex.head<2>(), vertex[2], std::nullopt); } // Check if the heightmap is flat @@ -692,78 +693,126 @@ namespace jiminy HeightmapFunction sumHeightmaps(const std::vector & heightmaps) { + // Make sure that at least one heightmap has been specified + if (heightmaps.empty()) + { + JIMINY_THROW(bad_control_flow, "At least one heightmap must be specified."); + } + + // Early return if a single heightmap has been specified. Nothing to do. if (heightmaps.size() == 1) { return heightmaps[0]; } - return [heightmaps]( - const Eigen::Vector2d & pos, double & height, Eigen::Vector3d & normal) -> void - { - thread_local static double height_i; - thread_local static Eigen::Vector3d normal_i; + return [heightmaps](const Eigen::Vector2d & pos, + double & height, + std::optional> normal) -> void + { height = 0.0; - normal.setZero(); - for (const HeightmapFunction & heightmap : heightmaps) + if (normal.has_value()) { - heightmap(pos, height_i, normal_i); - height += height_i; - normal += normal_i; + normal->setZero(); + for (const HeightmapFunction & heightmap : heightmaps) + { + double height_i; + Eigen::Vector3d normal_i; + heightmap(pos, height_i, normal_i); + height += height_i; + normal.value() += normal_i; + } + normal->normalize(); + } + else + { + for (const HeightmapFunction & heightmap : heightmaps) + { + double height_i; + heightmap(pos, height_i, std::nullopt); + height += height_i; + } } - normal.normalize(); }; } HeightmapFunction mergeHeightmaps(const std::vector & heightmaps) { + // Make sure that at least one heightmap has been specified + if (heightmaps.empty()) + { + JIMINY_THROW(bad_control_flow, "At least one heightmap must be specified."); + } + + // Early return if a single heightmap has been specified. Nothing to do. if (heightmaps.size() == 1) { return heightmaps[0]; } - return [heightmaps]( - const Eigen::Vector2d & pos, double & height, Eigen::Vector3d & normal) -> void - { - thread_local static double height_i; - thread_local static Eigen::Vector3d normal_i; + return [heightmaps](const Eigen::Vector2d & pos, + double & height, + std::optional> normal) -> void + { height = -INF; - bool is_dirty = false; - for (const HeightmapFunction & heightmap : heightmaps) + if (normal.has_value()) { - heightmap(pos, height_i, normal_i); - if (std::abs(height_i - height) < EPS) + bool is_dirty = false; + for (const HeightmapFunction & heightmap : heightmaps) { - normal += normal_i; - is_dirty = true; + double height_i; + Eigen::Vector3d normal_i; + heightmap(pos, height_i, normal_i); + if (std::abs(height_i - height) < EPS) + { + normal.value() += normal_i; + is_dirty = true; + } + else if (height_i > height) + { + height = height_i; + normal.value() = normal_i; + is_dirty = false; + } } - else if (height_i > height) + if (is_dirty) { - height = height_i; - normal = normal_i; - is_dirty = false; + normal->normalize(); } } - if (is_dirty) + else { - normal.normalize(); + for (const HeightmapFunction & heightmap : heightmaps) + { + double height_i; + heightmap(pos, height_i, std::nullopt); + if (height_i > height) + { + height = height_i; + } + } } }; } - HeightmapFunction stairs( + HeightmapFunction periodicStairs( double stepWidth, double stepHeight, uint32_t stepNumber, double orientation) { const double interpDelta = 0.01; - const Eigen::Rotation2D rot_mat(orientation); - return [stepWidth, stepHeight, stepNumber, rot_mat, interpDelta]( - const Eigen::Vector2d & pos, double & height, Eigen::Vector3d & normal) -> void + // Define the projection axis + const Eigen::Vector2d axis{std::cos(orientation), std::sin(orientation)}; + + return [stepWidth, stepHeight, stepNumber, axis, interpDelta]( + const Eigen::Vector2d & pos, + double & height, + std::optional> normal) -> void { // Compute position in stairs reference frame - Eigen::Vector2d posRel = rot_mat.inverse() * pos; - const double modPos = std::fmod(std::abs(posRel[0]), stepWidth * stepNumber * 2); + // Eigen::Vector2d posRel = rotMat.inverse() * pos; + const double posRel = axis.dot(pos); + const double modPos = std::fmod(std::abs(posRel), stepWidth * stepNumber * 2); - // Compute the default height and normal + // Compute the default height uint32_t stairIndex = static_cast(modPos / stepWidth); int8_t staircaseSlopeSign = 1; if (stairIndex >= stepNumber) @@ -772,27 +821,126 @@ namespace jiminy staircaseSlopeSign = -1; } height = stairIndex * stepHeight; - normal = Eigen::Vector3d::UnitZ(); // Avoid unsupported vertical edge - const double posRelOnStep = std::fmod(modPos, stepWidth) / stepWidth; + const double posRelOnStep = + std::fmod(modPos + std::numeric_limits::epsilon(), stepWidth) / stepWidth; if (1.0 - posRelOnStep < interpDelta) { + // Compute the slope of the vertical edge of the stair const double slope = staircaseSlopeSign * stepHeight / interpDelta; + // Update height height += slope * (posRelOnStep - (1.0 - interpDelta)); + if (normal.has_value()) + { + // Compute the inverse of the normal's Euclidean norm + const double normInv = 1.0 / std::sqrt(1.0 + std::pow(slope, 2)); + + // Update normal vector + // step 1. compute normal in stairs reference frame: + // normal << -slope * normInv, 0.0, normInv; + // step 2. Rotate normal vector in world plane reference frame: + // normal.head<2>() = rotMat * normal.head<2>(); + // Or simply in a single operation: + normal.value() << -slope * normInv * axis, normInv; + } + } + else if (normal.has_value()) + { + normal.value() = Eigen::Vector3d::UnitZ(); + } + }; + } + + template class AnyPerlinProcess> + static HeightmapFunction generateHeightmapFromPerlinProcess1D(AnyPerlinProcess<1> && fun, + double orientation) + { + // Define the projection axis + const Eigen::Vector2d axis{std::cos(orientation), std::sin(orientation)}; + + return + [fun = std::move(fun), axis](const Eigen::Vector2d & pos, + double & height, + std::optional> normal) mutable + { + // Compute the position along axis + const Vector1 posAxis = Vector1{axis.dot(pos)}; + + // Compute the height + height = fun(posAxis); + + if (normal.has_value()) + { + // Compute the gradient of the Perlin Proces + const double slope = fun.grad(posAxis)[0]; + // Compute the inverse of the normal's Euclidean norm const double normInv = 1.0 / std::sqrt(1.0 + std::pow(slope, 2)); // Update normal vector - // step 1. compute normal in stairs reference frame: - // normal << -slope * normInv, 0.0, normInv; - // step 2. Rotate normal vector in world plane reference frame: - // normal.head<2>() = rot_mat * normal.head<2>(); - // Or simply in a single operation: - normal << -slope * normInv * rot_mat.toRotationMatrix().col(0), normInv; + normal.value() << -slope * normInv * axis, normInv; } }; } + + template class AnyPerlinProcess> + static HeightmapFunction generateHeightmapFromPerlinProcess2D(AnyPerlinProcess<2> && fun) + { + return [fun = std::move(fun)](const Eigen::Vector2d & pos, + double & height, + std::optional> normal) mutable + { + // Compute the height + height = fun(pos); + + if (normal.has_value()) + { + // Compute the gradient of the Perlin Proces + const auto grad = fun.grad(pos); + + // Compute the inverse of the normal's Euclidean norm + const double normInv = 1.0 / std::sqrt(1.0 + grad.squaredNorm()); + + // Update normal vector + normal.value() << -normInv * grad.template head<2>(), normInv; + } + }; + } + + HeightmapFunction unidirectionalRandomPerlinGround( + double wavelength, std::size_t numOctaves, double orientation, uint32_t seed) + { + auto fun = RandomPerlinProcess<1>(wavelength, numOctaves); + fun.reset(PCG32(seed)); + return generateHeightmapFromPerlinProcess1D(std::move(fun), orientation); + } + + HeightmapFunction randomPerlinGround(double wavelength, std::size_t numOctaves, uint32_t seed) + { + auto fun = RandomPerlinProcess<2>(wavelength, numOctaves); + fun.reset(PCG32(seed)); + return generateHeightmapFromPerlinProcess2D(std::move(fun)); + } + + HeightmapFunction periodicPerlinGround( + double wavelength, double period, std::size_t numOctaves, uint32_t seed) + { + auto fun = PeriodicPerlinProcess<2>(wavelength, period, numOctaves); + fun.reset(PCG32(seed)); + return generateHeightmapFromPerlinProcess2D(std::move(fun)); + } + + HeightmapFunction unidirectionalPeriodicPerlinGround(double wavelength, + double period, + std::size_t numOctaves, + double orientation, + uint32_t seed) + { + auto fun = PeriodicPerlinProcess<1>(wavelength, period, numOctaves); + fun.reset(PCG32(seed)); + return generateHeightmapFromPerlinProcess1D(std::move(fun), orientation); + } } \ No newline at end of file diff --git a/core/src/utilities/json.cc b/core/src/utilities/json.cc index 9e092f5fb..023befe06 100644 --- a/core/src/utilities/json.cc +++ b/core/src/utilities/json.cc @@ -145,12 +145,16 @@ namespace jiminy template<> HeightmapFunction convertFromJson(const Json::Value & /* value */) { - return { - [](const Eigen::Vector2d & /* xy */, double & height, Eigen::Vector3d & normal) -> void - { - height = 0.0; - normal = Eigen::Vector3d::UnitZ(); - }}; + return {[](const Eigen::Vector2d & /* xy */, + double & height, + std::optional> normal) -> void + { + height = 0.0; + if (normal.has_value()) + { + normal.value() = Eigen::Vector3d::UnitZ(); + } + }}; } template<> diff --git a/core/src/utilities/random.cc b/core/src/utilities/random.cc index 5d5b84345..73cfc6d8b 100644 --- a/core/src/utilities/random.cc +++ b/core/src/utilities/random.cc @@ -5,9 +5,6 @@ namespace jiminy { - static inline constexpr double PERLIN_NOISE_PERSISTENCE{1.50}; - static inline constexpr double PERLIN_NOISE_LACUNARITY{1.15}; - // ***************************** Uniform random bit generators ***************************** // PCG32::PCG32(uint64_t state) noexcept : @@ -171,18 +168,94 @@ namespace jiminy // **************************** Non-cryptographic hash function **************************** // - static uint32_t rotl32(uint32_t x, int8_t r) noexcept +#ifdef __has_builtin +# define HAS_BUILTIN(x) __has_builtin(x) +#else +# define HAS_BUILTIN(x) 0 +#endif + +#if !defined(NO_CLANG_BUILTIN) && HAS_BUILTIN(__builtin_rotateleft32) +# define rotl32 __builtin_rotateleft32 +/* Note: although _rotl exists for minGW (GCC under windows), performance seems poor */ +#elif defined(_MSC_VER) +# define rotl32(x, r) _rotl(x, r) +#else +# define rotl32(x, r) (((x) << (r)) | ((x) >> (32 - (r)))) +#endif + + constexpr uint32_t PRIME32_1 = 0x9E3779B1U; /* 0b10011110001101110111100110110001 */ + constexpr uint32_t PRIME32_2 = 0x85EBCA77U; /* 0b10000101111010111100101001110111 */ + constexpr uint32_t PRIME32_3 = 0xC2B2AE3DU; /* 0b11000010101100101010111000111101 */ + constexpr uint32_t PRIME32_4 = 0x27D4EB2FU; /* 0b00100111110101001110101100101111 */ + constexpr uint32_t PRIME32_5 = 0x165667B1U; /* 0b00010110010101100110011110110001 */ + + static uint32_t XXH32_round(uint32_t acc, const uint32_t input) + { + acc += input * PRIME32_2; + acc = rotl32(acc, 13); + acc *= PRIME32_1; + return acc; + } + + uint32_t xxHash(const void * input, int32_t len, uint32_t seed) noexcept { - return (x << r) | (x >> (32 - r)); + uint32_t hash; + + const auto * data = reinterpret_cast(input); + if (len >= 16) + { + uint32_t v1 = seed + PRIME32_1 + PRIME32_2; + uint32_t v2 = seed + PRIME32_2; + uint32_t v3 = seed + 0; + uint32_t v4 = seed - PRIME32_1; + + const uint8_t * const bEnd = data + len; + const uint8_t * const limit = bEnd - 15; + do + { + v1 = XXH32_round(v1, *reinterpret_cast(data)); + data += 4; + v2 = XXH32_round(v2, *reinterpret_cast(data)); + data += 4; + v3 = XXH32_round(v3, *reinterpret_cast(data)); + data += 4; + v4 = XXH32_round(v4, *reinterpret_cast(data)); + data += 4; + } while (data < limit); + len &= 15; + + hash = rotl32(v1, 1) + rotl32(v2, 7) + rotl32(v3, 12) + rotl32(v4, 18); + } + else + { + hash = seed + PRIME32_5; + } + hash += static_cast(len); + + while (len >= 4) + { + hash += *reinterpret_cast(data) * PRIME32_3; + data += 4; + hash = rotl32(hash, 17) * PRIME32_4; + len -= 4; + } + while (len > 0) + { + hash += *data * PRIME32_5; + data += 1; + hash = rotl32(hash, 11) * PRIME32_1; + --len; + } + + hash ^= hash >> 15; + hash *= PRIME32_2; + hash ^= hash >> 13; + hash *= PRIME32_3; + hash ^= hash >> 16; + return hash; } - /// \brief MurmurHash3 is a non-cryptographic hash function initially designed - /// for hash-based lookup. - /// - /// \sa It was written by Austin Appleby, and is placed in the public domain. - /// The author hereby disclaims copyright to this source code: - /// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp - static uint32_t MurmurHash3(const void * key, int32_t len, uint32_t seed) noexcept + uint32_t MurmurHash3(const void * key, int32_t len, uint32_t seed) noexcept { // Define some internal constants constexpr uint32_t c1 = 0xcc9e2d51; @@ -241,349 +314,174 @@ namespace jiminy h1 ^= h1 >> 13; h1 *= 0xc2b2ae35; h1 ^= h1 >> 16; - return h1; } // **************************** Continuous 1D Gaussian processes *************************** // - PeriodicGaussianProcess::PeriodicGaussianProcess(double wavelength, double period) noexcept : - wavelength_{wavelength}, - period_{period} + static std::tuple getClosestKnots(double value, + double delta) { - reset(std::random_device{}); - } + // Compute closest left and right indices + const double quot = value / delta; + const Eigen::Index indexLeft = static_cast(std::floor(quot)); + Eigen::Index indexRight = indexLeft + 1; - void PeriodicGaussianProcess::reset( - const uniform_random_bit_generator_ref & g) noexcept - { - // Sample normal vector - auto normalVec = normal(numTimes_, 1, g); + // Compute the time ratio + const double ratio = quot - static_cast(indexLeft); - // Compute discrete periodic gaussian process values - values_.noalias() = covSqrtRoot_.triangularView() * normalVec.cast(); + return {indexLeft, indexRight, ratio}; } - double PeriodicGaussianProcess::operator()(float t) + static std::tuple getClosestKnots( + double value, double delta, Eigen::Index numTimes) { - // Wrap requested time in gaussian process period - double tWrap = std::fmod(t, period_); - if (tWrap < 0) + // Wrap value in period interval + const double period = static_cast(numTimes) * delta; + value = std::fmod(value, period); + if (value < 0.0) { - tWrap += period_; + value += period; } - // Compute closest left and right indices - const Eigen::Index tLeftIndex = static_cast(std::floor(tWrap / dt_)); - const Eigen::Index tRightIndex = (tLeftIndex + 1) % numTimes_; - - // Perform First order interpolation - const double ratio = tWrap / dt_ - static_cast(tLeftIndex); - return values_[tLeftIndex] + ratio * (values_[tRightIndex] - values_[tLeftIndex]); + // Compute closest left and right indices, wrapping around if needed + auto [indexLeft, indexRight, ratio] = getClosestKnots(value, delta); + if (indexRight == numTimes) + { + indexRight = 0; + } + return {indexLeft, indexRight, ratio}; } - double PeriodicGaussianProcess::getWavelength() const noexcept + template + static std::decay_t cubicInterp( + double ratio, double delta, T && valueLeft, T && valueRight, T && gradLeft, T && gradRight) { - return wavelength_; + const auto dy = valueRight - valueLeft; + const auto a = gradLeft * delta - dy; + const auto b = -gradRight * delta + dy; + return valueLeft + ratio * ((1.0 - ratio) * ((1.0 - ratio) * a + ratio * b) + dy); } - double PeriodicGaussianProcess::getPeriod() const noexcept + template + static std::decay_t derivativeCubicInterp( + double ratio, double delta, T && valueLeft, T && valueRight, T && gradLeft, T && gradRight) { - return period_; + const auto dy = valueRight - valueLeft; + const auto a = gradLeft * delta - dy; + const auto b = -gradRight * delta + dy; + return ((1.0 - ratio) * (1.0 - 3.0 * ratio) * a + ratio * (2.0 - 3.0 * ratio) * b + dy) / + delta; } - // **************************** Continuous 1D Fourier processes **************************** // - - PeriodicFourierProcess::PeriodicFourierProcess(double wavelength, double period) noexcept : + PeriodicTabularProcess::PeriodicTabularProcess(double wavelength, double period) : wavelength_{wavelength}, period_{period} { - reset(std::random_device{}); - } - - void PeriodicFourierProcess::reset( - const uniform_random_bit_generator_ref & g) noexcept - { - // Sample normal vectors - auto normalVec1 = normal(numHarmonics_, 1, g); - auto normalVec2 = normal(numHarmonics_, 1, g); - - // Compute discrete periodic gaussian process values - const double scale = M_SQRT2 / std::sqrt(2 * numHarmonics_ + 1); - values_ = scale * cosMat_ * normalVec1.cast(); - values_.noalias() += scale * sinMat_ * normalVec2.cast(); - } - - double PeriodicFourierProcess::operator()(float t) - { - // Wrap requested time in guassian process period - double tWrap = std::fmod(t, period_); - if (tWrap < 0) + // Make sure the period is positive + if (period < 0.0) { - tWrap += period_; + JIMINY_THROW(std::invalid_argument, "'period' must be positive."); } - - // Compute closest left and right indices - const Eigen::Index tLeftIndex = static_cast(std::floor(tWrap / dt_)); - const Eigen::Index tRightIndex = (tLeftIndex + 1) % numTimes_; - - // Perform First order interpolation - const double ratio = tWrap / dt_ - static_cast(tLeftIndex); - return values_[tLeftIndex] + ratio * (values_[tRightIndex] - values_[tLeftIndex]); } - double PeriodicFourierProcess::getWavelength() const noexcept + double PeriodicTabularProcess::operator()(double t) const noexcept { - return wavelength_; - } + // Compute closest left index within time period + const auto [indexLeft, indexRight, ratio] = getClosestKnots(t, dt_, numTimes_); - double PeriodicFourierProcess::getPeriod() const noexcept - { - return period_; + /* Perform cubic spline interpolation to ensure continuity of the derivative: + https://en.wikipedia.org/wiki/Spline_interpolation#Algorithm_to_find_the_interpolating_cubic_spline + */ + return cubicInterp(ratio, + dt_, + values_[indexLeft], + values_[indexRight], + grads_[indexLeft], + grads_[indexRight]); } - // ***************************** Continuous 1D Perlin processes **************************** // - - AbstractPerlinNoiseOctave::AbstractPerlinNoiseOctave(double wavelength) : - wavelength_{wavelength} - { - if (wavelength_ <= 0.0) - { - JIMINY_THROW(std::invalid_argument, "'wavelength' must be strictly larger than 0.0."); - } - shift_ = uniform(std::random_device{}); - } - - void AbstractPerlinNoiseOctave::reset( - const uniform_random_bit_generator_ref & g) noexcept - { - // Sample random phase shift - shift_ = uniform(g); - } - - double AbstractPerlinNoiseOctave::operator()(double t) const + double PeriodicTabularProcess::grad(double t) const noexcept { - // Get current phase - const double phase = t / wavelength_ + shift_; - - // Compute closest right and left knots - const int32_t phaseIndexLeft = static_cast(phase); - const int32_t phaseIndexRight = phaseIndexLeft + 1; - - // Compute smoothed ratio of current phase wrt to the closest knots - const double dtLeft = phase - phaseIndexLeft; - const double dtRight = dtLeft - 1.0; - const double ratio = fade(dtLeft); - - /* Compute gradients at knots, and perform linear interpolation between them to get value - at current phase.*/ - const double yLeft = grad(phaseIndexLeft, dtLeft); - const double yRight = grad(phaseIndexRight, dtRight); - return lerp(ratio, yLeft, yRight); + const auto [indexLeft, indexRight, ratio] = getClosestKnots(t, dt_, numTimes_); + return derivativeCubicInterp(ratio, + dt_, + values_[indexLeft], + values_[indexRight], + grads_[indexLeft], + grads_[indexRight]); } - double AbstractPerlinNoiseOctave::getWavelength() const noexcept + double PeriodicTabularProcess::getWavelength() const noexcept { return wavelength_; } - double AbstractPerlinNoiseOctave::fade(double delta) noexcept + double PeriodicTabularProcess::getPeriod() const noexcept { - /* Improved Smoothstep function by Ken Perlin (aka Smootherstep). - It has zero 1st and 2nd-order derivatives at dt = 0.0, and 1.0: - https://en.wikipedia.org/wiki/Smoothstep#Variations */ - return std::pow(delta, 3) * (delta * (delta * 6.0 - 15.0) + 10.0); - } - - double AbstractPerlinNoiseOctave::lerp(double ratio, double yLeft, double yRight) noexcept - { - return yLeft + ratio * (yRight - yLeft); + return period_; } - RandomPerlinNoiseOctave::RandomPerlinNoiseOctave(double wavelength) : - AbstractPerlinNoiseOctave(wavelength) + PeriodicGaussianProcess::PeriodicGaussianProcess(double wavelength, double period) : + PeriodicTabularProcess(wavelength, period) { - seed_ = std::random_device{}(); + reset(std::random_device{}); } - void RandomPerlinNoiseOctave::reset( + void PeriodicGaussianProcess::reset( const uniform_random_bit_generator_ref & g) noexcept { - // Call base implementation - AbstractPerlinNoiseOctave::reset(g); - - // Sample new random seed for MurmurHash - seed_ = g(); - } - - double RandomPerlinNoiseOctave::grad(int32_t knot, double delta) const noexcept - { - // Get hash of knot - const uint32_t hash = MurmurHash3(&knot, sizeof(int32_t), seed_); - - // Convert to double in [0.0, 1.0) - const double s = - static_cast(hash) / static_cast(std::numeric_limits::max()); - - // Compute rescaled gradient between [-1.0, 1.0) - const double grad = 2.0 * s - 1.0; - - // Return scalar product between distance and gradient - return 2.0 * grad * delta; - } - - namespace internal - { - template - void randomizePermutationVector(Generator && g, T & vec) - { - // Re-Initialize the permutation vector with values from 0 to size - std::iota(vec.begin(), vec.end(), 0); + // Sample normal vector + const Eigen::VectorXd normalVec = normal(numTimes_, 1, g).cast(); - // Shuffle the permutation vector - std::shuffle(vec.begin(), vec.end(), g); - } - } + /* Compute discrete periodic gaussian process values. - PeriodicPerlinNoiseOctave::PeriodicPerlinNoiseOctave(double wavelength, double period) : - AbstractPerlinNoiseOctave(wavelength), - period_{period} - { - // Make sure the wavelength is multiple of the period - if (std::abs(std::round(period / wavelength) * wavelength - period) > - std::numeric_limits::epsilon()) - { - JIMINY_THROW(std::invalid_argument, "'wavelength' must be multiple of 'period'."); - } - - // Initialize the permutation vector with values from 0 to 255 and shuffle it - internal::randomizePermutationVector(std::random_device{}, perm_); - } + A gaussian process can be derived from a normally distributed random vector. + More precisely, a Gaussian Process y is uniquely defined by its kernel K and + a normally distributed random vector z ~ N(0, I). Let us consider a timestamp t. + The value of the Gaussian process y at time t is given by: + y(t) = K(t*, t) @ (L^-T @ z), + where: + t* are evenly spaced sampling timestamps associated with z + Cov = K(t*, t*) = L @ L^T is the Cholesky decomposition of the covariance matrix. - void PeriodicPerlinNoiseOctave::reset( - const uniform_random_bit_generator_ref & g) noexcept - { - // Call base implementation - AbstractPerlinNoiseOctave::reset(g); + Its analytical derivative can be deduced easily: + dy/dt(t) = dK/dt(t*, t) @ (L^-T @ z). - // Re-Initialize the permutation vector with values from 0 to 255 - internal::randomizePermutationVector(g, perm_); + When the query timestamps corresponds to the sampling timestamps, it yields: + y^* = K(t*, t*) @ (L^-T @ z) = L @ z + dy/dt^* = dK/dt(t*, t*) @ (L^-T @ z). */ + values_.noalias() = covSqrtRoot_.triangularView() * normalVec; + grads_.noalias() = + covJacobian_ * + covSqrtRoot_.transpose().triangularView().solve(normalVec); } - double PeriodicPerlinNoiseOctave::grad(int32_t knot, double delta) const noexcept - { - // Wrap knot is period interval - knot %= static_cast(period_ / wavelength_); - - // Convert to double in [0.0, 1.0) - const double s = perm_[knot] / 256.0; - - // Compute rescaled gradient between [-1.0, 1.0) - const double grad = 2.0 * s - 1.0; - - // Return scalar product between distance and gradient - return 2.0 * grad * delta; - } + // **************************** Continuous 1D Fourier processes **************************** // - AbstractPerlinProcess::AbstractPerlinProcess( - std::vector && octaveScalePairs) noexcept : - octaveScalePairs_(std::move(octaveScalePairs)) + PeriodicFourierProcess::PeriodicFourierProcess(double wavelength, double period) : + PeriodicTabularProcess(wavelength, period) { - // Compute the scaling factor to keep values within range [-1.0, 1.0] - double amplitudeSquared = 0.0; - for (const OctaveScalePair & octaveScale : octaveScalePairs_) - { - // FIXME: replaced `std::get` by placeholder `_` when moving to C++26 (P2169R4) - amplitudeSquared += std::pow(std::get<1>(octaveScale), 2); - } - amplitude_ = std::sqrt(amplitudeSquared); + reset(std::random_device{}); } - void AbstractPerlinProcess::reset( + void PeriodicFourierProcess::reset( const uniform_random_bit_generator_ref & g) noexcept { - // Reset octaves - for (OctaveScalePair & octaveScale : octaveScalePairs_) - { - // FIXME: replaced `std::get` by placeholder `_` when moving to C++26 (P2169R4) - std::get<0>(octaveScale)->reset(g); - } - } - - double AbstractPerlinProcess::operator()(float t) - { - // Compute sum of octaves' values - double value = 0.0; - for (const auto & [octave, scale] : octaveScalePairs_) - { - value += scale * (*octave)(t); - } - - // Scale sum by maximum amplitude - return value / amplitude_; - } - - double AbstractPerlinProcess::getWavelength() const noexcept - { - double wavelength = INF; - for (const OctaveScalePair & octaveScale : octaveScalePairs_) - { - // FIXME: replaced `std::get` by placeholder `_` when moving to C++26 (P2169R4) - wavelength = std::min(wavelength, std::get<0>(octaveScale)->getWavelength()); - } - return wavelength; - } - - std::size_t AbstractPerlinProcess::getNumOctaves() const noexcept - { - return octaveScalePairs_.size(); - } - - std::vector buildPerlinNoiseOctaves( - double wavelength, - std::size_t numOctaves, - std::function(double)> factory) - { - std::vector octaveScalePairs; - octaveScalePairs.reserve(numOctaves); - double scale = 1.0; - for (std::size_t i = 0; i < numOctaves; ++i) - { - octaveScalePairs.emplace_back(factory(wavelength), scale); - wavelength *= PERLIN_NOISE_LACUNARITY; - scale *= PERLIN_NOISE_PERSISTENCE; - } - return octaveScalePairs; - } - - RandomPerlinProcess::RandomPerlinProcess(double wavelength, std::size_t numOctaves) : - AbstractPerlinProcess(buildPerlinNoiseOctaves( - wavelength, - numOctaves, - [](double wavelengthIn) -> std::unique_ptr - { return std::make_unique(wavelengthIn); })) - { - } + // Sample normal vectors + const Eigen::VectorXd normalVec1 = normal(numHarmonics_, 1, g).cast(); + const Eigen::VectorXd normalVec2 = normal(numHarmonics_, 1, g).cast(); - PeriodicPerlinProcess::PeriodicPerlinProcess( - double wavelength, double period, std::size_t numOctaves) : - AbstractPerlinProcess(buildPerlinNoiseOctaves( - wavelength, - numOctaves, - [period](double wavelengthIn) -> std::unique_ptr - { return std::make_unique(wavelengthIn, period); })), - period_{period} - { - // Make sure the period is larger than the wavelength - if (period_ < wavelength) - { - JIMINY_THROW(std::invalid_argument, "'period' must be larger than 'wavelength'."); - } - } + // Compute discrete periodic fourrier process values and derivatives + const double scale = M_SQRT2 / std::sqrt(2 * numHarmonics_ + 1); + values_ = scale * sinMat_ * normalVec1; + values_.noalias() += scale * cosMat_ * normalVec2; - double PeriodicPerlinProcess::getPeriod() const noexcept - { - return period_; + const auto diff = + 2 * M_PI / period_ * + Eigen::VectorXd::LinSpaced(numHarmonics_, 1, static_cast(numHarmonics_)); + grads_ = scale * cosMat_ * normalVec1.cwiseProduct(diff); + grads_.noalias() -= scale * sinMat_ * normalVec2.cwiseProduct(diff); } // ******************************* Random terrain generators ******************************* // @@ -593,7 +491,7 @@ namespace jiminy const MatrixX & state, int64_t sparsity, uint32_t seed) noexcept { const auto numBytes = static_cast(sizeof(Scalar) * state.size()); - const uint32_t hash = MurmurHash3(state.data(), numBytes, seed); + const uint32_t hash = xxHash(state.data(), numBytes, seed); if (hash % sparsity == 0) { return static_cast(hash) / @@ -658,7 +556,7 @@ namespace jiminy double orientation, uint32_t seed) { - if ((0.01 < interpDelta.array()).any() || (interpDelta.array() > size.array() / 2.0).any()) + if ((0.01 > interpDelta.array()).any() || (interpDelta.array() > size.array() / 2.0).any()) { JIMINY_WARNING( "All components of 'interpDelta' must be in range [0.01, 'size'/2.0]. Value: ", @@ -675,13 +573,15 @@ namespace jiminy uniformSparseFromState(Vector1::Constant(i), 1, seed); }); - const Eigen::Rotation2D rot_mat(orientation); + const Eigen::Rotation2D rotMat(orientation); - return [size, heightMax, interpDelta, rot_mat, sparsity, interpThr, offset, seed]( - const Eigen::Vector2d & pos, double & height, Eigen::Vector3d & normal) -> void + return [size, heightMax, interpDelta, rotMat, sparsity, interpThr, offset, seed]( + const Eigen::Vector2d & pos, + double & height, + std::optional> normal) -> void { // Compute the tile index and relative coordinate - Eigen::Vector2d posRel = (rot_mat * (pos + offset)).array() / size.array(); + Eigen::Vector2d posRel = (rotMat * (pos + offset)).array() / size.array(); Vector2 posIndices = posRel.array().floor().cast(); posRel -= posIndices.cast(); @@ -694,7 +594,10 @@ namespace jiminy std::tie(height, dheight_x) = tile2dInterp1d( posIndices, posRel, 0, size, sparsity, heightMax, interpThr, seed); const double norm_inv = 1.0 / std::sqrt(dheight_x * dheight_x + 1.0); - normal << -dheight_x * norm_inv, 0.0, norm_inv; + if (normal.has_value()) + { + normal.value() << -dheight_x * norm_inv, 0.0, norm_inv; + } } else if (!is_edge[0] && is_edge[1]) { @@ -702,7 +605,10 @@ namespace jiminy std::tie(height, dheight_y) = tile2dInterp1d( posIndices, posRel, 1, size, sparsity, heightMax, interpThr, seed); const double norm_inv = 1.0 / std::sqrt(dheight_y * dheight_y + 1.0); - normal << 0.0, -dheight_y * norm_inv, norm_inv; + if (normal.has_value()) + { + normal.value() << 0.0, -dheight_y * norm_inv, norm_inv; + } } else if (is_edge[0] && is_edge[1]) { @@ -719,8 +625,11 @@ namespace jiminy const double dheight_x = dheight_x_0 + (dheight_x_m - dheight_x_0) * ratio; const double dheight_y = (height_0 - height_m) / (2.0 * size[1] * interpThr[1]); - normal << -dheight_x, -dheight_y, 1.0; - normal.normalize(); + if (normal.has_value()) + { + normal.value() << -dheight_x, -dheight_y, 1.0; + normal->normalize(); + } } else { @@ -733,14 +642,20 @@ namespace jiminy const double dheight_x = dheight_x_0 + (dheight_x_p - dheight_x_0) * ratio; const double dheight_y = (height_p - height_0) / (2.0 * size[1] * interpThr[1]); - normal << -dheight_x, -dheight_y, 1.0; - normal.normalize(); + if (normal.has_value()) + { + normal.value() << -dheight_x, -dheight_y, 1.0; + normal->normalize(); + } } } else { height = heightMax * uniformSparseFromState(posIndices, sparsity, seed); - normal = Eigen::Vector3d::UnitZ(); + if (normal.has_value()) + { + normal.value() = Eigen::Vector3d::UnitZ(); + } } }; } diff --git a/core/unit/CMakeLists.txt b/core/unit/CMakeLists.txt index d68c0c0c0..9f33e99f3 100755 --- a/core/unit/CMakeLists.txt +++ b/core/unit/CMakeLists.txt @@ -11,6 +11,7 @@ find_package(Threads) set(UNIT_TEST_FILES "${CMAKE_CURRENT_SOURCE_DIR}/engine_sanity_check.cc" "${CMAKE_CURRENT_SOURCE_DIR}/model_test.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/random_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/miscellaneous.cc" ) diff --git a/core/unit/miscellaneous.cc b/core/unit/miscellaneous.cc index 770d153d7..6648517aa 100644 --- a/core/unit/miscellaneous.cc +++ b/core/unit/miscellaneous.cc @@ -1,4 +1,3 @@ -#include #include #include "jiminy/core/fwd.h" @@ -71,66 +70,3 @@ TEST(Miscellaneous, swapMatrixRows) matrixOut); } } - - -TEST(Miscellaneous, MatrixRandom) -{ - using generator_t = - std::independent_bits_engine::digits, uint32_t>; - - generator_t gen32{0}; - uniform_random_bit_generator_ref gen32_ref = gen32; - - float mean = 5.0; - float stddev = 2.0; - - auto mean_vec = Eigen::MatrixXf::Constant(1, 2, mean); - auto stddev_vec = Eigen::MatrixXf::Constant(1, 2, stddev); - float value1 = normal(gen32, mean, stddev); - float value2 = normal(gen32, mean, stddev); - - { - gen32.seed(0); - scalar_random_op op{ - [](auto & g, float _mean, float _stddev) -> float - { return normal(g, _mean, _stddev); }, - gen32, - mean_vec, - stddev_vec}; - ASSERT_FLOAT_EQ(op(0, 0), value1); - ASSERT_FLOAT_EQ(op(0, 1), value2); - ASSERT_THAT(op(0, 0), testing::Not(testing::FloatEq(value1))); - } - { - gen32.seed(0); - scalar_random_op &, float, float), - uniform_random_bit_generator_ref, - float, - float> - op{normal, gen32, mean, stddev}; - ASSERT_FLOAT_EQ(op(0, 0), value1); - ASSERT_FLOAT_EQ(op(10, 10), value2); - ASSERT_THAT(op(0, 0), testing::Not(testing::FloatEq(value1))); - } - { - gen32.seed(0); - auto mat_expr = normal(gen32_ref, mean_vec, stddev_vec); - ASSERT_FLOAT_EQ(mat_expr(0, 0), value1); - ASSERT_FLOAT_EQ(mat_expr(0, 1), value2); - ASSERT_THAT(mat_expr(0, 0), testing::Not(testing::FloatEq(value1))); - } - { - gen32.seed(0); - auto mat_expr = normal(1, 2, gen32, mean, stddev); - ASSERT_FLOAT_EQ(mat_expr(0, 0), value1); - ASSERT_FLOAT_EQ(mat_expr(0, 1), value2); - ASSERT_THAT(mat_expr(0, 0), testing::Not(testing::FloatEq(value1))); - } - { - gen32.seed(0); - auto mat_expr = normal(gen32_ref, mean, stddev_vec.transpose()); - ASSERT_FLOAT_EQ(mat_expr(0, 0), value1); - ASSERT_FLOAT_EQ(mat_expr(1, 0), value2); - ASSERT_THAT(mat_expr(0, 0), testing::Not(testing::FloatEq(value1))); - } -} diff --git a/core/unit/random_test.cc b/core/unit/random_test.cc new file mode 100644 index 000000000..ae9c7a82e --- /dev/null +++ b/core/unit/random_test.cc @@ -0,0 +1,166 @@ +#include // `testing::*` +#include + +#include "jiminy/core/utilities/random.h" + + +using namespace jiminy; + +static inline constexpr double DELTA{1e-6}; +static inline constexpr double TOL{1e-4}; + + +TEST(Miscellaneous, MatrixRandom) +{ + using generator_t = + std::independent_bits_engine::digits, uint32_t>; + + generator_t gen32{0}; + uniform_random_bit_generator_ref gen32_ref = gen32; + + float mean = 5.0; + float stddev = 2.0; + + auto mean_vec = Eigen::MatrixXf::Constant(1, 2, mean); + auto stddev_vec = Eigen::MatrixXf::Constant(1, 2, stddev); + float value1 = normal(gen32, mean, stddev); + float value2 = normal(gen32, mean, stddev); + + { + gen32.seed(0); + scalar_random_op op{ + [](auto & g, float _mean, float _stddev) -> float + { return normal(g, _mean, _stddev); }, + gen32, + mean_vec, + stddev_vec}; + ASSERT_FLOAT_EQ(op(0, 0), value1); + ASSERT_FLOAT_EQ(op(0, 1), value2); + ASSERT_THAT(op(0, 0), testing::Not(testing::FloatEq(value1))); + } + { + gen32.seed(0); + scalar_random_op &, float, float), + uniform_random_bit_generator_ref, + float, + float> + op{normal, gen32, mean, stddev}; + ASSERT_FLOAT_EQ(op(0, 0), value1); + ASSERT_FLOAT_EQ(op(10, 10), value2); + ASSERT_THAT(op(0, 0), testing::Not(testing::FloatEq(value1))); + } + { + gen32.seed(0); + auto mat_expr = normal(gen32_ref, mean_vec, stddev_vec); + ASSERT_FLOAT_EQ(mat_expr(0, 0), value1); + ASSERT_FLOAT_EQ(mat_expr(0, 1), value2); + ASSERT_THAT(mat_expr(0, 0), testing::Not(testing::FloatEq(value1))); + } + { + gen32.seed(0); + auto mat_expr = normal(1, 2, gen32, mean, stddev); + ASSERT_FLOAT_EQ(mat_expr(0, 0), value1); + ASSERT_FLOAT_EQ(mat_expr(0, 1), value2); + ASSERT_THAT(mat_expr(0, 0), testing::Not(testing::FloatEq(value1))); + } + { + gen32.seed(0); + auto mat_expr = normal(gen32_ref, mean, stddev_vec.transpose()); + ASSERT_FLOAT_EQ(mat_expr(0, 0), value1); + ASSERT_FLOAT_EQ(mat_expr(1, 0), value2); + ASSERT_THAT(mat_expr(0, 0), testing::Not(testing::FloatEq(value1))); + } +} + + +TEST(PerlinNoiseTest, RandomPerlinNoiseOctaveInitialization) +{ + double wavelength = 10.0; + RandomPerlinNoiseOctave<1> octave(wavelength); + octave.reset(PCG32{std::seed_seq{0}}); + + EXPECT_DOUBLE_EQ(octave.getWavelength(), wavelength); +} + +TEST(PerlinNoiseTest, PeriodicPerlinNoiseOctaveInitialization) +{ + double wavelength = 10.0; + double period = 20.0; + PeriodicPerlinNoiseOctave<1> octave(wavelength, period); + octave.reset(PCG32{std::seed_seq{0}}); + + EXPECT_DOUBLE_EQ(octave.getWavelength(), wavelength); + EXPECT_DOUBLE_EQ(octave.getPeriod(), period); +} + +TEST(PerlinNoiseTest, RandomGradientCalculation1D) +{ + { + double wavelength = 10.0; + RandomPerlinNoiseOctave<1> octave(wavelength); + octave.reset(PCG32{std::seed_seq{0}}); + + Eigen::Array t{5.43}; + Eigen::Matrix finiteDiffGrad{(octave(t + DELTA) - octave(t - DELTA)) / + (2 * DELTA)}; + ASSERT_TRUE(finiteDiffGrad.isApprox(octave.grad(t), TOL)); + } + { + double wavelength = 3.41; + RandomPerlinNoiseOctave<1> octave(wavelength); + octave.reset(PCG32{std::seed_seq{0}}); + + Eigen::Array t{17.0}; + Eigen::Matrix finiteDiffGrad{(octave(t + DELTA) - octave(t - DELTA)) / + (2 * DELTA)}; + ASSERT_TRUE(finiteDiffGrad.isApprox(octave.grad(t), TOL)); + } +} + +TEST(PerlinNoiseTest, PeriodicGradientCalculation1D) +{ + double wavelength = 10.0; + double period = 30.0; + PeriodicPerlinNoiseOctave<1> octave(wavelength, period); + octave.reset(PCG32{std::seed_seq{0}}); + + Eigen::Array t{5.43}; + Eigen::Matrix finiteDiffGrad{(octave(t + DELTA) - octave(t - DELTA)) / + (2 * DELTA)}; + ASSERT_TRUE(finiteDiffGrad.isApprox(octave.grad(t), TOL)); +} + + +TEST(PerlinNoiseTest, RandomGradientCalculation2D) +{ + double wavelength = 10.0; + RandomPerlinNoiseOctave<2> octave(wavelength); + octave.reset(PCG32{std::seed_seq{0}}); + + Eigen::Vector2d pos{5.43, 7.12}; + Eigen::Vector2d finiteDiffGrad{(octave(pos + DELTA * Eigen::Vector2d::UnitX()) - + octave(pos - DELTA * Eigen::Vector2d::UnitX())) / + (2 * DELTA), + (octave(pos + DELTA * Eigen::Vector2d::UnitY()) - + octave(pos - DELTA * Eigen::Vector2d::UnitY())) / + (2 * DELTA)}; + ASSERT_TRUE(finiteDiffGrad.isApprox(octave.grad(pos), TOL)); +} + + +TEST(PerlinNoiseTest, PeriodicGradientCalculation2D) +{ + double wavelength = 10.0; + double period = 30.0; + PeriodicPerlinNoiseOctave<2> octave(wavelength, period); + octave.reset(PCG32{std::seed_seq{0}}); + + Eigen::Vector2d pos{5.43, 7.12}; + Eigen::Vector2d finiteDiffGrad{(octave(pos + DELTA * Eigen::Vector2d::UnitX()) - + octave(pos - DELTA * Eigen::Vector2d::UnitX())) / + (2 * DELTA), + (octave(pos + DELTA * Eigen::Vector2d::UnitY()) - + octave(pos - DELTA * Eigen::Vector2d::UnitY())) / + (2 * DELTA)}; + ASSERT_TRUE(finiteDiffGrad.isApprox(octave.grad(pos), TOL)); +} diff --git a/docs/api/gym_jiminy/common/index.rst b/docs/api/gym_jiminy/common/index.rst index a71e72016..1ab5a1e2f 100644 --- a/docs/api/gym_jiminy/common/index.rst +++ b/docs/api/gym_jiminy/common/index.rst @@ -8,6 +8,6 @@ Gym Jiminy API blocks/index envs/index quantities/index - rewards/index + compositions/index wrappers/index utils/index diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py b/python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py index 2925c1d14..b62a1981a 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py @@ -19,8 +19,11 @@ StateQuantity, DatasetTrajectoryQuantity) from .compositions import (AbstractReward, - BaseQuantityReward, - BaseMixtureReward) + QuantityReward, + MixtureReward, + AbstractTerminationCondition, + QuantityTermination, + EpisodeState) from .blocks import (BlockStateT, InterfaceBlock, BaseObserverBlock, @@ -53,8 +56,10 @@ 'InterfaceQuantity', 'AbstractQuantity', 'AbstractReward', - 'BaseQuantityReward', - 'BaseMixtureReward', + 'AbstractTerminationCondition', + 'QuantityReward', + 'MixtureReward', + 'QuantityTermination', 'BaseObserverBlock', 'BaseControllerBlock', 'BasePipelineWrapper', @@ -65,6 +70,7 @@ 'ControlledJiminyEnv', 'QuantityEvalMode', 'QuantityCreator', + 'EpisodeState', 'StateQuantity', 'DatasetTrajectoryQuantity' ] diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py b/python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py index 8e5c09d09..14af723ea 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py @@ -1,10 +1,14 @@ -"""This module promotes reward components as first-class objects. +"""This module promotes reward components and termination conditions as +first-class objects. Those building blocks that can be plugged onto an existing +pipeline by composition to keep everything modular, from the task definition to +the low-level observers and controllers. -Defining rewards this way allows for standardization of usual metrics. Overall, -it greatly reduces code duplication and bugs. +This modular approach allows for standardization of usual metrics. Overall, it +greatly reduces code duplication and bugs. """ from abc import ABC, abstractmethod -from typing import Sequence, Callable, Optional, Tuple, TypeVar +from enum import IntEnum +from typing import Tuple, Sequence, Callable, Union, Optional, Generic, TypeVar import numpy as np @@ -14,6 +18,10 @@ ValueT = TypeVar('ValueT') +Number = Union[float, int, bool, complex] +ArrayOrScalar = Union[np.ndarray, np.number, Number] +ArrayLikeOrScalar = Union[ArrayOrScalar, Sequence[Union[Number, np.number]]] + class AbstractReward(ABC): """Abstract class from which all reward component must derived. @@ -24,7 +32,7 @@ class AbstractReward(ABC): indefinite (aka. objective). Defining cost is allowed by not recommended. Although it encourages the - agent to achieve the task at hands as quickly as possible if success if the + agent to achieve the task at hands as quickly as possible if success is the only termination condition, it has the side-effect to give the opportunity to the agent to maximize the return by killing itself whenever this is an option, which is rarely the desired behavior. No restriction is enforced as @@ -47,8 +55,8 @@ def name(self) -> str: """Name uniquely identifying a given reward component. This name will be used as key for storing reward-specific monitoring - information in 'info' if key is missing, otherwise it will raise an - exception. + and debugging information in 'info' if key does not already exists, + otherwise it will raise an exception. """ return self._name @@ -95,6 +103,12 @@ def compute(self, terminated: bool, info: InfoType) -> Optional[float]: method to honor flags 'is_terminated' (if not indefinite) and 'is_normalized'. Failing this, an exception will be raised. + :param terminated: Whether the episode has reached a terminal state of + the MDP at the current step. + :param info: Dictionary of extra information for monitoring. It will be + updated in-place for storing current value of the reward + in 'info' if it was truly evaluated. + :returns: Scalar value if the reward was evaluated, `None` otherwise. """ @@ -109,8 +123,9 @@ def __call__(self, terminated: bool, info: InfoType) -> float: This method is a lightweight wrapper around `compute` to skip evaluation depending on whether the current state and the reward are terminal. If the reward was truly evaluated, then 'info' is - updated to store either reward-specific 'info' if any or its value - otherwise. If not, then 'info' is left as-is and 0.0 is returned. + updated to store either custom debugging information if any or its + value otherwise. If the reward is not evaluated, then 'info' is + left as-is and 0.0 is returned. .. warning:: This method is not meant to be overloaded. @@ -142,7 +157,7 @@ def __call__(self, terminated: bool, info: InfoType) -> float: "Reward not normalized in range [0.0, 1.0] as it ought to be.") # Store its value as info - if self.name is info.keys(): + if self.name in info.keys(): raise KeyError( f"Key '{self.name}' already reserved in 'info'. Impossible to " "store value of reward component.") @@ -155,9 +170,9 @@ def __call__(self, terminated: bool, info: InfoType) -> float: return value -class BaseQuantityReward(AbstractReward): - """Base class that makes easy easy to derive reward components from generic - quantities. +class QuantityReward(AbstractReward, Generic[ValueT]): + """Convenience class making it easy to derive reward components from + generic quantities. All this class does is applying some user-specified post-processing to the value of a given multi-variate quantity to return a floating-point scalar @@ -179,9 +194,9 @@ def __init__(self, quantities by the environment. As a result, it must be unique otherwise an exception will be raised. :param quantity: Tuple gathering the class of the underlying quantity - to use as reward after some post-processing, plus all - its constructor keyword-arguments except environment - 'env' and parent 'parent. + to use as reward after some post-processing, plus any + keyword-arguments of its constructor except 'env', + and 'parent'. :param transform_fn: Transform function responsible for aggregating a multi-variate quantity as floating-point scalar value to maximize. Typical examples are `np.min`, @@ -219,7 +234,7 @@ def __init__(self, self.env.quantities[self.name] = quantity # Keep track of the underlying quantity - self.quantity = self.env.quantities.registry[self.name] + self.data = self.env.quantities.registry[self.name] def __del__(self) -> None: try: @@ -251,7 +266,7 @@ def compute(self, terminated: bool, info: InfoType) -> Optional[float]: return None # Evaluate raw quantity - value = self.env.quantities[self.name] + value = self.data.get() # Early return if quantity is None if value is None: @@ -265,17 +280,17 @@ def compute(self, terminated: bool, info: InfoType) -> Optional[float]: return value -BaseQuantityReward.name.__doc__ = \ +QuantityReward.name.__doc__ = \ """Name uniquely identifying every reward. It will be used as key not only for storing reward-specific monitoring - information in 'info', but also for adding the underlying quantity to - the ones already managed by the environment. + and debugging information in 'info', but also for adding the underlying + quantity to the ones already managed by the environment. """ -class BaseMixtureReward(AbstractReward): - """Base class for aggregating multiple independent reward components in a +class MixtureReward(AbstractReward): + """Base class for aggregating multiple independent reward components as a single one. """ @@ -288,7 +303,7 @@ def __init__(self, name: str, components: Sequence[AbstractReward], reduce_fn: Callable[ - [Sequence[Optional[float]]], Optional[float]], + [Tuple[Optional[float], ...]], Optional[float]], is_normalized: bool) -> None: """ :param env: Base or wrapped jiminy environment. @@ -363,6 +378,258 @@ def compute(self, terminated: bool, info: InfoType) -> Optional[float]: values.append(value) # Aggregate all reward components in one - reward_total = self._reduce_fn(values) + reward_total = self._reduce_fn(tuple(values)) return reward_total + + +class EpisodeState(IntEnum): + """Specify the current state of the ongoing episode. + """ + + CONTINUED = 0 + """No termination condition has been triggered this step. + """ + + TERMINATED = 1 + """The terminal state has been reached. + """ + + TRUNCATED = 2 + """A truncation condition has been triggered. + """ + + +class AbstractTerminationCondition(ABC): + """Abstract class from which all termination conditions must derived. + + Request the ongoing episode to stop immediately as soon as a termination + condition is triggered. + + There are two cases: truncating the episode or reaching the terminal state. + In the former case, the agent is instructed to stop collecting samples from + the ongoing episode and move to the next one, without considering this as a + failure. As such, the reward-to-go that has not been observed will be + estimated via a value function estimator. This is already what happens + when collecting sample batches in the infinite horizon RL framework, except + that the episode is not resumed to collect the rest of the episode in the + following sample batched. In the case of a termination condition, the agent + is just as much instructed to move to the next episode, but also to + consider that it was an actual failure. This means that, unlike truncation + conditions, the reward-to-go is known to be exactly zero. This is usually + dramatic for the agent in the perspective of an infinite horizon reward, + even more as the maximum discounted reward grows larger as the discount + factor gets closer to one. As a result, the agent will avoid at all cost + triggering terminal conditions, to the point of becoming risk averse by + taking extra security margins lowering the average reward if necessary. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + name: str, + grace_period: float = 0.0, + *, + is_truncation: bool = False, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param name: Desired name of the termination condition. This name will + be used as key for storing the current episode state from + the perspective of this specific condition in 'info', and + to add the underlying quantity to the set of already + managed quantities by the environment. As a result, it + must be unique otherwise an exception will be raised. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_truncation: Whether the episode should be considered + terminated or truncated whenever the termination + condition is triggered. + Optional: False by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + self.env = env + self._name = name + self.grace_period = grace_period + self.is_truncation = is_truncation + self.is_training_only = is_training_only + + @property + def name(self) -> str: + """Name uniquely identifying a given termination condition. + + This name will be used as key for storing termination + condition-specific monitoring information in 'info' if key does not + already exists, otherwise it will raise an exception. + """ + return self._name + + @abstractmethod + def compute(self, info: InfoType) -> bool: + """Evaluate the termination condition at hands. + + :param info: Dictionary of extra information for monitoring. It will be + updated in-place for storing terminated and truncated + flags in 'info' as a tri-states `EpisodeState` value. + """ + + def __call__(self, info: InfoType) -> Tuple[bool, bool]: + """Return whether the termination condition has been triggered. + + For the corresponding MDP to be stationary, the condition to trigger + termination is supposed to involve only the transition from previous to + current state of the environment under the ongoing action. + + .. note:: + This method is a lightweight wrapper around `compute` to return two + boolean flags 'terminated', 'truncated' complying with Gym API. + 'info' will be updated to store either custom debug information if + any, a tri-states episode state `EpisodeState` otherwise. + + .. warning:: + This method is not meant to be overloaded. + + :param info: Dictionary of extra information for monitoring. It will be + updated in-place for storing terminated and truncated + flags in 'info' as a tri-states `EpisodeState` value. + + :returns: terminated and truncated flags. + """ + # Skip termination condition in eval mode or during grace period + termination_info: InfoType = {} + if (self.is_training_only and not self.env.is_training) or ( + self.env.stepper_state.t < self.grace_period): + # Always continue + is_terminated, is_truncated = False, False + else: + # Evaluate the reward and store extra information + is_done = self.compute(termination_info) + is_terminated = is_done and not self.is_truncation + is_truncated = is_done and self.is_truncation + + # Store episode state as info + if self.name in info.keys(): + raise KeyError( + f"Key '{self.name}' already reserved in 'info'. Impossible to " + "store value of termination condition.") + if termination_info: + info[self.name] = termination_info + else: + if is_terminated: + episode_state = EpisodeState.TERMINATED + elif is_truncated: + episode_state = EpisodeState.TRUNCATED + else: + episode_state = EpisodeState.CONTINUED + info[self.name] = episode_state + + # Returning terminated and truncated flags + return is_terminated, is_truncated + + +class QuantityTermination(AbstractTerminationCondition, Generic[ValueT]): + """Convenience class making it easy to derive termination conditions from + generic quantities. + + All this class does is checking that, all elements of a given quantity are + within bounds. If so, then the episode continues, otherwise it is either + truncated or terminated according to 'is_truncation' constructor argument. + This only applies after the end of a grace period. Before that, the + episode continues no matter what. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + name: str, + quantity: QuantityCreator[Optional[ArrayOrScalar]], + low: Optional[ArrayLikeOrScalar], + high: Optional[ArrayLikeOrScalar], + grace_period: float = 0.0, + *, + is_truncation: bool = False, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param name: Desired name of the termination condition. This name will + be used as key for storing the current episode state from + the perspective of this specific condition in 'info', and + to add the underlying quantity to the set of already + managed quantities by the environment. As a result, it + must be unique otherwise an exception will be raised. + :param quantity: Tuple gathering the class of the underlying quantity + to use as termination condition, plus any + keyword-arguments of its constructor except 'env', + and 'parent'. + :param low: Lower bound below which termination is triggered. + :param high: Upper bound above which termination is triggered. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_truncation: Whether the episode should be considered + terminated or truncated whenever the termination + condition is triggered. + Optional: False by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + # Backup user argument(s) + self.low = low + self.high = high + + # Call base implementation + super().__init__( + env, + name, + grace_period, + is_truncation=is_truncation, + is_training_only=is_training_only) + + # Add quantity to the set of quantities managed by the environment + self.env.quantities[self.name] = quantity + + # Keep track of the underlying quantity + self.data = self.env.quantities.registry[self.name] + + def __del__(self) -> None: + try: + del self.env.quantities[self.name] + except Exception: # pylint: disable=broad-except + # This method must not fail under any circumstances + pass + + def compute(self, info: InfoType) -> bool: + """Evaluate the termination condition. + + The underlying quantity is first evaluated. The episode continues if + all the elements of its value are within bounds, otherwise the episode + is either truncated or terminated according to 'is_truncation'. + + .. warning:: + This method is not meant to be overloaded. + """ + # Evaluate the quantity + value = self.data.get() + + # Check if the quantity is out-of-bounds bound. + # Note that it may be `None` if the quantity is ill-defined for the + # current simulation state, which triggers termination unconditionally. + is_done = value is None + is_done |= self.low is not None and bool(np.any(self.low > value)) + is_done |= self.high is not None and bool(np.any(value > self.high)) + return is_done + + +QuantityTermination.name.__doc__ = \ + """Name uniquely identifying every termination condition. + + It will be used as key not only for storing termination condition-specific + monitoring and debugging information in 'info', but also for adding the + underlying quantity to the ones already managed by the environment. + """ diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py b/python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py index 1d25bf655..aafc05e5c 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py @@ -5,7 +5,8 @@ from abc import abstractmethod, ABC from collections import OrderedDict from typing import ( - Dict, Any, TypeVar, Generic, no_type_check, TypedDict, TYPE_CHECKING) + Dict, Any, Tuple, TypeVar, Generic, TypedDict, no_type_check, + TYPE_CHECKING) import numpy as np import numpy.typing as npt @@ -273,7 +274,7 @@ def _observer_handle(self, # they are always updated before the controller gets called, no matter # if either one or the other is time-continuous. Hacking the internal # dynamics to clear quantities does not address this issue either. - self.quantities.clear() + # self.quantities.clear() # Refresh the observation if not already done but only if a simulation # is already running. It would be pointless to refresh the observation @@ -353,23 +354,19 @@ def stop(self) -> None: """ self.simulator.stop() - @property @abstractmethod - def unwrapped(self) -> "BaseJiminyEnv": - """The "underlying environment at the basis of the pipeline from which - this environment is part of. - """ + def has_terminated(self, info: InfoType) -> Tuple[bool, bool]: + """Determine whether the episode is over, because a terminal state of + the underlying MDP has been reached or an aborting condition outside + the scope of the MDP has been triggered. - @property - @abstractmethod - def step_dt(self) -> float: - """Get timestep of a single 'step'. - """ + .. note:: + This method is called after `refresh_observation`, so that the + internal buffer 'observation' is up-to-date. - @property - @abstractmethod - def is_training(self) -> bool: - """Check whether the environment is in 'train' or 'eval' mode. + :param info: Dictionary of extra information for monitoring. + + :returns: terminated and truncated flags. """ @abstractmethod @@ -386,3 +383,22 @@ def eval(self) -> None: time specifically. See documentations of a given environment for details about their behaviors in training and evaluation modes. """ + + @property + @abstractmethod + def unwrapped(self) -> "BaseJiminyEnv": + """The "underlying environment at the basis of the pipeline from which + this environment is part of. + """ + + @property + @abstractmethod + def step_dt(self) -> float: + """Get timestep of a single 'step'. + """ + + @property + @abstractmethod + def is_training(self) -> bool: + """Check whether the environment is in 'train' or 'eval' mode. + """ diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py index 4d7164878..dfaebb6dc 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py @@ -17,8 +17,8 @@ from abc import abstractmethod from collections import OrderedDict from typing import ( - Dict, Any, List, Optional, Tuple, Union, Generic, TypeVar, SupportsFloat, - Callable, cast, TYPE_CHECKING) + Dict, Any, List, Sequence, Optional, Tuple, Union, Generic, TypeVar, + SupportsFloat, Callable, cast, TYPE_CHECKING) import numpy as np @@ -35,7 +35,7 @@ InfoType, EngineObsType, InterfaceJiminyEnv) -from .compositions import AbstractReward +from .compositions import AbstractReward, AbstractTerminationCondition from .blocks import BaseControllerBlock, BaseObserverBlock from ..utils import DataNested, is_breakpoint, zeros, build_copyto, copy @@ -115,7 +115,8 @@ def __getattr__(self, name: str) -> Any: Calling this method in script mode while a simulation is already running would trigger a warning to avoid relying on it by mistake. """ - if self.is_simulation_running and not hasattr(sys, 'ps1'): + if (self.is_simulation_running and self.env.is_training and + not hasattr(sys, 'ps1')): # `hasattr(sys, 'ps1')` is used to detect whether the method was # called from an interpreter or within a script. For details, see: # https://stackoverflow.com/a/64523765/4820605 @@ -301,6 +302,24 @@ def _setup(self) -> None: # Initialize specialized operator(s) for efficiency self._copyto_action = build_copyto(self.action) + def has_terminated(self, info: InfoType) -> Tuple[bool, bool]: + """Determine whether the episode is over, because a terminal state of + the underlying MDP has been reached or an aborting condition outside + the scope of the MDP has been triggered. + + By default, it does nothing but forwarding the request to the base + environment. This behavior can be overwritten by the user. + + .. note:: + This method is called after `refresh_observation`, so that the + internal buffer 'observation' is up-to-date. + + :param info: Dictionary of extra information for monitoring. + + :returns: terminated and truncated flags. + """ + return self.env.has_terminated(info) + def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: """Render the unified environment. @@ -334,32 +353,46 @@ class ComposedJiminyEnv( considered as internal unlike `gym.Wrapper`. This means that it will be taken into account when calling `evaluate` or `play_interactive` on the wrapped environment. + + .. warning:: + This class is final, ie not meant to be derived. """ def __init__(self, env: InterfaceJiminyEnv[ObsT, ActT], *, reward: Optional[AbstractReward] = None, + terminations: Sequence[AbstractTerminationCondition] = (), trajectories: Optional[Dict[str, Trajectory]] = None) -> None: """ :param env: Environment to extend, eventually already wrapped. :param reward: Reward object deriving from `AbstractReward`. It will be evaluated at each step of the environment and summed up - with one returned by the wrapped environment. This + with the one returned by the wrapped environment. This reward must be already instantiated and associated with the provided environment. `None` for not considering any reward. Optional: `None` by default. + :param terminations: Sequence of termination condition objects deriving + from `AbstractTerminationCondition`. They will be + checked along with the one enforced by the wrapped + environment. If provided, these termination + conditions must be already instantiated and + associated with the environment at hands. + Optional: Empty sequence by default. :param trajectories: Set of named trajectories as a dictionary whose (key, value) pairs are respectively the name of each trajectory and the trajectory itself. `None` for not considering any trajectory. Optional: `None` by default. """ - # Make sure that the unwrapped environment matches the reward one + # Make sure that the unwrapped environment of compositions matches assert reward is None or env.unwrapped is reward.env.unwrapped + assert all(env.unwrapped is termination.env.unwrapped + for termination in terminations) # Backup user argument(s) self.reward = reward + self.terminations = tuple(terminations) # Initialize base class super().__init__(env) @@ -414,6 +447,43 @@ def refresh_observation(self, measurement: EngineObsType) -> None: """ self.env.refresh_observation(measurement) + def has_terminated(self, info: InfoType) -> Tuple[bool, bool]: + """Determine whether the practitioner is instructed to stop the ongoing + episode on the spot because a termination condition has been triggered, + either coming from the based environment or from the ad-hoc termination + conditions that has been plugged on top of it. + + At each step of the wrapped environment, all its termination conditions + will be evaluated sequentially until one of them eventually gets + triggered. If this happens, evaluation is skipped for the remaining + ones and the reward is evaluated straight away. Ultimately, the + practitioner is instructed to stop the ongoing episode, but it is his + own responsibility to honor this request. The first condition being + evaluated is the one of the underlying environment, then comes the ones + of this composition layer, following the original sequence ordering. + + .. note:: + This method is called after `refresh_observation`, so that the + internal buffer 'observation' is up-to-date. + + .. seealso:: + See `InterfaceJiminyEnv.has_terminated` documentation for details. + + :param info: Dictionary of extra information for monitoring. + + :returns: terminated and truncated flags. + """ + # Call unwrapped environment implementation + terminated, truncated = self.env.has_terminated(info) + + # Evaluate conditions one-by-one as long as none has been triggered + for termination in self.terminations: + if terminated or truncated: + break + terminated, truncated = termination(info) + + return terminated, truncated + def compute_command(self, action: ActT, command: np.ndarray) -> None: """Compute the motors efforts to apply on the robot. @@ -426,9 +496,29 @@ def compute_command(self, action: ActT, command: np.ndarray) -> None: self.env.compute_command(action, command) def compute_reward(self, terminated: bool, info: InfoType) -> float: - if self.reward is None: - return 0.0 - return self.reward(terminated, info) + """Compute the total reward, ie the sum of the original reward from the + wrapped environment with the ad-hoc reward components that has been + plugged on top of it. + + .. seealso:: + See `InterfaceController.compute_reward` documentation for details. + + :param terminated: Whether the episode has reached the terminal state + of the MDP at the current step. This flag can be + used to compute a specific terminal reward. + :param info: Dictionary of extra information for monitoring. + + :returns: Aggregated reward for the current step. + """ + # Compute base reward + reward = self.env.compute_reward(terminated, info) + + # Add composed reward if any + if self.reward is not None: + reward += self.reward(terminated, info) + + # Return total reward + return reward class ObservedJiminyEnv( diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py b/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py index 352b188a7..c07234e27 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py @@ -14,16 +14,16 @@ """ import re import weakref -from enum import Enum +from enum import IntEnum from weakref import ReferenceType from abc import ABC, abstractmethod from collections import OrderedDict from collections.abc import MutableSet from dataclasses import dataclass, replace -from functools import partial, wraps +from functools import wraps from typing import ( Any, Dict, List, Optional, Tuple, Generic, TypeVar, Type, Iterator, - Callable, Literal, ClassVar, cast) + Collection, Callable, Literal, ClassVar, TYPE_CHECKING) import numpy as np @@ -47,6 +47,9 @@ class WeakMutableCollection(MutableSet, Generic[ValueT]): Internally, it is implemented as a set for which uniqueness is characterized by identity instead of equality operator. """ + + __slots__ = ("_callback", "_weakrefs") + def __init__(self, callback: Optional[Callable[[ "WeakMutableCollection[ValueT]", ReferenceType ], None]] = None) -> None: @@ -56,7 +59,7 @@ def __init__(self, callback: Optional[Callable[[ Optional: None by default. """ self._callback = callback - self._ref_list: List[ReferenceType] = [] + self._weakrefs: List[ReferenceType] = [] def __callback__(self, ref: ReferenceType) -> None: """Internal method that will be called every time an element must be @@ -73,9 +76,9 @@ def __callback__(self, ref: ReferenceType) -> None: # actually the right weak reference since all of them will be removed # in the end, so it is not a big deal. value = ref() - for i, ref_i in enumerate(self._ref_list): + for i, ref_i in enumerate(self._weakrefs): if value is ref_i(): - del self._ref_list[i] + del self._weakrefs[i] break if self._callback is not None: self._callback(self, ref) @@ -87,13 +90,13 @@ def __contains__(self, obj: Any) -> bool: :param obj: Object to look for in the container. """ - return any(ref() is obj for ref in self._ref_list) + return any(ref() is obj for ref in self._weakrefs) def __iter__(self) -> Iterator[ValueT]: """Dunder method that returns an iterator over the objects of the - container for which a string reference still exist. + container for which a reference still exist. """ - for ref in self._ref_list: + for ref in self._weakrefs: obj = ref() if obj is not None: yield obj @@ -101,7 +104,7 @@ def __iter__(self) -> Iterator[ValueT]: def __len__(self) -> int: """Dunder method that returns the length of the container. """ - return len(self._ref_list) + return len(self._weakrefs) def add(self, value: ValueT) -> None: """Add a new element to the container if not already contained. @@ -111,7 +114,7 @@ def add(self, value: ValueT) -> None: :param obj: Object to add to the container. """ if value not in self: - self._ref_list.append(weakref.ref(value, self.__callback__)) + self._weakrefs.append(weakref.ref(value, self.__callback__)) def discard(self, value: ValueT) -> None: """Remove an element from the container if stored in it. @@ -124,6 +127,34 @@ def discard(self, value: ValueT) -> None: self.__callback__(weakref.ref(value)) +class QuantityStateMachine(IntEnum): + """Specify the current state of a given (unique) quantity, which determines + the steps to perform for retrieving its current value. + """ + + IS_RESET = 0 + """The quantity at hands has just been reset. The quantity must first be + initialized, then refreshed and finally stored in cached before to retrieve + its value. + """ + + IS_INITIALIZED = 1 + """The quantity at hands has been initialized but never evaluated for the + current robot state. Its value must still be refreshed and stored in cache + before to retrieve it. + """ + + IS_CACHED = 2 + """The quantity at hands has been evaluated and its value stored in cache. + As such, its value can be retrieve from cache directly. + """ + + +# Define proxies for fast lookup +_IS_RESET, _IS_INITIALIZED, _IS_CACHED = ( # pylint: disable=invalid-name + QuantityStateMachine) + + class SharedCache(Generic[ValueT]): """Basic thread local shared cache. @@ -135,7 +166,10 @@ class SharedCache(Generic[ValueT]): This implementation is not thread safe. """ - owners: WeakMutableCollection["InterfaceQuantity[ValueT]"] + __slots__ = ( + "_value", "_weakrefs", "_owner", "_auto_refresh", "sm_state", "owners") + + owners: Collection["InterfaceQuantity[ValueT]"] """Owners of the shared buffer, ie quantities relying on it to store the result of their evaluation. This information may be useful for determining the most efficient computation path overall. @@ -157,72 +191,174 @@ def __init__(self) -> None: # Cached value if any self._value: Optional[ValueT] = None - # Whether a value is stored in cached - self._has_value: bool = False + # Whether auto-refresh is requested + self._auto_refresh = True + + # Basic state machine management + self.sm_state: QuantityStateMachine = QuantityStateMachine.IS_RESET # Initialize "owners" of the shared buffer. # Define callback to reset part of the computation graph whenever a # quantity owning the cache gets garbage collected, namely all # quantities that may assume at some point the existence of this - # deleted owner to find the adjust their computation path. - def _callback(self: WeakMutableCollection["InterfaceQuantity"], - ref: ReferenceType # pylint: disable=unused-argument - ) -> None: - owner: Optional["InterfaceQuantity"] + # deleted owner to adjust their computation path. + def _callback( + self: WeakMutableCollection["InterfaceQuantity[ValueT]"], + ref: ReferenceType) -> None: # pylint: disable=unused-argument + owner: Optional["InterfaceQuantity[ValueT]"] for owner in self: # Stop going up in parent chain if dynamic computation graph # update is disable for efficiency. - while owner.allow_update_graph and owner.parent is not None: + while (owner.allow_update_graph and + owner.parent is not None and owner.parent.has_cache): owner = owner.parent - owner.reset(reset_tracking=True, - ignore_auto_refresh=True, - update_graph=True) + owner.reset(reset_tracking=True) - self.owners = WeakMutableCollection(_callback) + # Initialize weak reference to owning quantities + self._weakrefs = WeakMutableCollection(_callback) - @property - def has_value(self) -> bool: - """Whether a value is stored in cache. + # Maintain alive owning quantities upon reset + self.owners = self._weakrefs + self._owner: Optional["InterfaceQuantity[ValueT]"] = None + + def add(self, owner: "InterfaceQuantity[ValueT]") -> None: + """Add a given quantity instance to the set of co-owners associated + with the shared cache at hands. + + .. warning:: + All shared cache co-owners must be instances of the same unique + quantity. An exception will be thrown if an attempt is made to add + a quantity instance that does not satisfy this condition. + + :param owner: Quantity instance to add to the set of co-owners. """ - return self._has_value + # Make sure that the quantity is not already part of the co-owners + if id(owner) in map(id, self.owners): + raise ValueError( + "The specified quantity instance is already an owner of this " + "shared cache.") + + # Make sure that the new owner is consistent with the others if any + if any(owner != _owner for _owner in self._weakrefs): + raise ValueError( + "Quantity instance inconsistent with already existing shared " + "cache owners.") - def reset(self, *, ignore_auto_refresh: bool = False) -> None: + # Add quantity instance to shared cache owners + self._weakrefs.add(owner) + + # Refresh owners + if self.sm_state is not QuantityStateMachine.IS_RESET: + self.owners = tuple(self._weakrefs) + + def discard(self, owner: "InterfaceQuantity[ValueT]") -> None: + """Remove a given quantity instance from the set of co-owners + associated with the shared cache at hands. + + :param owner: Quantity instance to remove from the set of co-owners. + """ + # Make sure that the quantity is part of the co-owners + if id(owner) not in map(id, self.owners): + raise ValueError( + "The specified quantity instance is not an owner of this " + "shared cache.") + + # Restore "dynamic" owner list as it may be involved in quantity reset + self.owners = self._weakrefs + + # Remove quantity instance from shared cache owners + self._weakrefs.discard(owner) + + # Refresh owners. + # Note that one must keep tracking the quantity instance being used in + # computations, aka 'self._owner', even if it is no longer an actual + # shared cache owner. This is necessary because updating it would + # require resetting the state machine, which is not an option as it + # would mess up with quantities storing history since initialization. + if self.sm_state is not QuantityStateMachine.IS_RESET: + self.owners = tuple(self._weakrefs) + + def reset(self, + ignore_auto_refresh: bool = False, + reset_state_machine: bool = False) -> None: """Clear value stored in cache if any. :param ignore_auto_refresh: Whether to skip automatic refresh of all co-owner quantities of this shared cache. Optional: False by default. + :param reset_state_machine: Whether to reset completely the state + machine of the underlying quantity, ie not + considering it initialized anymore. + Optional: False by default. """ # Clear cache - self._value = None - self._has_value = False + if self.sm_state is _IS_CACHED: + self.sm_state = _IS_INITIALIZED + + # Special branch if case quantities must be reset on the way + if reset_state_machine: + # Reset the state machine completely + self.sm_state = _IS_RESET + + # Update list of owning quantities + self.owners = self._weakrefs + self._owner = None - # Refresh automatically if any cache owner requested it and not ignored - if not ignore_auto_refresh: + # Reset auto-refresh buffer + self._auto_refresh = True + + # Refresh automatically if not already proven useless and not ignored + if not ignore_auto_refresh and self._auto_refresh: for owner in self.owners: if owner.auto_refresh: owner.get() break + else: + self._auto_refresh = False - def set(self, value: ValueT) -> None: - """Set value in cache, silently overriding the existing value if any. + def get(self) -> ValueT: + """Return cached value if any, otherwise evaluate it and store it. + """ + # Get value already stored + if self.sm_state is _IS_CACHED: + # return cast(ValueT, self._value) + return self._value # type: ignore[return-value] - .. warning: - Beware the value is stored by reference for efficiency. It is up to - the user to copy it if necessary. + # Evaluate quantity + try: + if self.sm_state is _IS_RESET: + # Cache the list of owning quantities + self.owners = tuple(self._weakrefs) - :param value: Value to store in cache. - """ - self._value = value - self._has_value = True + # Stick to the first owning quantity systematically + owner = self.owners[0] + self._owner = owner - def get(self) -> ValueT: - """Return cached value if any, otherwise raises an exception. - """ - if self._has_value: - return cast(ValueT, self._value) - raise ValueError( - "No value has been stored. Please call 'set' before 'get'.") + # Initialize quantity if not already done manually + if not owner._is_initialized: + owner.initialize() + assert owner._is_initialized + + # Get first owning quantity systematically + # assert self._owner is not None + owner = self._owner # type: ignore[assignment] + + # Make sure that the state has been refreshed + if owner._force_update_state: + owner.state.get() + + # Refresh quantity + value = owner.refresh() + except RecursionError as e: + raise LookupError( + "Mutual dependency between quantities is disallowed.") from e + + # Update state machine + self.sm_state = _IS_CACHED + + # Return value after storing it + self._value = value + return value class InterfaceQuantity(ABC, Generic[ValueT]): @@ -249,7 +385,7 @@ class InterfaceQuantity(ABC, Generic[ValueT]): requirements: Dict[str, "InterfaceQuantity"] """Intermediary quantities on which this quantity may rely on for its evaluation at some point, depending on the optimal computation path at - runtime. There values will be exposed to the user as usual properties. + runtime. They will be exposed to the user as usual attributes. """ allow_update_graph: ClassVar[bool] = True @@ -257,8 +393,8 @@ class InterfaceQuantity(ABC, Generic[ValueT]): the quantity can be reset at any point in time to re-compute the optimal computation path, typically after deletion or addition of some other node to its dependent sub-graph. When this happens, the quantity gets reset on - the spot, which is not always acceptable, hence the capability to disable - this feature. + the spot, even if a simulation is already running. This is not always + acceptable, hence the capability to disable this feature at class-level. """ def __init__(self, @@ -274,8 +410,8 @@ def __init__(self, :param requirements: Intermediary quantities on which this quantity depends for its evaluation, as a dictionary whose keys are tuple gathering their respective - class and all their constructor keyword-arguments - except environment 'env' and parent 'parent. + class plus any keyword-arguments of its + constructor except 'env' and 'parent'. :param auto_refresh: Whether this quantity must be refreshed automatically as soon as its shared cache has been cleared if specified, otherwise this does nothing. @@ -296,13 +432,20 @@ class and all their constructor keyword-arguments name: cls(env, self, **kwargs) for name, (cls, kwargs) in requirements.items()} + # Define proxies for user-specified intermediary quantities. + # This approach is much faster than hidding quantities behind value + # getters. In particular, dynamically adding properties, which is hacky + # but which is the fastest alternative option, still adds 35% overhead + # on Python 3.11 compared to calling `get` directly. The "official" + # approaches are even slower, ie implementing custom `__getattribute__` + # method or worst custom `__getattr__` method. + for name, quantity in self.requirements.items(): + setattr(self, name, quantity) + # Update the state explicitly if available but auto-refresh not enabled - self._state: Optional[StateQuantity] = None + self._force_update_state = False if isinstance(self, AbstractQuantity): - quantity = self.requirements["state"] - if not quantity.auto_refresh: - assert isinstance(quantity, StateQuantity) - self._state = quantity + self._force_update_state = not self.state.auto_refresh # Shared cache handling self._cache: Optional[SharedCache[ValueT]] = None @@ -314,41 +457,23 @@ class and all their constructor keyword-arguments # Whether the quantity must be re-initialized self._is_initialized: bool = False - # Add getter dynamically for user-specified intermediary quantities. - # This approach is hacky but much faster than any of other official - # approach, ie implementing custom a `__getattribute__` method or even - # worst a custom `__getattr__` method. - def get_value(name: str, quantity: InterfaceQuantity) -> Any: - return quantity.requirements[name].get() + if TYPE_CHECKING: + def __getattr__(self, name: str) -> Any: + """Get access to intermediary quantities as first-class properties, + without having to do it through `requirements`. - for name in requirement_names: - setattr(type(self), name, property(partial(get_value, name))) - - def __getattr__(self, name: str) -> Any: - """Get access to intermediary quantities as first-class properties, - without having to do it through `requirements`. - - .. warning:: - Accessing quantities this way is convenient, but unfortunately - much slower than do it through `requirements` manually. As a - result, this approach is mainly intended for ease of use while - prototyping. + .. warning:: + Accessing quantities this way is convenient, but unfortunately + much slower than do it through dynamically added properties. As + a result, this approach is only used to fix typing issues. - :param name: Name of the requested quantity. - """ - try: - return self.__getattribute__('requirements')[name].get() - except KeyError as e: - raise AttributeError( - f"'{type(self)}' object has no attribute '{name}'") from e - - def __dir__(self) -> List[str]: - """Attribute lookup. - - It is mainly used by autocomplete feature of Ipython. It is overloaded - to get consistent autocompletion wrt `getattr`. - """ - return [*super().__dir__(), *self.requirements.keys()] + :param name: Name of the requested quantity. + """ + try: + return self.__getattribute__('requirements')[name].get() + except KeyError as e: + raise AttributeError( + f"'{type(self)}' object has no attribute '{name}'") from e @property def cache(self) -> SharedCache[ValueT]: @@ -381,7 +506,7 @@ def cache(self, cache: Optional[SharedCache[ValueT]]) -> None: # Withdraw this quantity from the owners of its current cache if any if self._cache is not None: try: - self._cache.owners.discard(self) + self._cache.discard(self) except ValueError: # This may fail if the quantity is already being garbage # collected when clearing cache. @@ -389,7 +514,7 @@ def cache(self, cache: Optional[SharedCache[ValueT]]) -> None: # Declare this quantity as owner of the cache if specified if cache is not None: - cache.owners.add(self) + cache.add(self) # Update internal cache attribute self._cache = cache @@ -419,43 +544,35 @@ def get(self) -> ValueT: .. warning:: This method is not meant to be overloaded. """ - # Get value in cache if available. - # Note that direct access to internal `_value` attribute is preferred - # over the public API `get` for speedup. The same cannot be done for - # `has_value` as it would prevent mocking it during running unit tests - # or benchmarks. - if (self.has_cache and - self._cache.has_value): # type: ignore[union-attr] + # Delegate getting value to shared cache if available + if self._cache is not None: + # Get value + value = self._cache.get() + + # This instance is not forceably considered active at this point. + # Note that it must be done AFTER getting the value, otherwise it + # would mess up with computation graph tracking at initialization. self._is_active = True - return self._cache._value # type: ignore[union-attr,return-value] + + # Return cached value + return value # Evaluate quantity try: # Initialize quantity if not self._is_initialized: self.initialize() - assert (self._is_initialized and - self._is_active) # type: ignore[unreachable] - - # Make sure that the state has been refreshed - if self._state is not None: - self._state.get() + assert self._is_initialized # Refresh quantity - value = self.refresh() + return self.refresh() except RecursionError as e: raise LookupError( "Mutual dependency between quantities is disallowed.") from e - # Return value after storing it in shared cache if available - if self.has_cache: - self._cache.set(value) # type: ignore[union-attr] - return value - def reset(self, reset_tracking: bool = False, - ignore_auto_refresh: bool = False, - update_graph: bool = False) -> None: + *, ignore_other_instances: bool = False) -> None: """Consider that the quantity must be re-initialized before being evaluated once again. @@ -473,53 +590,48 @@ def reset(self, :param reset_tracking: Do not consider this quantity as active anymore until the `get` method gets called once again. Optional: False by default. - :param ignore_auto_refresh: Whether to skip automatic refresh of all - co-owner quantities of this shared cache. - Optional: False by default. - :param update_graph: If true, then the quantity will be reset if and - only if dynamic computation graph update is - allowed as prescribed by class attribute - `allow_update_graph`. If false, then it will be - reset no matter what. + :param ignore_other_instances: + Whether to skip reset of intermediary quantities as well as any + shared cache co-owner quantity instances. + Optional: False by default. """ # Make sure that auto-refresh can be honored - if (not ignore_auto_refresh and self.auto_refresh and - not self.has_cache): + if self.auto_refresh and not self.has_cache: raise RuntimeError( "Automatic refresh enabled but no shared cache is available. " "Please add one before calling this method.") # Reset all requirements first - for quantity in self.requirements.values(): - quantity.reset(reset_tracking, ignore_auto_refresh, update_graph) + if not ignore_other_instances: + for quantity in self.requirements.values(): + quantity.reset(reset_tracking, ignore_other_instances=False) - # Skip reset if dynamic computation graph update if appropriate - if update_graph and not self.allow_update_graph: + # Skip reset if dynamic computation graph update is not allowed + if self.env.is_simulation_running and not self.allow_update_graph: return - # No longer consider this exact instance as active if requested + # No longer consider this exact instance as active if reset_tracking: self._is_active = False # No longer consider this exact instance as initialized self._is_initialized = False - # More work must to be done if shared cache is available and has value + # More work must to be done if shared cache if appropriate if self.has_cache: - # Early return if shared cache has no value - if not self.cache.has_value: - return - - # Invalidate cache before looping over all identical properties. - # Note that auto-refresh must be ignored to avoid infinite loop. - self.cache.reset(ignore_auto_refresh=True) - - # Reset all identical quantities - for owner in self.cache.owners: - owner.reset(ignore_auto_refresh=True) - - # Reset shared cache one last time but without ignore auto refresh - self.cache.reset(ignore_auto_refresh=ignore_auto_refresh) + # Reset all identical quantities. + # Note that auto-refresh will be done afterward if requested. + if not ignore_other_instances: + for owner in self.cache.owners: + if owner is not self: + owner.reset(reset_tracking=reset_tracking, + ignore_other_instances=True) + + # Reset shared cache + # pylint: disable=unexpected-keyword-arg + self.cache.reset( + ignore_auto_refresh=not self.env.is_simulation_running, + reset_state_machine=True) def initialize(self) -> None: """Initialize internal buffers. @@ -554,10 +666,12 @@ def refresh(self) -> ValueT: """ -QuantityCreator = Tuple[Type[InterfaceQuantity[ValueT]], Dict[str, Any]] +QuantityValueT_co = TypeVar('QuantityValueT_co', covariant=True) +QuantityCreator = Tuple[ + Type[InterfaceQuantity[QuantityValueT_co]], Dict[str, Any]] -class QuantityEvalMode(Enum): +class QuantityEvalMode(IntEnum): """Specify on which state to evaluate a given quantity. """ @@ -570,6 +684,10 @@ class QuantityEvalMode(Enum): """ +# Define proxies for fast lookup +_TRUE, _REFERENCE = QuantityEvalMode + + @dataclass(unsafe_hash=True) class AbstractQuantity(InterfaceQuantity, Generic[ValueT]): """Base class for generic quantities involved observer-controller blocks, @@ -578,16 +696,16 @@ class AbstractQuantity(InterfaceQuantity, Generic[ValueT]): .. note:: A dataset of trajectories made available through `self.trajectories`. The latter is synchronized because all quantities as long as shared - cached is available. Since the dataset is initially empty by default, - using `QuantityEvalMode.REFERENCE` evaluation mode requires manually - adding at least one trajectory to the dataset and selecting it. + cached is available. At least one trajectory must be added to the + dataset and selected prior to using `QuantityEvalMode.REFERENCE` + evaluation mode since the dataset is initially empty by default. .. seealso:: See `InterfaceQuantity` documentation for details. """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -609,14 +727,14 @@ def __init__(self, :param requirements: Intermediary quantities on which this quantity depends for its evaluation, as a dictionary whose keys are tuple gathering their respective - class and all their constructor keyword-arguments - except environment 'env' and parent 'parent. + class plus any keyword-arguments of its + constructor except 'env' and 'parent'. :param mode: Desired mode of evaluation for this quantity. If mode is set to `QuantityEvalMode.TRUE`, then current simulation state will be used in dynamics computations. If mode is - set to `QuantityEvalMode.REFERENCE`, then at the state of - some reference trajectory at the current simulation time - will be used instead. + set to `QuantityEvalMode.REFERENCE`, then the state at the + current simulation time of the selected reference + trajectory will be used instead. :param auto_refresh: Whether this quantity must be refreshed automatically as soon as its shared cache has been cleared if specified, otherwise this does nothing. @@ -648,7 +766,7 @@ class and all their constructor keyword-arguments super().__init__(env, parent, requirements, auto_refresh=auto_refresh) # Add trajectory quantity proxy - trajectory = self.requirements["state"].requirements["trajectory"] + trajectory = self.state.trajectory assert isinstance(trajectory, DatasetTrajectoryQuantity) self.trajectory = trajectory @@ -662,14 +780,13 @@ def initialize(self) -> None: super().initialize() # Force initializing state quantity - state = self.requirements["state"] - state.initialize() + self.state.initialize() # Refresh robot proxy - assert isinstance(state, StateQuantity) - self.robot = state.robot - self.pinocchio_model = state.pinocchio_model - self.pinocchio_data = state.pinocchio_data + assert isinstance(self.state, StateQuantity) + self.robot = self.state.robot + self.pinocchio_model = self.state.pinocchio_model + self.pinocchio_data = self.state.pinocchio_data def sync(fun: Callable[..., None]) -> Callable[..., None]: @@ -918,7 +1035,7 @@ class StateQuantity(InterfaceQuantity[State]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -1002,17 +1119,14 @@ def __init__(self, # only refresh the state when needed if the evaluation mode is TRAJ. # * Update state: 500ns (TRUE) | 5.0us (TRAJ) # * Check cache state: 70ns - auto_refresh = mode == QuantityEvalMode.TRUE + auto_refresh = mode is QuantityEvalMode.TRUE # Call base implementation. super().__init__( - env, parent, requirements={}, auto_refresh=auto_refresh) - - # Create empty trajectory database, manually added as a requirement. - # Note that it must be done after base initialization, otherwise a - # getter will be added for it as first-class property. - self.trajectory = DatasetTrajectoryQuantity(env, self) - self.requirements["trajectory"] = self.trajectory + env, + parent, + requirements=dict(trajectory=(DatasetTrajectoryQuantity, {})), + auto_refresh=auto_refresh) # Robot for which the quantity must be evaluated self.robot = env.robot @@ -1020,13 +1134,13 @@ def __init__(self, self.pinocchio_data = env.robot.pinocchio_data # State for which the quantity must be evaluated - self.state = State(t=np.nan, q=np.array([])) + self._state = State(t=np.nan, q=np.array([])) # Persistent buffer for storing body external forces if necessary self._f_external_vec = pin.StdVec_Force() self._f_external_list: List[np.ndarray] = [] self._f_external_batch = np.array([]) - self._f_external_slices: List[np.ndarray] = [] + self._f_external_slices: Tuple[np.ndarray, ...] = () # Persistent buffer storing all lambda multipliers for efficiency self._constraint_lambda_batch = np.array([]) @@ -1074,7 +1188,7 @@ def initialize(self) -> None: assert isinstance(owner, StateQuantity) if owner._is_initialized: continue - if owner.mode == QuantityEvalMode.TRUE: + if owner.mode is QuantityEvalMode.TRUE: owner.robot = owner.env.robot use_theoretical_model = False else: @@ -1092,7 +1206,7 @@ def initialize(self) -> None: super().initialize() # Refresh proxies and allocate memory for storing external forces - if self.mode == QuantityEvalMode.TRUE: + if self.mode is QuantityEvalMode.TRUE: self._f_external_vec = self.env.robot_state.f_external else: self._f_external_vec = pin.StdVec_Force() @@ -1101,7 +1215,7 @@ def initialize(self) -> None: self._f_external_list = [ f_ext.vector for f_ext in self._f_external_vec] self._f_external_batch = np.zeros((self.pinocchio_model.njoints, 6)) - self._f_external_slices = list(self._f_external_batch) + self._f_external_slices = tuple(self._f_external_batch) # Allocate memory for lambda vector self._constraint_lambda_batch = np.zeros( @@ -1134,11 +1248,11 @@ def initialize(self) -> None: i += constraint.size # Allocate state for which the quantity must be evaluated if needed - if self.mode == QuantityEvalMode.TRUE: + if self.mode is QuantityEvalMode.TRUE: if not self.env.is_simulation_running: raise RuntimeError("No simulation running. Impossible to " "initialize this quantity.") - self.state = State( + self._state = State( 0.0, self.env.robot_state.q, self.env.robot_state.v, @@ -1152,23 +1266,21 @@ def refresh(self) -> State: """Compute the current state depending on the mode of evaluation, and make sure that kinematics and dynamics quantities are up-to-date. """ - # Update state at which the quantity must be evaluated - if self.mode == QuantityEvalMode.TRUE: + if self.mode is _TRUE: # Update the current simulation time - self.state.t = self.env.stepper_state.t + self._state.t = self.env.stepper_state.t # Update external forces and constraint multipliers in state buffer multi_array_copyto(self._f_external_slices, self._f_external_list) multi_array_copyto( self._constraint_lambda_slices, self._constraint_lambda_list) else: - self.state = self.trajectory.get() + self._state = self.trajectory.get() - if self.mode == QuantityEvalMode.REFERENCE: # Copy body external forces from stacked buffer to force vector - has_forces = self.state.f_external is not None + has_forces = self._state.f_external is not None if has_forces: - array_copyto(self._f_external_batch, self.state.f_external) + array_copyto(self._f_external_batch, self._state.f_external) multi_array_copyto(self._f_external_list, self._f_external_slices) @@ -1176,9 +1288,9 @@ def refresh(self) -> State: if self.update_kinematics: update_quantities( self.robot, - self.state.q, - self.state.v, - self.state.a, + self._state.q, + self._state.v, + self._state.a, self._f_external_vec if has_forces else None, update_dynamics=self._update_dynamics, update_centroidal=self._update_centroidal, @@ -1189,12 +1301,11 @@ def refresh(self) -> State: self.trajectory.use_theoretical_model)) # Restore lagrangian multipliers of the constraints if available - has_constraints = self.state.lambda_c is not None - if has_constraints: + if self._state.lambda_c is not None: array_copyto( - self._constraint_lambda_batch, self.state.lambda_c) + self._constraint_lambda_batch, self._state.lambda_c) multi_array_copyto(self._constraint_lambda_list, self._constraint_lambda_slices) # Return state - return self.state + return self._state diff --git a/python/gym_jiminy/common/gym_jiminy/common/compositions/__init__.py b/python/gym_jiminy/common/gym_jiminy/common/compositions/__init__.py index c82060f4c..cf0778b92 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/compositions/__init__.py +++ b/python/gym_jiminy/common/gym_jiminy/common/compositions/__init__.py @@ -4,24 +4,41 @@ radial_basis_function, AdditiveMixtureReward, MultiplicativeMixtureReward) -from .generic import (BaseTrackingReward, +from .generic import (SurviveReward, + TrackingQuantityReward, TrackingActuatedJointPositionsReward, - SurviveReward) + DriftTrackingQuantityTermination, + ShiftTrackingQuantityTermination, + MechanicalSafetyTermination, + MechanicalPowerConsumptionTermination, + ShiftTrackingMotorPositionsTermination) from .locomotion import (TrackingBaseHeightReward, TrackingBaseOdometryVelocityReward, TrackingCapturePointReward, TrackingFootPositionsReward, TrackingFootOrientationsReward, TrackingFootForceDistributionReward, + DriftTrackingBaseOdometryPositionTermination, + DriftTrackingBaseOdometryOrientationTermination, + ShiftTrackingFootOdometryPositionsTermination, + ShiftTrackingFootOdometryOrientationsTermination, MinimizeAngularMomentumReward, - MinimizeFrictionReward) + MinimizeFrictionReward, + BaseRollPitchTermination, + FallingTermination, + FootCollisionTermination, + FlyingTermination, + ImpactForceTermination) __all__ = [ "CUTOFF_ESP", "radial_basis_function", "AdditiveMixtureReward", "MultiplicativeMixtureReward", - "BaseTrackingReward", + "SurviveReward", + "MinimizeFrictionReward", + "MinimizeAngularMomentumReward", + "TrackingQuantityReward", "TrackingActuatedJointPositionsReward", "TrackingBaseHeightReward", "TrackingBaseOdometryVelocityReward", @@ -29,7 +46,18 @@ "TrackingFootPositionsReward", "TrackingFootOrientationsReward", "TrackingFootForceDistributionReward", - "MinimizeAngularMomentumReward", - "MinimizeFrictionReward", - "SurviveReward" + "DriftTrackingQuantityTermination", + "DriftTrackingBaseOdometryPositionTermination", + "DriftTrackingBaseOdometryOrientationTermination", + "ShiftTrackingQuantityTermination", + "ShiftTrackingMotorPositionsTermination", + "ShiftTrackingFootOdometryPositionsTermination", + "ShiftTrackingFootOdometryOrientationsTermination", + "MechanicalSafetyTermination", + "MechanicalPowerConsumptionTermination", + "FlyingTermination", + "BaseRollPitchTermination", + "FallingTermination", + "FootCollisionTermination", + "ImpactForceTermination" ] diff --git a/python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py b/python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py index 0de5a6b0d..7852688cd 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py @@ -4,12 +4,21 @@ """ from operator import sub from functools import partial -from typing import Optional, Callable, TypeVar +from dataclasses import dataclass +from typing import Optional, Callable, Tuple, TypeVar + +import numpy as np +import numba as nb +import pinocchio as pin from ..bases import ( - InfoType, QuantityCreator, InterfaceJiminyEnv, AbstractReward, - BaseQuantityReward, QuantityEvalMode) -from ..quantities import BinaryOpQuantity, ActuatedJointsPosition + InfoType, QuantityCreator, InterfaceJiminyEnv, InterfaceQuantity, + QuantityEvalMode, AbstractReward, QuantityReward, + AbstractTerminationCondition, QuantityTermination) +from ..bases.compositions import ArrayOrScalar, ArrayLikeOrScalar +from ..quantities import ( + EnergyGenerationMode, StackedQuantity, UnaryOpQuantity, BinaryOpQuantity, + MultiActuatedJointKinematic, AverageMechanicalPowerConsumption) from .mixin import radial_basis_function @@ -46,7 +55,7 @@ def compute(self, terminated: bool, info: InfoType) -> Optional[float]: return 1.0 -class BaseTrackingReward(BaseQuantityReward): +class TrackingQuantityReward(QuantityReward): """Base class from which to derive reward defined as a difference between the current and reference value of a given quantity. @@ -78,17 +87,17 @@ def __init__(self, :param quantity_creator: Any callable taking a quantity evaluation mode as input argument and return a tuple gathering the class of the underlying quantity to use as - reward after some post-processing, plus all - its constructor keyword-arguments except - environment 'env', parent 'parent. + reward after some post-processing, plus any + keyword-arguments of its constructor except + 'env' and 'parent'. :param cutoff: Cutoff threshold for the RBF kernel transform. :param op: Any callable taking the true and reference values of the quantity as input argument and returning the difference between them, considering the algebra defined by their Lie Group. The basic subtraction operator `operator.sub` is - appropriate for Euclidean. + appropriate for the Euclidean space. Optional: `operator.sub` by default. - :param order: Order of Lp-Norm that will be used as distance metric. + :param order: Order of L^p-norm that will be used as distance metric. Optional: 2 by default. """ # Backup some user argument(s) @@ -107,12 +116,12 @@ def __init__(self, is_terminal=False) -class TrackingActuatedJointPositionsReward(BaseTrackingReward): +class TrackingActuatedJointPositionsReward(TrackingQuantityReward): """Reward the agent for tracking the position of all the actuated joints of the robot wrt some reference trajectory. .. seealso:: - See `BaseTrackingReward` documentation for technical details. + See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, @@ -128,5 +137,529 @@ def __init__(self, super().__init__( env, "reward_actuated_joint_positions", - lambda mode: (ActuatedJointsPosition, dict(mode=mode)), + lambda mode: (MultiActuatedJointKinematic, dict( + kinematic_level=pin.KinematicLevel.POSITION, + is_motor_side=False, + mode=mode)), cutoff) + + +class DriftTrackingQuantityTermination(QuantityTermination): + """Base class to derive termination condition from the difference between + the current and reference drift of a given quantity. + + The drift is defined as the difference between the most recent and oldest + values of a time series. In this case, a variable-length horizon bounded by + 'max_stack' is considered. + + All elements must be within bounds for at least one time step in the fixed + horizon. If so, then the episode continues, otherwise it is either + truncated or terminated according to 'is_truncation' constructor argument. + This only applies after the end of a grace period. Before that, the episode + continues no matter what. + """ + def __init__(self, + env: InterfaceJiminyEnv, + name: str, + quantity_creator: Callable[ + [QuantityEvalMode], QuantityCreator[ArrayOrScalar]], + low: Optional[ArrayLikeOrScalar], + high: Optional[ArrayLikeOrScalar], + horizon: float, + grace_period: float = 0.0, + *, + op: Callable[ + [ArrayOrScalar, ArrayOrScalar], ArrayOrScalar] = sub, + post_fn: Optional[Callable[ + [ArrayOrScalar], ArrayOrScalar]] = None, + is_truncation: bool = False, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param name: Desired name of the termination condition. This name will + be used as key for storing the current episode state from + the perspective of this specific condition in 'info', and + to add the underlying quantity to the set of already + managed quantities by the environment. As a result, it + must be unique otherwise an exception will be raised. + :param quantity_creator: Any callable taking a quantity evaluation mode + as input argument and return a tuple gathering + the class of the underlying quantity to use as + reward after some post-processing, plus any + keyword-arguments of its constructor except + 'env' and 'parent'. + :param low: Lower bound below which termination is triggered. + :param high: Upper bound above which termination is triggered. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the drift. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param op: Any callable taking the true and reference values of the + quantity as input argument and returning the difference + between them, considering the algebra defined by their Lie + Group. The basic subtraction operator `operator.sub` is + appropriate for Euclidean space. + Optional: `operator.sub` by default. + :apram post_fn: Optional callable taking the true and reference drifts + of the quantity as input argument and returning some + post-processed value to which bound checking will be + applied. None to skip post-processing entirely. + Optional: None by default. + :param is_truncation: Whether the episode should be considered + terminated or truncated whenever the termination + condition is triggered. + Optional: False by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + # pylint: disable=unnecessary-lambda-assignment + + # Convert horizon in stack length, assuming constant env timestep + max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + + # Backup user argument(s) + self.max_stack = max_stack + self.op = op + self.post_fn = post_fn + + # Define drift of quantity + stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731 + quantity=quantity_creator(mode), + max_stack=max_stack)) + delta_creator = lambda mode: (BinaryOpQuantity, dict( # noqa: E731 + quantity_left=(UnaryOpQuantity, dict( + quantity=stack_creator(mode), + op=lambda stack: stack[-1])), + quantity_right=(UnaryOpQuantity, dict( + quantity=stack_creator(mode), + op=lambda stack: stack[0])), + op=op)) + + # Add drift quantity to the set of quantities managed by environment + drift_tracking_quantity = (BinaryOpQuantity, dict( + quantity_left=delta_creator(QuantityEvalMode.TRUE), + quantity_right=delta_creator(QuantityEvalMode.REFERENCE), + op=self._compute_drift_error)) + + # Call base implementation + super().__init__(env, + name, + drift_tracking_quantity, # type: ignore[arg-type] + low, + high, + grace_period, + is_truncation=is_truncation, + is_training_only=is_training_only) + + def _compute_drift_error(self, + left: np.ndarray, + right: np.ndarray) -> ArrayOrScalar: + """Compute the difference between the true and reference drift over + a given horizon, then apply some post-processing on it if requested. + + :param left: True value of the drift as a N-dimensional array. + :param right: Reference value of the drift as a N-dimensional array. + """ + diff = left - right + if self.post_fn is not None: + return self.post_fn(diff) + return diff + + +class ShiftTrackingQuantityTermination(QuantityTermination[np.ndarray]): + """Base class to derive termination condition from the shift between the + current and reference values of a given quantity. + + The shift is defined as the minimum time-aligned distance (L^2-norm of the + difference) between two multivariate time series. In this case, a + variable-length horizon bounded by 'max_stack' is considered. + + All elements must be within bounds for at least one time step in the fixed + horizon. If so, then the episode continues, otherwise it is either + truncated or terminated according to 'is_truncation' constructor argument. + This only applies after the end of a grace period. Before that, the episode + continues no matter what. + """ + def __init__(self, + env: InterfaceJiminyEnv, + name: str, + quantity_creator: Callable[ + [QuantityEvalMode], QuantityCreator[ArrayOrScalar]], + thr: float, + horizon: float, + grace_period: float = 0.0, + *, + op: Callable[[np.ndarray, np.ndarray], np.ndarray] = sub, + is_truncation: bool = False, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param name: Desired name of the termination condition. This name will + be used as key for storing the current episode state from + the perspective of this specific condition in 'info', and + to add the underlying quantity to the set of already + managed quantities by the environment. As a result, it + must be unique otherwise an exception will be raised. + :param quantity_creator: Any callable taking a quantity evaluation mode + as input argument and return a tuple gathering + the class of the underlying quantity to use as + reward after some post-processing, plus any + keyword-arguments of its constructor except + 'env' and 'parent'. + :param thr: Termination is triggered if the shift exceeds this + threshold. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the shift. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param op: Any callable taking the true and reference stacked values of + the quantity as input argument and returning the difference + between them, considering the algebra defined by their Lie + Group. True and reference values are stacked in contiguous + N-dimension arrays along the first axis, namely the first + dimension gathers individual timesteps. For instance, the + common subtraction operator `operator.sub` is appropriate + for Euclidean space. + Optional: `operator.sub` by default. + :param order: Order of L^p-norm that will be used as distance metric. + :param is_truncation: Whether the episode should be considered + terminated or truncated whenever the termination + condition is triggered. + Optional: False by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + # pylint: disable=unnecessary-lambda-assignment + + # Convert horizon in stack length, assuming constant env timestep + max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + + # Backup user argument(s) + self.max_stack = max_stack + self.op = op + + # Jit-able method computing minimum distance between two time series + @nb.jit(nopython=True, cache=True) + def min_norm(values: np.ndarray) -> float: + """Compute the minimum Euclidean norm over all timestamps of a + multivariate time series. + + :param values: Time series as a N-dimensional array whose last + dimension corresponds to individual timestamps over + a finite horizon. The value at each timestamp will + be regarded as a 1D vector for computing their + Euclidean norm. + """ + num_times = values.shape[-1] + values_squared_flat = np.square(values).reshape((-1, num_times)) + return np.sqrt(np.min(np.sum(values_squared_flat, axis=0))) + + self._min_norm = min_norm + + # Define drift of quantity + stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731 + quantity=quantity_creator(mode), + max_stack=max_stack, + mode='slice', + as_array=True)) + + # Add drift quantity to the set of quantities managed by environment + shift_tracking_quantity = (BinaryOpQuantity, dict( + quantity_left=stack_creator(QuantityEvalMode.TRUE), + quantity_right=stack_creator(QuantityEvalMode.REFERENCE), + op=self._compute_min_distance)) + + # Call base implementation + super().__init__(env, + name, + shift_tracking_quantity, # type: ignore[arg-type] + None, + np.array(thr), + grace_period, + is_truncation=is_truncation, + is_training_only=is_training_only) + + def _compute_min_distance(self, + left: np.ndarray, + right: np.ndarray) -> float: + """Compute the minimum time-aligned Euclidean distance between two + multivariate time series kept in sync. + + Internally, the time-aligned difference between the two time series + will first be computed according to the user-specified binary operator + 'op'. The classical Euclidean norm of the difference is then computed + over all timestamps individually and the minimum value is returned. + + :param left: Time series as a N-dimensional array whose first dimension + corresponds to individual timestamps over a finite + horizon. The value at each timestamp will be regarded as a + 1D vector for computing their Euclidean norm. It will be + passed as left-hand side of the binary operator 'op'. + :param right: Time series as a N-dimensional array with the exact same + shape as 'left'. See 'left' for details. It will be + passed as right-hand side of the binary operator 'op'. + """ + return self._min_norm(self.op(left, right)) + + +@dataclass(unsafe_hash=True) +class _MultiActuatedJointBoundDistance( + InterfaceQuantity[Tuple[np.ndarray, np.ndarray]]): + """Distance of the actuated joints from their respective lower and upper + mechanical stops. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity]) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param mode: Desired mode of evaluation for this quantity. + """ + # Call base implementation + super().__init__( + env, + parent, + requirements=dict( + position=(MultiActuatedJointKinematic, dict( + kinematic_level=pin.KinematicLevel.POSITION, + is_motor_side=False, + mode=QuantityEvalMode.TRUE))), + auto_refresh=False) + + # Lower and upper bounds of the actuated joints + self.position_low, self.position_high = np.array([]), np.array([]) + + def initialize(self) -> None: + # Call base implementation + super().initialize() + + # Initialize the actuated joint position indices + self.position.initialize() + position_indices = self.position.kinematic_indices + + # Refresh mechanical joint position indices + position_limit_low = self.env.robot.pinocchio_model.lowerPositionLimit + self.position_low = position_limit_low[position_indices] + position_limit_high = self.env.robot.pinocchio_model.upperPositionLimit + self.position_high = position_limit_high[position_indices] + + def refresh(self) -> Tuple[np.ndarray, np.ndarray]: + position = self.position.get() + return (position - self.position_low, self.position_high - position) + + +class MechanicalSafetyTermination(AbstractTerminationCondition): + """Discouraging the agent from hitting the mechanical stops by immediately + terminating the episode if the articulated joints approach them at + excessive speed. + + Hitting the lower and upper mechanical stops is inconvenient but forbidding + it completely is not desirable as it induces safety margins that constrain + the problem too strictly. This is particularly true when the maximum motor + torque becomes increasingly limited and PD controllers are being used for + low-level motor control, which turns out to be the case in most instances. + Overall, such an hard constraint would impede performance while completing + the task successfully remains the highest priority. Still, the impact + velocity must be restricted to prevent destructive damage. It is + recommended to estimate an acceptable thresholdfrom real experimental data. + """ + def __init__(self, + env: InterfaceJiminyEnv, + position_margin: float, + velocity_max: float, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param position_margin: Distance of actuated joints from their + respective mechanical bounds below which + their speed is being watched. + :param velocity_max: Maximum velocity above which further approaching + the mechanical stops triggers termination when + watched for being close from them. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + # Backup user argument(s) + self.position_margin = position_margin + self.velocity_max = velocity_max + + # Call base implementation + super().__init__( + env, + "termination_mechanical_safety", + grace_period, + is_truncation=False, + is_training_only=is_training_only) + + # Add quantity to the set of quantities managed by the environment + self.env.quantities["_".join((self.name, "position_delta"))] = ( + _MultiActuatedJointBoundDistance, {}) + self.env.quantities["_".join((self.name, "velocity"))] = ( + MultiActuatedJointKinematic, dict( + kinematic_level=pin.KinematicLevel.VELOCITY, + is_motor_side=False)) + + # Keep track of the underlying quantities + registry = self.env.quantities.registry + self.position_delta = registry["_".join((self.name, "position_delta"))] + self.velocity = registry["_".join((self.name, "velocity"))] + + def __del__(self) -> None: + try: + for field in ("position_delta", "velocity"): + if hasattr(self, field): + del self.env.quantities["_".join((self.name, field))] + except Exception: # pylint: disable=broad-except + # This method must not fail under any circumstances + pass + + def compute(self, info: InfoType) -> bool: + """Evaluate the termination condition. + + The underlying quantity is first evaluated. The episode continues if + its value is within bounds, otherwise the episode is either truncated + or terminated according to 'is_truncation'. + + .. warning:: + This method is not meant to be overloaded. + """ + # Evaluate the quantity + position_delta_low, position_delta_high = self.position_delta.get() + velocity = self.velocity.get() + + # Check if the robot is going to hit the mechanical stops at high speed + is_done = any( + (position_delta_low < self.position_margin) & + (velocity < - self.velocity_max)) + is_done |= any( + (position_delta_high < self.position_margin) & + (velocity > self.velocity_max)) + return is_done + + +class MechanicalPowerConsumptionTermination(QuantityTermination): + """Terminate the episode immediately if the average mechanical power + consumption is too high. + + High power consumption is undesirable as it means that the motion is + suboptimal and probably unnatural and fragile. Moreover, it helps to + accommodate hardware capability to avoid motor overheating while increasing + battery autonomy and lifespan. Finally, it may be necessary to deal with + some hardware limitations on max power drain. + """ + def __init__( + self, + env: InterfaceJiminyEnv, + max_power: float, + horizon: float, + generator_mode: EnergyGenerationMode = EnergyGenerationMode.CHARGE, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param max_power: Maximum average mechanical power consumption applied + on any of the contact points or collision bodies + above which termination is triggered. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the average. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + # Backup user argument(s) + self.max_power = max_power + self.horizon = horizon + self.generator_mode = generator_mode + + # Call base implementation + super().__init__( + env, + "termination_power_consumption", + (AverageMechanicalPowerConsumption, dict( # type: ignore[arg-type] + horizon=self.horizon, + generator_mode=self.generator_mode)), + None, + self.max_power, + grace_period, + is_truncation=False, + is_training_only=is_training_only) + + +class ShiftTrackingMotorPositionsTermination(ShiftTrackingQuantityTermination): + """Terminate the episode if the selected reference trajectory is not + tracked with expected accuracy regarding the actuated joint positions, + whatever the timestep being considered over some fixed-size sliding window. + + The robot must track the reference if there is no hazard, only applying + minor corrections to keep balance. Rewarding the agent for doing so is + not effective as favoring robustness remains more profitable. Indeed, it + would anticipate disturbances, lowering its current reward to maximize the + future return, primarily averting termination. Limiting the shift over a + given horizon allows for large deviations to handle strong pushes. + Moreover, assuming that the agent is not able to keep track of the time + flow, which means that only the observation at the current step is provided + to the agent and o stateful network architecture such as LSTM is being + used, restricting the shift also urges to do what it takes to get back to + normal as soon as possible for fear of triggering termination, as it may + happen any time the deviation is above the maximum acceptable shift, + irrespective of its scale. + """ + def __init__(self, + env: InterfaceJiminyEnv, + thr: float, + horizon: float, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param thr: Maximum shift above which termination is triggered. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the shift. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + # Call base implementation + super().__init__( + env, + "termination_tracking_motor_positions", + lambda mode: (MultiActuatedJointKinematic, dict( + kinematic_level=pin.KinematicLevel.POSITION, + is_motor_side=False, + mode=mode)), + thr, + horizon, + grace_period, + is_truncation=False, + is_training_only=is_training_only) diff --git a/python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py b/python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py index eaf10fde2..7e49596e4 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py +++ b/python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py @@ -1,30 +1,38 @@ """Rewards mainly relevant for locomotion tasks on floating-base robots. """ from functools import partial -from typing import Union, Sequence, Literal, Callable, cast +from dataclasses import dataclass +from typing import Optional, Union, Sequence, Literal, Callable, cast import numpy as np +import numba as nb + +import jiminy_py.core as jiminy import pinocchio as pin from ..bases import ( - InterfaceJiminyEnv, StateQuantity, QuantityEvalMode, BaseQuantityReward) + InterfaceJiminyEnv, StateQuantity, InterfaceQuantity, QuantityEvalMode, + QuantityReward) from ..quantities import ( - MaskedQuantity, UnaryOpQuantity, AverageBaseOdometryVelocity, CapturePoint, - MultiFootRelativeXYZQuat, MultiContactRelativeForceTangential, - MultiFootRelativeForceVertical, AverageBaseMomentum) -from ..quantities.locomotion import sanitize_foot_frame_names -from ..utils import quat_difference - -from .generic import BaseTrackingReward + OrientationType, MaskedQuantity, UnaryOpQuantity, FrameOrientation, + BaseRelativeHeight, BaseOdometryPose, BaseOdometryAverageVelocity, + CapturePoint, MultiFramePosition, MultiFootRelativeXYZQuat, + MultiContactNormalizedSpatialForce, MultiFootNormalizedForceVertical, + MultiFootCollisionDetection, AverageBaseMomentum) +from ..utils import quat_difference, quat_to_yaw + +from .generic import ( + ArrayLikeOrScalar, TrackingQuantityReward, QuantityTermination, + DriftTrackingQuantityTermination, ShiftTrackingQuantityTermination) from .mixin import radial_basis_function -class TrackingBaseHeightReward(BaseTrackingReward): +class TrackingBaseHeightReward(TrackingQuantityReward): """Reward the agent for tracking the height of the floating base of the robot wrt some reference trajectory. .. seealso:: - See `BaseTrackingReward` documentation for technical details. + See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, @@ -33,27 +41,26 @@ def __init__(self, :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. """ - # Backup some user argument(s) - self.cutoff = cutoff - - # Call base implementation super().__init__( env, "reward_tracking_base_height", lambda mode: (MaskedQuantity, dict( quantity=(UnaryOpQuantity, dict( - quantity=(StateQuantity, dict(mode=mode)), + quantity=(StateQuantity, dict( + update_kinematics=False, + mode=mode)), op=lambda state: state.q)), + axis=0, keys=(2,))), cutoff) -class TrackingBaseOdometryVelocityReward(BaseTrackingReward): +class TrackingBaseOdometryVelocityReward(TrackingQuantityReward): """Reward the agent for tracking the odometry velocity wrt some reference trajectory. .. seealso:: - See `BaseTrackingReward` documentation for technical details. + See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, @@ -62,23 +69,19 @@ def __init__(self, :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. """ - # Backup some user argument(s) - self.cutoff = cutoff - - # Call base implementation super().__init__( env, "reward_tracking_odometry_velocity", - lambda mode: (AverageBaseOdometryVelocity, dict(mode=mode)), + lambda mode: (BaseOdometryAverageVelocity, dict(mode=mode)), cutoff) -class TrackingCapturePointReward(BaseTrackingReward): +class TrackingCapturePointReward(TrackingQuantityReward): """Reward the agent for tracking the capture point wrt some reference trajectory. .. seealso:: - See `BaseTrackingReward` documentation for technical details. + See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, @@ -87,25 +90,21 @@ def __init__(self, :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. """ - # Backup some user argument(s) - self.cutoff = cutoff - - # Call base implementation super().__init__( env, "reward_tracking_capture_point", lambda mode: (CapturePoint, dict( - reference_frame=pin.LOCAL, + reference_frame=pin.ReferenceFrame.LOCAL, mode=mode)), cutoff) -class TrackingFootPositionsReward(BaseTrackingReward): +class TrackingFootPositionsReward(TrackingQuantityReward): """Reward the agent for tracking the relative position of the feet wrt each other. .. seealso:: - See `BaseTrackingReward` documentation for technical details. + See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, @@ -121,16 +120,6 @@ def __init__(self, set of contact and force sensors of the robot. Optional: 'auto' by default. """ - # Backup some user argument(s) - self.cutoff = cutoff - - # Sanitize frame names corresponding to the feet of the robot - frame_names = tuple(sanitize_foot_frame_names(env, frame_names)) - - # Buffer storing the difference before current and reference poses - self._spatial_velocities = np.zeros((6, len(frame_names))) - - # Call base implementation super().__init__( env, "reward_tracking_foot_positions", @@ -138,16 +127,17 @@ def __init__(self, quantity=(MultiFootRelativeXYZQuat, dict( frame_names=frame_names, mode=mode)), + axis=0, keys=(0, 1, 2))), cutoff) -class TrackingFootOrientationsReward(BaseTrackingReward): +class TrackingFootOrientationsReward(TrackingQuantityReward): """Reward the agent for tracking the relative orientation of the feet wrt each other. .. seealso:: - See `BaseTrackingReward` documentation for technical details. + See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, @@ -163,13 +153,6 @@ def __init__(self, set of contact and force sensors of the robot. Optional: 'auto' by default. """ - # Backup some user argument(s) - self.cutoff = cutoff - - # Sanitize frame names corresponding to the feet of the robot - frame_names = tuple(sanitize_foot_frame_names(env, frame_names)) - - # Call base implementation super().__init__( env, "reward_tracking_foot_orientations", @@ -184,7 +167,7 @@ def __init__(self, [np.ndarray, np.ndarray], np.ndarray], quat_difference)) -class TrackingFootForceDistributionReward(BaseTrackingReward): +class TrackingFootForceDistributionReward(TrackingQuantityReward): """Reward the agent for tracking the relative vertical force in world frame applied on each foot. @@ -199,7 +182,7 @@ class TrackingFootForceDistributionReward(BaseTrackingReward): the flying phase of running. .. seealso:: - See `BaseTrackingReward` documentation for technical details. + See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, @@ -215,28 +198,21 @@ def __init__(self, set of contact and force sensors of the robot. Optional: 'auto' by default. """ - # Backup some user argument(s) - self.cutoff = cutoff - - # Sanitize frame names corresponding to the feet of the robot - frame_names = tuple(sanitize_foot_frame_names(env, frame_names)) - - # Call base implementation super().__init__( env, "reward_tracking_foot_force_distribution", - lambda mode: (MultiFootRelativeForceVertical, dict( + lambda mode: (MultiFootNormalizedForceVertical, dict( frame_names=frame_names, mode=mode)), cutoff) -class MinimizeAngularMomentumReward(BaseQuantityReward): +class MinimizeAngularMomentumReward(QuantityReward): """Reward the agent for minimizing the angular momentum in world plane. The angular momentum along x- and y-axes in local odometry frame is transform in a normalized reward to maximize by applying RBF kernel on the - error. See `BaseTrackingReward` documentation for technical details. + error. See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, @@ -258,16 +234,16 @@ def __init__(self, is_terminal=False) -class MinimizeFrictionReward(BaseQuantityReward): +class MinimizeFrictionReward(QuantityReward): """Reward the agent for minimizing the tangential forces at all the contact points and collision bodies, and to avoid jerky intermittent contact state. - The L2-norm is used to aggregate all the local tangential forces. While the - L1-norm would be more natural in this specific cases, using the L2-norm is - preferable as it promotes space-time regularity, ie balancing the force - distribution evenly between all the candidate contact points and avoiding - jerky contact forces over time (high-frequency vibrations), phenomena to - which the L1-norm is completely insensitive. + The L^2-norm is used to aggregate all the local tangential forces. While + the L^1-norm would be more natural in this specific cases, using the L-2 + norm is preferable as it promotes space-time regularity, ie balancing the + force distribution evenly between all the candidate contact points and + avoiding jerky contact forces over time (high-frequency vibrations), + phenomena to which the L^1-norm is completely insensitive. """ def __init__(self, env: InterfaceJiminyEnv, @@ -283,7 +259,540 @@ def __init__(self, super().__init__( env, "reward_friction", - (MultiContactRelativeForceTangential, dict()), + (MaskedQuantity, dict( + quantity=(MultiContactNormalizedSpatialForce, dict()), + axis=0, + keys=(0, 1))), partial(radial_basis_function, cutoff=self.cutoff, order=2), is_normalized=True, is_terminal=False) + + +class BaseRollPitchTermination(QuantityTermination): + """Encourages the agent to keep the floating base straight, ie its torso in + case of a humanoid robot, by prohibiting excessive roll and pitch angles. + """ + def __init__(self, + env: InterfaceJiminyEnv, + low: Optional[ArrayLikeOrScalar], + high: Optional[ArrayLikeOrScalar], + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param low: Lower bound below which termination is triggered. + :param high: Upper bound above which termination is triggered. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + super().__init__( + env, + "termination_base_roll_pitch", + (MaskedQuantity, dict( # type: ignore[arg-type] + quantity=(FrameOrientation, dict( + frame_name="root_joint", + type=OrientationType.EULER)), + axis=0, + keys=(0, 1))), + low, + high, + grace_period, + is_truncation=False, + is_training_only=is_training_only) + + +class FallingTermination(QuantityTermination): + """Terminate the episode immediately if the floating base of the robot + gets too close from the ground. + + It is assumed that the state is no longer recoverable when its condition + is triggered. As such, the episode is terminated on the spot as the + situation is hopeless. Generally speaking, aborting an epsiode in + anticipation of catastrophic failure is beneficial. Assuming the condition + is on point, doing this improves the signal to noice ratio when estimating + the gradient by avoiding cluterring the training batches with irrelevant + information. + """ + def __init__(self, + env: InterfaceJiminyEnv, + min_base_height: float, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param min_base_height: Minimum height of the floating base of the + robot below which termination is triggered. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + super().__init__( + env, + "termination_base_height", + (BaseRelativeHeight, {}), # type: ignore[arg-type] + min_base_height, + None, + grace_period, + is_truncation=False, + is_training_only=is_training_only) + + +class FootCollisionTermination(QuantityTermination): + """Terminate the episode immediately if some of the feet of the robot are + getting too close from each other. + + Self-collision must be avoided at all cost, as it can damage the hardware. + Considering this condition as a dramatically failure urges the agent to do + his best in this matter, to the point of becoming risk averse. + """ + def __init__(self, + env: InterfaceJiminyEnv, + security_margin: float = 0.0, + grace_period: float = 0.0, + frame_names: Union[Sequence[str], Literal['auto']] = 'auto', + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param security_margin: + Minimum signed distance below which termination is triggered. This + can be interpreted as inflating or deflating the geometry objects + by the safety margin depending on whether it is positive or + negative. See `MultiFootCollisionDetection` for details. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param frame_names: Name of the frames corresponding to the feet of the + robot. 'auto' to automatically detect them from the + set of contact and force sensors of the robot. + Optional: 'auto' by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + super().__init__( + env, + "termination_foot_collision", + (MultiFootCollisionDetection, dict( # type: ignore[arg-type] + frame_names=frame_names, + security_margin=security_margin)), + False, + False, + grace_period, + is_truncation=False, + is_training_only=is_training_only) + + +@dataclass(unsafe_hash=True) +class _MultiContactMinGroundDistance(InterfaceQuantity[float]): + """Minimum distance from the ground profile among all the contact points. + + .. note:: + Internally, it does not compute the exact shortest distance from the + ground profile because it would be computionally too demanding for now. + As a surrogate, it relies on a first order approximation assuming zero + local curvature around all the contact points individually. + + .. warning:: + The set of contact points must not change over episodes. In addition, + collision bodies are not supported for now. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity]) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + """ + # Get the name of all the contact points + contact_frame_names = env.robot.contact_frame_names + + # Call base implementation + super().__init__( + env, + parent, + requirements=dict( + positions=(MultiFramePosition, dict( + frame_names=contact_frame_names, + mode=QuantityEvalMode.TRUE + ))), + auto_refresh=False) + + # Jit-able method computing the minimum first-order depth + @nb.jit(nopython=True, cache=True, fastmath=True) + def min_depth(positions: np.ndarray, + heights: np.ndarray, + normals: np.ndarray) -> float: + """Approximate minimum distance from the ground profile among a set + of the query points. + + Internally, it uses a first order approximation assuming zero local + curvature around each query point. + + :param positions: Position of all the query points from which to + compute from the ground profile, as a 2D array + whose first dimension gathers the 3 position + coordinates (X, Y, Z) while the second correponds + to the N individual query points. + :param heights: Vertical height wrt the ground profile of the N + individual query points in world frame as 1D array. + :param normals: Normal of the ground profile for the projection in + world plane of all the query points, as a 2D array + whose first dimension gathers the 3 position + coordinates (X, Y, Z) while the second correponds + to the N individual query points. + """ + return np.min((positions[2] - heights) * normals[2]) + + self._min_depth = min_depth + + # Reference to the heightmap function for the ongoing epsiode + self._heightmap = jiminy.HeightmapFunction(lambda: None) + + # Allocate memory for the height and normal of all the contact points + self._heights = np.zeros((len(contact_frame_names),)) + self._normals = np.zeros((3, len(contact_frame_names)), order="F") + + def initialize(self) -> None: + # Call base implementation + super().initialize() + + # Refresh the heighmap function + engine_options = self.env.unwrapped.engine.get_options() + self._heightmap = engine_options["world"]["groundProfile"] + + def refresh(self) -> float: + # Query the height and normal to the ground profile for the position in + # world plane of all the contact points. + positions = self.positions.get() + jiminy.query_heightmap(self._heightmap, + positions[:2], + self._heights, + self._normals) + + # Make sure the ground normal is normalized + # self._normals /= np.linalg.norm(self._normals, axis=0) + + # First-order distance estimation assuming no curvature + return self._min_depth(positions, self._heights, self._normals) + + +class FlyingTermination(QuantityTermination): + """Discourage the agent of jumping by terminating the episode immediately + if the robot is flying too high above the ground. + + This kind of behavior is unsually undesirable because it may be frightning + for people nearby, damage the hardware, be difficult to predict and be + hardly repeatable. Moreover, such dynamic motions tend to transfer poorly + to reality because the simulation to real gap is worsening. + """ + def __init__(self, + env: InterfaceJiminyEnv, + max_height: float, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param max_height: Maximum height of the lowest contact points wrt the + groupd above which termination is triggered. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + super().__init__( + env, + "termination_flying", + (_MultiContactMinGroundDistance, {}), # type: ignore[arg-type] + None, + max_height, + grace_period, + is_truncation=False, + is_training_only=is_training_only) + + +class ImpactForceTermination(QuantityTermination): + """Terminate the episode immediately in case of violent impact on the + ground. + + Similarly to the jumping behavior, this kind of behavior is usually + undesirable. See `FlyingTermination` documentation for details. + """ + def __init__(self, + env: InterfaceJiminyEnv, + max_force_rel: float, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param max_force_rel: Maximum vertical force applied on any of the + contact points or collision bodies above which + termination is triggered. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + super().__init__( + env, + "termination_impact_force", + (MaskedQuantity, dict( # type: ignore[arg-type] + quantity=(MultiContactNormalizedSpatialForce, dict()), + axis=0, + keys=(2,))), + None, + max_force_rel, + grace_period, + is_truncation=False, + is_training_only=is_training_only) + + +class DriftTrackingBaseOdometryPositionTermination( + DriftTrackingQuantityTermination): + """Terminate the episode if the current base odometry position is drifting + too much over wrt some reference trajectory that is being tracked. + + It is generally important to make sure that the robot is not deviating too + much from some reference trajectory. It sounds appealing to make sure that + the absolute error between the current and reference trajectory is bounded + at all time. However, such a condition is very restrictive, especially for + robots dealing with external disturbances or evolving on an uneven terrain. + Moreover, when it comes to infinite-horizon trajectories in particular, eg + periodic motions, avoiding drifting away over time involves being able to + sense the absolute position of the robot in world frame via exteroceptive + navigation sensors such as depth cameras or LIDARs. This kind of advanced + sensor may not be able, thereby making the objective out of reach. Still, + in the case of legged locomotion, what really matters is tracking + accurately a nominal limit cycle as long as doing so does not compromise + local stability. If it does, then the agent expected to make every effort + to recover balance as fast as possible before going back to the nominal + limit cycle, without trying to catch up with the ensuing drift since the + exact absolute odometry pose in world frame is of little interest. See + `BaseOdometryPose` and `DriftTrackingQuantityTermination` documentations + for details. + """ + def __init__(self, + env: InterfaceJiminyEnv, + max_position_err: float, + horizon: float, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param max_position_err: + Maximum drift error in translation (X, Y) in world plane above + which termination is triggered. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the drift. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + super().__init__( + env, + "termination_tracking_base_odom_position", + lambda mode: ( # type: ignore[arg-type, return-value] + MaskedQuantity, dict( + quantity=(BaseOdometryPose, dict( + mode=mode)), + axis=0, + keys=(0, 1))), + None, + max_position_err, + horizon, + grace_period, + post_fn=np.linalg.norm, + is_truncation=False, + is_training_only=is_training_only) + + +class DriftTrackingBaseOdometryOrientationTermination( + DriftTrackingQuantityTermination): + """Terminate the episode if the current base odometry orientation is + drifting too much over wrt some reference trajectory that is being tracked. + + See `BaseOdometryPose` and `DriftTrackingBaseOdometryPositionTermination` + documentations for details. + """ + def __init__(self, + env: InterfaceJiminyEnv, + max_orientation_err: float, + horizon: float, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param max_orientation_err: + Maximum drift error in orientation (yaw,) in world plane above + which termination is triggered. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the drift. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + super().__init__( + env, + "termination_tracking_base_odom_orientation", + lambda mode: ( # type: ignore[arg-type, return-value] + MaskedQuantity, dict( + quantity=(BaseOdometryPose, dict( + mode=mode)), + axis=0, + keys=(2,))), + -max_orientation_err, + max_orientation_err, + horizon, + grace_period, + is_truncation=False, + is_training_only=is_training_only) + + +class ShiftTrackingFootOdometryPositionsTermination( + ShiftTrackingQuantityTermination): + """Terminate the episode if the selected reference trajectory is not + tracked with expected accuracy regarding the relative foot odometry + positions, whatever the timestep being considered over some fixed-size + sliding window. + + See `MultiFootRelativeXYZQuat` and `ShiftTrackingMotorPositionsTermination` + documentation for details. + """ + def __init__(self, + env: InterfaceJiminyEnv, + max_position_err: float, + horizon: float, + grace_period: float = 0.0, + frame_names: Union[Sequence[str], Literal['auto']] = 'auto', + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param max_position_err: + Maximum drift error in translation (X, Y) in world plane above + which termination is triggered. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the shift. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param frame_names: Name of the frames corresponding to the feet of the + robot. 'auto' to automatically detect them from the + set of contact and force sensors of the robot. + Optional: 'auto' by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + super().__init__( + env, + "termination_tracking_foot_odom_positions", + lambda mode: ( # type: ignore[arg-type, return-value] + MaskedQuantity, dict( + quantity=(MultiFootRelativeXYZQuat, dict( + frame_names=frame_names, + mode=mode)), + axis=0, + keys=(0, 1))), + max_position_err, + horizon, + grace_period, + is_truncation=False, + is_training_only=is_training_only) + + +class ShiftTrackingFootOdometryOrientationsTermination( + ShiftTrackingQuantityTermination): + """Terminate the episode if the selected reference trajectory is not + tracked with expected accuracy regarding the relative foot odometry + orientations, whatever the timestep being considered over some fixed-size + sliding window. + + See `MultiFootRelativeXYZQuat` and `ShiftTrackingMotorPositionsTermination` + documentation for details. + """ + def __init__(self, + env: InterfaceJiminyEnv, + max_orientation_err: float, + horizon: float, + grace_period: float = 0.0, + frame_names: Union[Sequence[str], Literal['auto']] = 'auto', + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + Maximum shift error in orientation (yaw,) in world plane above + which termination is triggered. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the shift. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param frame_names: Name of the frames corresponding to the feet of the + robot. 'auto' to automatically detect them from the + set of contact and force sensors of the robot. + Optional: 'auto' by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + # Call base implementation + super().__init__( + env, + "termination_tracking_foot_odom_orientations", + lambda mode: (UnaryOpQuantity, dict( + quantity=(MaskedQuantity, dict( + quantity=(MultiFootRelativeXYZQuat, dict( + frame_names=frame_names, + mode=mode)), + axis=0, + keys=(3, 4, 5, 6))), + op=quat_to_yaw)), + max_orientation_err, + horizon, + grace_period, + is_truncation=False, + is_training_only=is_training_only) diff --git a/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py b/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py index 80bb0dab8..d3e5ea57b 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py +++ b/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py @@ -4,12 +4,13 @@ """ import math import logging -from typing import Sequence, Optional, Union +from functools import partial +from typing import Sequence, Tuple, Optional, Union, Literal import numpy as np import numba as nb -from ..bases import InterfaceJiminyEnv, AbstractReward, BaseMixtureReward +from ..bases import InterfaceJiminyEnv, AbstractReward, MixtureReward # Reward value at cutoff threshold @@ -35,7 +36,7 @@ def radial_basis_function(error: ArrayOrScalar, where :math:`dist(x, x_ref)` is some distance metric of the error between the observed (:math:`x`) and desired (:math:`x_ref`) values of a - multi-variate quantity. The L2-norm (Euclidean norm) was used when it was + multi-variate quantity. The L^2-norm (Euclidean norm) was used when it was first introduced as a non-linear kernel for Support Vector Machine (SVM) algorithm. Such restriction does not make sense in the context of reward normalization. The scaling parameter :math:`sigma` is derived from the @@ -44,7 +45,7 @@ def radial_basis_function(error: ArrayOrScalar, :param error: Multi-variate error on some tangent space as a 1D array. :param cutoff: Cut-off threshold to consider. - :param order: Order of Lp-Norm that will be used as distance metric. + :param order: Order of L^p-norm that will be used as distance metric. """ error_ = np.asarray(error) is_contiguous = error_.flags.f_contiguous or error_.flags.c_contiguous @@ -65,42 +66,63 @@ def radial_basis_function(error: ArrayOrScalar, return math.pow(CUTOFF_ESP, squared_dist_rel) -class AdditiveMixtureReward(BaseMixtureReward): - """Weighted sum of multiple independent reward components. +class AdditiveMixtureReward(MixtureReward): + """Weighted L^p-norm of multiple independent reward components. - Aggregation of reward components using the addition operator is suitable - when improving the behavior for any of them without the others is equally - beneficial, and unbalanced performance for each reward component is - considered acceptable rather than detrimental. It especially makes sense + Aggregating the reward components using L^p-norm progressively transition + from promoting versatility for 0 < p < 1, to overall competency for p = 1, + and ultimately specialization for p > 1. In particular, the L^1-norm is + appropriate when improving the behavior for any of them without the others + is equally beneficial, and unbalanced performance for each reward component + is considered acceptable rather than detrimental. It especially makes sense for reward that are not competing with each other (improving one tends to impede some other). In the latter case, the multiplicative operator is more appropriate. See `MultiplicativeMixtureReward` documentation for details. + + .. note:: + Combining `AdditiveMixtureReward` for L^inf-norm with `SurviveReward` + ensures a minimum reward that the agent would obtain no matter what, + to encourage surviving in last resort. This is usually useful to + bootstrap learning at the very beginning. """ def __init__(self, env: InterfaceJiminyEnv, name: str, components: Sequence[AbstractReward], + order: Union[int, float, Literal['inf']] = 1, weights: Optional[Sequence[float]] = None) -> None: """ :param env: Base or wrapped jiminy environment. :param name: Desired name of the total reward. :param components: Sequence of reward components to aggregate. - :param weights: Sequence of weights associated with each reward - components, with the same ordering as 'components'. - Optional: 1.0 for all reward components by default. + :param order: Order of L^p-norm used to add up the reward components. + :param weights: Optional sequence of weights associated with each + reward component, with same ordering as 'components'. + Optional: Same weights that preserves normalization by + default, `(1.0 / len(components),) * len(components)`. """ # Handling of default arguments if weights is None: - weights = (1.0,) * len(components) + weights = (1.0 / len(components),) * len(components) + + # Make sure that the order is strictly positive + if not isinstance(order, str) and order <= 0.0: + raise ValueError("'order' must be strictly positive or 'inf'.") # Make sure that the weight sequence is consistent with the components if len(weights) != len(components): raise ValueError( "Exactly one weight per reward component must be specified.") + # Filter out components whose weight are zero + weights, components = zip(*( + (weight, reward) + for weight, reward in zip(weights, components) + if weight > 0.0)) + # Determine whether the cumulative reward is normalized - weight_total = 0.0 + scale = 0.0 for weight, reward in zip(weights, components): if not reward.is_normalized: LOGGER.warning( @@ -109,36 +131,63 @@ def __init__(self, "recommended.", reward.name) is_normalized = False break - weight_total += weight + if order == 'inf': + scale = max(scale, weight) + else: + scale += weight else: - is_normalized = abs(weight_total - 1.0) < 1e-4 + is_normalized = abs(1.0 - scale) < 1e-4 # Backup user-arguments - self.weights = weights + self.order = order + self.weights = tuple(weights) + + # Jit-able method computing the weighted sum of reward components + @nb.jit(nopython=True, cache=True, fastmath=True) + def weighted_norm(weights: Tuple[float, ...], + order: Union[int, float, Literal['inf']], + values: Tuple[Optional[float], ...] + ) -> Optional[float]: + """Compute the weighted L^p-norm of all the reward components that + has been evaluated, filtering out the others. + + This method returns `None` if no reward component has been + evaluated. + + :param weights: Sequence of weights for each reward component, with + same ordering as 'components'. + :param order: Order of the L^p-norm. + :param values: Sequence of scalar value for reward components that + has been evaluated, `None` otherwise, with the same + ordering as 'components'. + + :returns: Scalar value if at least one of the reward component has + been evaluated, `None` otherwise. + """ + total, any_value = 0.0, False + for value, weight in zip(values, weights): + if value is not None: + if isinstance(order, str): + if any_value: + total = max(total, weight * value) + else: + total = value + else: + total += weight * math.pow(value, order) + any_value = True + if any_value: + if isinstance(order, str): + return total + return math.pow(total, 1.0 / order) + return None # Call base implementation - super().__init__(env, name, components, self._reduce, is_normalized) - - def _reduce(self, values: Sequence[Optional[float]]) -> Optional[float]: - """Compute the weighted sum of all the reward components that has been - evaluated, filtering out the others. - - This method returns `None` if no reward component has been evaluated. - - :param values: Sequence of scalar value for reward components that has - been evaluated, `None` otherwise, with the same ordering - as 'components'. - - :returns: Scalar value if at least one of the reward component has been - evaluated, `None` otherwise. - """ - # TODO: x2 speedup can be expected with `nb.jit` - total, any_value = 0.0, False - for weight, value in zip(self.weights, values): - if value is not None: - total += weight * value - any_value = True - return total if any_value else None + super().__init__( + env, + name, + components, + partial(weighted_norm, self.weights, self.order), + is_normalized) AdditiveMixtureReward.is_normalized.__doc__ = \ @@ -150,10 +199,10 @@ def _reduce(self, values: Sequence[Optional[float]]) -> Optional[float]: """ -class MultiplicativeMixtureReward(BaseMixtureReward): - """Product of multiple independent reward components. +class MultiplicativeMixtureReward(MixtureReward): + """Geometric mean of independent reward components, to promote versatility. - Aggregation of reward components using multiplication operator is suitable + Aggregating the reward components using the geometric mean is appropriate when maintaining balanced performance between all reward components is essential, and having poor performance for any of them is unacceptable. This type of aggregation is especially useful when reward components are @@ -164,8 +213,7 @@ class MultiplicativeMixtureReward(BaseMixtureReward): def __init__(self, env: InterfaceJiminyEnv, name: str, - components: Sequence[AbstractReward] - ) -> None: + components: Sequence[AbstractReward]) -> None: """ :param env: Base or wrapped jiminy environment. :param name: Desired name of the reward. @@ -174,29 +222,33 @@ def __init__(self, # Determine whether the cumulative reward is normalized is_normalized = all(reward.is_normalized for reward in components) - # Call base implementation - super().__init__(env, name, components, self._reduce, is_normalized) - - def _reduce(self, values: Sequence[Optional[float]]) -> Optional[float]: - """Compute the product of all the reward components that has been - evaluated, filtering out the others. - - This method returns `None` if no reward component has been evaluated. + # Jit-able method computing the product of reward components + @nb.jit(nopython=True, cache=True, fastmath=True) + def geometric_mean( + values: Tuple[Optional[float], ...]) -> Optional[float]: + """Compute the product of all the reward components that has + been evaluated, filtering out the others. + + This method returns `None` if no reward component has been + evaluated. + + :param values: Sequence of scalar value for reward components that + has been evaluated, `None` otherwise, with the same + ordering as 'components'. + + :returns: Scalar value if at least one of the reward component has + been evaluated, `None` otherwise. + """ + total, any_value, n_values = 1.0, False, 0 + for value in values: + if value is not None: + total *= value + any_value = True + n_values += 1 + return math.pow(total, 1.0 / n_values) if any_value else None - :param values: Sequence of scalar value for reward components that has - been evaluated, `None` otherwise, with the same ordering - as 'components'. - - :returns: Scalar value if at least one of the reward component has been - evaluated, `None` otherwise. - """ - # TODO: x2 speedup can be expected with `nb.jit` - total, any_value = 1.0, False - for value in values: - if value is not None: - total *= value - any_value = True - return total if any_value else None + # Call base implementation + super().__init__(env, name, components, geometric_mean, is_normalized) MultiplicativeMixtureReward.is_normalized.__doc__ = \ diff --git a/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py b/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py index 7e3db84a7..0172309d7 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py @@ -109,6 +109,7 @@ class BaseJiminyEnv(InterfaceJiminyEnv[ObsT, ActT], def __init__(self, simulator: Simulator, step_dt: float, + simulation_duration_max: float = 86400.0, debug: bool = False, render_mode: Optional[str] = None, **kwargs: Any) -> None: @@ -119,11 +120,15 @@ def __init__(self, independent from the controller and observation update periods. The latter are configured via `engine.set_options`. - :param mode: Rendering mode. It can be either 'human' to display the - current simulation state, or 'rgb_array' to return a - snapshot as an RGB array without showing it on the screen. - Optional: 'human' by default if available with the current - backend (or default if none), 'rgb_array' otherwise. + :param simulation_duration_max: + Maximum duration of a simulation. If the current simulation time + exceeds this threshold, then it will triggers `is_truncated=True`. + It cannot exceed the maximum possible duration before telemetry + log time overflow which is extremely large (about 30 years). Beware + that log data are stored in RAM, which may cause out-of-memory + error if the episode is lasting for too long without reset. + Optional: About 4GB of log data assuming 5ms control update period + and telemetry disabled for everything but the robot configuration. :param debug: Whether the debug mode must be enabled. Doing it enables telemetry recording. :param render_mode: Desired rendering mode, ie "human" or "rgb_array". @@ -174,7 +179,8 @@ def __init__(self, assert render_mode in self.metadata['render_modes'] # Backup some user arguments - self.simulator: Simulator = simulator + self.simulator = simulator + self.simulation_duration_max = simulation_duration_max self._step_dt = step_dt self.render_mode = render_mode self.debug = debug @@ -707,12 +713,6 @@ def reset(self, # type: ignore[override] self.robot.compute_sensor_measurements( 0.0, q_init, v_init, a_init, u_motor, f_external) - # Re-initialize the quantity manager. - # Note that computation graph tracking is never reset automatically. - # It is the responsibility of the practitioner implementing a derived - # environment whenever it makes sense for its specific use-case. - self.quantities.reset(reset_tracking=False) - # Run the reset hook if any. # Note that the reset hook must be called after `_setup` because it # expects that the robot is not going to change anymore at this point. @@ -727,6 +727,12 @@ def reset(self, # type: ignore[override] env = env_derived self.derived = env + # Re-initialize the quantity manager. + # Note that computation graph tracking is never reset automatically. + # It is the responsibility of the practitioner implementing a derived + # environment whenever it makes sense for its specific use-case. + self.quantities.reset(reset_tracking=False) + # Instantiate the actual controller. # Note that a weak reference must be used to avoid circular reference. self.robot.controller = jiminy.FunctionalController( @@ -751,6 +757,9 @@ def reset(self, # type: ignore[override] # Update shared buffers self._refresh_buffers() + # Clear cache and auto-refresh managed quantities + self.quantities.clear() + # Initialize the observation env._observer_handle( self.stepper_state.t, @@ -783,7 +792,7 @@ def reset(self, # type: ignore[override] self._info.clear() # The simulation cannot be done before doing a single step. - if any(self.has_terminated(self._info)): + if any(self.derived.has_terminated(self._info)): raise RuntimeError( "The simulation has already terminated at `reset`. Check the " "implementation of `has_terminated` if overloaded.") @@ -879,6 +888,9 @@ def step(self, # type: ignore[override] # Update shared buffers self._refresh_buffers() + # Clear cache and auto-refresh managed quantities + self.quantities.clear() + # Update the observer at the end of the step. # This is necessary because, internally, it is called at the beginning # of the every integration steps, during the controller update. @@ -894,14 +906,14 @@ def step(self, # type: ignore[override] # Check if the simulation is over. # Note that 'truncated' is forced to True if the integration failed or # if the maximum number of steps will be exceeded next step. - terminated, truncated = self.has_terminated(self._info) + terminated, truncated = self.derived.has_terminated(self._info) truncated = ( truncated or not self.is_simulation_running or self.stepper_state.t + DT_EPS > self.simulation_duration_max) # Check if stepping after done and if it is an undefined behavior if self._num_steps_beyond_terminate is None: - if terminated or truncated: + if terminated: self._num_steps_beyond_terminate = 0 else: if self.is_training and self._num_steps_beyond_terminate == 0: @@ -1033,7 +1045,7 @@ def replay(self, **kwargs: Any) -> None: kwargs['close_backend'] = not self.simulator.is_viewer_available # Stop any running simulation before replay if `has_terminated` is True - if any(self.has_terminated({})): + if any(self.derived.has_terminated({})): self.stop() with viewer_lock: @@ -1491,9 +1503,7 @@ def compute_command(self, action: ActT, command: np.ndarray) -> None: assert isinstance(action, np.ndarray) array_copyto(command, action) - def has_terminated(self, - info: InfoType # pylint: disable=unused-argument - ) -> Tuple[bool, bool]: + def has_terminated(self, info: InfoType) -> Tuple[bool, bool]: """Determine whether the episode is over, because a terminal state of the underlying MDP has been reached or an aborting condition outside the scope of the MDP has been triggered. @@ -1504,9 +1514,7 @@ def has_terminated(self, .. warning:: No matter what, truncation will happen when reaching the maximum - simulation duration, i.e. 'self.simulation_duration_max'. Its - default value is extremely large, but it can be overwritten by the - user to terminate the simulation earlier. + simulation duration, i.e. 'self.simulation_duration_max'. .. note:: This method is called after `refresh_observation`, so that the diff --git a/python/gym_jiminy/common/gym_jiminy/common/envs/locomotion.py b/python/gym_jiminy/common/gym_jiminy/common/envs/locomotion.py index 661a854d9..c4e0b5165 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/envs/locomotion.py +++ b/python/gym_jiminy/common/gym_jiminy/common/envs/locomotion.py @@ -150,7 +150,6 @@ def __init__(self, std_ratio = {k: v for k, v in std_ratio.items() if v > 0.0} # Backup user arguments - self.simulation_duration_max = simulation_duration_max self.reward_mixture = reward_mixture self.urdf_path = urdf_path self.mesh_path_dir = mesh_path_dir @@ -202,7 +201,8 @@ def __init__(self, simulator.import_options(config_path) # Initialize base class - super().__init__(simulator, step_dt, debug, **kwargs) + super().__init__( + simulator, step_dt, simulation_duration_max, debug, **kwargs) def _setup(self) -> None: """Configure the environment. diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py index dbd4ae526..f0d961aaa 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py @@ -4,28 +4,34 @@ from .transform import (StackedQuantity, MaskedQuantity, UnaryOpQuantity, - BinaryOpQuantity) -from .generic import (OrientationType, + BinaryOpQuantity, + MultiAryOpQuantity) +from .generic import (EnergyGenerationMode, + OrientationType, FrameOrientation, FramePosition, FrameXYZQuat, - MultiFramesOrientation, - MultiFramesPosition, - MultiFramesXYZQuat, - MultiFramesMeanXYZQuat, + MultiFrameOrientation, + MultiFramePosition, + MultiFrameXYZQuat, + MultiFrameMeanXYZQuat, + MultiFrameCollisionDetection, AverageFrameXYZQuat, AverageFrameRollPitch, - AverageFrameSpatialVelocity, - ActuatedJointsPosition) + AverageMechanicalPowerConsumption, + FrameSpatialAverageVelocity, + MultiActuatedJointKinematic) from .locomotion import (BaseOdometryPose, - MultiFootMeanOdometryPose, - AverageBaseSpatialVelocity, - AverageBaseOdometryVelocity, + BaseRelativeHeight, + BaseSpatialAverageVelocity, + BaseOdometryAverageVelocity, AverageBaseMomentum, MultiFootMeanXYZQuat, MultiFootRelativeXYZQuat, - MultiFootRelativeForceVertical, - MultiContactRelativeForceTangential, + MultiFootMeanOdometryPose, + MultiFootNormalizedForceVertical, + MultiContactNormalizedSpatialForce, + MultiFootCollisionDetection, CenterOfMass, CapturePoint, ZeroMomentPoint, @@ -33,32 +39,38 @@ __all__ = [ + 'EnergyGenerationMode', 'OrientationType', 'QuantityManager', 'StackedQuantity', 'MaskedQuantity', 'UnaryOpQuantity', 'BinaryOpQuantity', + 'MultiAryOpQuantity', 'FrameOrientation', 'FramePosition', 'FrameXYZQuat', - 'MultiFramesPosition', - 'MultiFramesOrientation', - 'MultiFramesXYZQuat', - 'MultiFramesMeanXYZQuat', + 'MultiFramePosition', + 'MultiFrameOrientation', + 'MultiFrameXYZQuat', + 'MultiFrameMeanXYZQuat', + 'MultiFrameCollisionDetection', 'MultiFootMeanXYZQuat', 'MultiFootRelativeXYZQuat', 'MultiFootMeanOdometryPose', - 'MultiFootRelativeForceVertical', - 'MultiContactRelativeForceTangential', + 'MultiFootNormalizedForceVertical', + 'MultiFootCollisionDetection', + 'MultiContactNormalizedSpatialForce', 'AverageFrameXYZQuat', 'AverageFrameRollPitch', - 'AverageFrameSpatialVelocity', - 'AverageBaseSpatialVelocity', - 'AverageBaseOdometryVelocity', + 'AverageMechanicalPowerConsumption', + 'FrameSpatialAverageVelocity', + 'BaseSpatialAverageVelocity', + 'BaseOdometryAverageVelocity', 'AverageBaseMomentum', 'BaseOdometryPose', - 'ActuatedJointsPosition', + 'BaseRelativeHeight', + 'MultiActuatedJointKinematic', 'CenterOfMass', 'CapturePoint', 'ZeroMomentPoint', diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py index 3e6a8699a..9def585df 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py @@ -4,10 +4,12 @@ application (locomotion, grasping...). """ import warnings -from enum import Enum +from enum import IntEnum +from functools import partial from dataclasses import dataclass from typing import ( - List, Dict, Optional, Protocol, Sequence, Tuple, Union, runtime_checkable) + List, Dict, Optional, Protocol, Sequence, Tuple, Union, Callable, + runtime_checkable) import numpy as np import numba as nb @@ -16,6 +18,7 @@ from jiminy_py.core import ( # pylint: disable=no-name-in-module array_copyto, multi_array_copyto) import pinocchio as pin +import hppfcl as fcl from ..bases import ( InterfaceJiminyEnv, InterfaceQuantity, AbstractQuantity, StateQuantity, @@ -24,7 +27,8 @@ matrix_to_rpy, matrix_to_quat, quat_apply, remove_yaw_from_quat, quat_interpolate_middle) -from .transform import StackedQuantity, MaskedQuantity +from .transform import ( + StackedQuantity, MaskedQuantity, UnaryOpQuantity, BinaryOpQuantity) @runtime_checkable @@ -42,7 +46,7 @@ class FrameQuantity(Protocol): @runtime_checkable -class MultiFramesQuantity(Protocol): +class MultiFrameQuantity(Protocol): """Protocol that must be satisfied by all quantities associated with a particular set of frames for which the same batched intermediary quantities must be computed. @@ -76,19 +80,21 @@ def aggregate_frame_names(quantity: InterfaceQuantity) -> Tuple[ active set must be decided by cache owners. :param quantity: Quantity whose parent implements either `FrameQuantity` or - `MultiFramesQuantity` protocol. All the parents of all its + `MultiFrameQuantity` protocol. All the parents of all its cache owners must also implement one of these protocol. """ # Make sure that parent quantity implement multi- or single-frame protocol - assert isinstance(quantity.parent, (FrameQuantity, MultiFramesQuantity)) + assert isinstance(quantity.parent, (FrameQuantity, MultiFrameQuantity)) quantities = (quantity.cache.owners if quantity.has_cache else (quantity,)) # First, order all multi-frame quantities by decreasing length frame_names_chunks: List[Tuple[str, ...]] = [] for owner in quantities: - if owner.parent.is_active(any_cache_owner=False): - if isinstance(owner.parent, MultiFramesQuantity): - frame_names_chunks.append(owner.parent.frame_names) + parent = owner.parent + assert parent is not None + if parent.is_active(any_cache_owner=False): + if isinstance(parent, MultiFrameQuantity): + frame_names_chunks.append(parent.frame_names) # Next, process ordered multi-frame quantities sequentially. # For each of them, we first check if its set of frames is completely @@ -128,9 +134,11 @@ def aggregate_frame_names(quantity: InterfaceQuantity) -> Tuple[ # Otherwise, we just move to the next quantity. frame_name_chunks: List[str] = [] for owner in quantities: - if owner.parent.is_active(any_cache_owner=False): - if isinstance(owner.parent, FrameQuantity): - frame_name_chunks.append(owner.parent.frame_name) + parent = owner.parent + assert parent is not None + if parent.is_active(any_cache_owner=False): + if isinstance(parent, FrameQuantity): + frame_name_chunks.append(parent.frame_name) frame_name = frame_name_chunks[-1] if frame_name not in frame_names: frame_names.append(frame_name) @@ -196,7 +204,7 @@ def __init__(self, :param mode: Desired mode of evaluation for this quantity. """ # Make sure that a parent has been specified - assert isinstance(parent, (FrameQuantity, MultiFramesQuantity)) + assert isinstance(parent, (FrameQuantity, MultiFrameQuantity)) # Call base implementation super().__init__( @@ -259,7 +267,7 @@ def refresh(self) -> Dict[Union[str, Tuple[str, ...]], np.ndarray]: return self._rot_mat_map -class OrientationType(Enum): +class OrientationType(IntEnum): """Specify the desired vector representation of the frame orientations. """ @@ -280,6 +288,11 @@ class OrientationType(Enum): """ +# Define proxies for fast lookup +_MATRIX, _EULER, _QUATERNION, _ANGLE_AXIS = ( # pylint: disable=invalid-name + OrientationType) + + @dataclass(unsafe_hash=True) class _BatchedFramesOrientation( InterfaceQuantity[Dict[Union[str, Tuple[str, ...]], np.ndarray]]): @@ -312,7 +325,7 @@ class _BatchedFramesOrientation( """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -322,12 +335,12 @@ class _BatchedFramesOrientation( def __init__(self, env: InterfaceJiminyEnv, - parent: Union["FrameOrientation", "MultiFramesOrientation"], + parent: Union["FrameOrientation", "MultiFrameOrientation"], type: OrientationType, mode: QuantityEvalMode) -> None: """ :param env: Base or wrapped jiminy environment. - :param parent: `FrameOrientation` or `MultiFramesOrientation` instance + :param parent: `FrameOrientation` or `MultiFrameOrientation` instance from which this quantity is a requirement. :param type: Desired vector representation of the orientation for all frames. Note that `OrientationType.ANGLE_AXIS` is not @@ -335,7 +348,7 @@ def __init__(self, :param mode: Desired mode of evaluation for this quantity. """ # Make sure that a suitable parent has been provided - assert isinstance(parent, (FrameOrientation, MultiFramesOrientation)) + assert isinstance(parent, (FrameOrientation, MultiFrameOrientation)) # Make sure that the specified orientation representation is supported if type not in (OrientationType.MATRIX, @@ -351,7 +364,7 @@ def __init__(self, # Initialize the ordered list of frame names. # Note that this must be done BEFORE calling base `__init__`, otherwise - # `isinstance(..., (FrameQuantity, MultiFramesQuantity))` will fail. + # `isinstance(..., (FrameQuantity, MultiFrameQuantity))` will fail. self.frame_names: Tuple[str, ...] = () # Call base implementation @@ -403,12 +416,13 @@ def initialize(self) -> None: def refresh(self) -> Dict[Union[str, Tuple[str, ...]], np.ndarray]: # Get the complete batch of rotation matrices managed by this instance - rot_mat_batch = self.rot_mat_map[self.frame_names] + value = self.rot_mat_map.get() + rot_mat_batch = value[self.frame_names] # Convert all rotation matrices at once to the desired representation - if self.type == OrientationType.EULER: + if self.type == _EULER: matrix_to_rpy(rot_mat_batch, self._data_batch) - elif self.type == OrientationType.QUATERNION: + elif self.type == _QUATERNION: matrix_to_quat(rot_mat_batch, self._data_batch) else: # Slice data. @@ -437,7 +451,7 @@ class FrameOrientation(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -486,14 +500,18 @@ def initialize(self) -> None: # Force re-initializing shared data if the active set has changed if not was_active: - self.requirements["data"].reset(reset_tracking=True) + # Must reset the tracking for shared computation systematically, + # just in case the optimal computation path has changed to the + # point that relying on batched quantity is no longer relevant. + self.data.reset(reset_tracking=True) def refresh(self) -> np.ndarray: - return self.data[self.frame_name] + value = self.data.get() + return value[self.frame_name] @dataclass(unsafe_hash=True) -class MultiFramesOrientation(InterfaceQuantity[np.ndarray]): +class MultiFrameOrientation(InterfaceQuantity[np.ndarray]): """Vector representation of the orientation of a given set of frames in world reference frame at the end of the agent step. @@ -512,7 +530,7 @@ class MultiFramesOrientation(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -566,16 +584,14 @@ def initialize(self) -> None: # Force re-initializing shared data if the active set has changed if not was_active: - # Must reset the tracking for shared computation systematically, - # just in case the optimal computation path has changed to the - # point that relying on batched quantity is no longer relevant. - self.requirements["data"].reset(reset_tracking=True) + self.data.reset(reset_tracking=True) def refresh(self) -> np.ndarray: # Return a slice of batched data. # Note that mapping from frame names to frame index in batched data # cannot be pre-computed as it may changed dynamically. - return self.data[self.frame_names] + value = self.data.get() + return value[self.frame_names] @dataclass(unsafe_hash=True) @@ -597,16 +613,16 @@ class _BatchedFramesPosition( def __init__(self, env: InterfaceJiminyEnv, - parent: Union["FramePosition", "MultiFramesPosition"], + parent: Union["FramePosition", "MultiFramePosition"], mode: QuantityEvalMode) -> None: """ :param env: Base or wrapped jiminy environment. - :param parent: `FramePosition` or `MultiFramesPosition` instance from + :param parent: `FramePosition` or `MultiFramePosition` instance from which this quantity is a requirement. :param mode: Desired mode of evaluation for this quantity. """ # Make sure that a suitable parent has been provided - assert isinstance(parent, (FramePosition, MultiFramesPosition)) + assert isinstance(parent, (FramePosition, MultiFramePosition)) # Initialize the ordered list of frame names self.frame_names: Tuple[str, ...] = () @@ -679,7 +695,7 @@ class FramePosition(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -722,14 +738,15 @@ def initialize(self) -> None: # Force re-initializing shared data if the active set has changed if not was_active: - self.requirements["data"].reset(reset_tracking=True) + self.data.reset(reset_tracking=True) def refresh(self) -> np.ndarray: - return self.data[self.frame_name] + value = self.data.get() + return value[self.frame_name] @dataclass(unsafe_hash=True) -class MultiFramesPosition(InterfaceQuantity[np.ndarray]): +class MultiFramePosition(InterfaceQuantity[np.ndarray]): """Position vector (X, Y, Z) of a given set of frames in world reference frame at the end of the agent step. """ @@ -739,7 +756,7 @@ class MultiFramesPosition(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -785,10 +802,11 @@ def initialize(self) -> None: # Force re-initializing shared data if the active set has changed if not was_active: - self.requirements["data"].reset(reset_tracking=True) + self.data.reset(reset_tracking=True) def refresh(self) -> np.ndarray: - return self.data[self.frame_names] + value = self.data.get() + return value[self.frame_names] @dataclass(unsafe_hash=True) @@ -803,7 +821,7 @@ class FrameXYZQuat(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -848,16 +866,16 @@ def __init__(self, def refresh(self) -> np.ndarray: # Copy the position of all frames at once in contiguous buffer - array_copyto(self._xyzquat[:3], self.position) + array_copyto(self._xyzquat[:3], self.position.get()) # Copy the quaternion of all frames at once in contiguous buffer - array_copyto(self._xyzquat[-4:], self.quat) + array_copyto(self._xyzquat[-4:], self.quat.get()) return self._xyzquat @dataclass(unsafe_hash=True) -class MultiFramesXYZQuat(InterfaceQuantity[np.ndarray]): +class MultiFrameXYZQuat(InterfaceQuantity[np.ndarray]): """Spatial vector representation (X, Y, Z, QuatX, QuatY, QuatZ, QuatW) of the transform of a given set of frames in world reference frame at the end of the agent step. @@ -868,7 +886,7 @@ class MultiFramesXYZQuat(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -902,30 +920,30 @@ def __init__(self, env, parent, requirements=dict( - positions=(MultiFramesPosition, dict( + positions=(MultiFramePosition, dict( frame_names=frame_names, mode=mode)), - quats=(MultiFramesOrientation, dict( + quats=(MultiFrameOrientation, dict( frame_names=frame_names, type=OrientationType.QUATERNION, mode=mode))), auto_refresh=False) # Pre-allocate memory for storing the pose XYZQuat of all frames - self._xyzquats = np.zeros((7, len(frame_names)), order='F') + self._xyzquats = np.zeros((7, len(frame_names)), order='C') def refresh(self) -> np.ndarray: # Copy the position of all frames at once in contiguous buffer - array_copyto(self._xyzquats[:3], self.positions) + array_copyto(self._xyzquats[:3], self.positions.get()) # Copy the quaternion of all frames at once in contiguous buffer - array_copyto(self._xyzquats[-4:], self.quats) + array_copyto(self._xyzquats[-4:], self.quats.get()) return self._xyzquats @dataclass(unsafe_hash=True) -class MultiFramesMeanXYZQuat(InterfaceQuantity[np.ndarray]): +class MultiFrameMeanXYZQuat(InterfaceQuantity[np.ndarray]): """Spatial vector representation (X, Y, Z, QuatX, QuatY, QuatZ, QuatW) of the average transform of a given set of frames in world reference frame at the end of the agent step. @@ -946,7 +964,7 @@ class MultiFramesMeanXYZQuat(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -980,16 +998,16 @@ def __init__(self, env, parent, requirements=dict( - positions=(MultiFramesPosition, dict( + positions=(MultiFramePosition, dict( frame_names=frame_names, mode=mode)), - quats=(MultiFramesOrientation, dict( + quats=(MultiFrameOrientation, dict( frame_names=frame_names, type=OrientationType.QUATERNION, mode=mode))), auto_refresh=False) - # Define jit-able specialization of `np.mean` for `axis=-1` + # Jit-able method specialization of `np.mean` for `axis=-1` @nb.jit(nopython=True, cache=True, fastmath=True) def position_average(value: np.ndarray, out: np.ndarray) -> None: """Compute the mean of an array over its last axis only. @@ -1002,7 +1020,7 @@ def position_average(value: np.ndarray, out: np.ndarray) -> None: self._position_average = position_average - # Define jit-able specialization of `quat_average` for 2D matrices + # Jit-able specialization of `quat_average` for 2D matrices @nb.jit(nopython=True, cache=True, fastmath=True) def quat_average_2d(quat: np.ndarray, out: np.ndarray) -> None: """Compute the average of a batch of quaternions [qx, qy, qz, qw]. @@ -1036,14 +1054,156 @@ def quat_average_2d(quat: np.ndarray, out: np.ndarray) -> None: def refresh(self) -> np.ndarray: # Compute the mean translation - self._position_average(self.positions, self._position_mean_view) + self._position_average(self.positions.get(), self._position_mean_view) # Compute the mean quaternion - self._quat_average(self.quats, self._quat_mean_view) + self._quat_average(self.quats.get(), self._quat_mean_view) return self._xyzquat_mean +@dataclass(unsafe_hash=True) +class MultiFrameCollisionDetection(InterfaceQuantity[bool]): + """Check if some geometry objects are colliding with each other. + + It takes into account some safety margins by which their volume will be + inflated / deflated. + + .. note:: + Jiminy enforces all collision geometries to be either primitive shapes + or convex polyhedra for efficiency. In practice, tf meshes where + specified in the original URDF file, then they will be converted into + their respective convex hull. + """ + + frame_names: Tuple[str, ...] + """Name of the bodies of the robot to consider for collision detection. + + All the geometry objects sharing with them the same parent joint will be + taking into account. + """ + + security_margin: float + """Signed distance below which a pair of geometry objects is stated in + collision. + + This can be interpreted as inflating or deflating the geometry objects by + the safety margin depending on whether it is positive or negative + respectively. Therefore, the actual geometry objects do no have to be in + contact to be stated in collision if the satefy margin is positive. On the + contrary, the penetration depth must be large enough if the security margin + is positive. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + frame_names: Sequence[str], + *, + security_margin: float = 0.0) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param frame_names: Name of the bodies of the robot to consider for + collision detection. All the geometry objects + sharing with them the same parent joint will be + taking into account. + :param security_margin: Signed distance below which a pair of geometry + objects is stated in collision. + Optional: 0.0 by default. + """ + # Backup some user-arguments + self.frame_names = tuple(frame_names) + self.security_margin = security_margin + + # Call base implementation + super().__init__( + env, + parent, + requirements={}, + auto_refresh=False) + + # Initialize a broadphase manager for each collision group + self._collision_groups = [ + fcl.DynamicAABBTreeCollisionManager() for _ in frame_names] + + # Initialize pair-wise collision requests between groups of bodies + self._requests: List[Tuple[ + fcl.BroadPhaseCollisionManager, + fcl.BroadPhaseCollisionManager, + fcl.CollisionCallBackBase]] = [] + for i in range(len(frame_names)): + for j in range(i + 1, len(frame_names)): + manager_1 = self._collision_groups[i] + manager_2 = self._collision_groups[j] + callback = fcl.CollisionCallBackDefault() + request: fcl.CollisionRequest = ( + callback.data.request) # pylint: disable=no-member + request.gjk_initial_guess = jiminy.GJKInitialGuess.CachedGuess + # request.gjk_variant = fcl.GJKVariant.NesterovAcceleration + # request.break_distance = 0.1 + request.gjk_tolerance = 1e-6 + request.distance_upper_bound = 1e-6 + request.num_max_contacts = 1 + request.security_margin = security_margin + self._requests.append((manager_1, manager_2, callback)) + + # Store callable responsible to updating transform of colision objects + self._transform_updates: List[Callable[[], None]] = [] + + def initialize(self) -> None: + # Call base implementation + super().initialize() + + # Define robot proxy for convenience + robot = self.env.robot + + # Clear all collision managers + for manager in self._collision_groups: + manager.clear() + + # Get the list of parent joint indices mapping + frame_indices_map: Dict[int, int] = {} + for i, frame_name in enumerate(self.frame_names): + frame_index = robot.pinocchio_model.getFrameId(frame_name) + frame = robot.pinocchio_model.frames[frame_index] + frame_indices_map[frame.parent] = i + + # Add collision objects to their corresponding manager + self._transform_updates.clear() + for i, geom in enumerate(robot.collision_model.geometryObjects): + j = frame_indices_map.get(geom.parentJoint) + if j is not None: + obj = fcl.CollisionObject(geom.geometry) + self._collision_groups[j].registerObject(obj) + pose = robot.collision_data.oMg[i] + translation, rotation = pose.translation, pose.rotation + self._transform_updates += ( + partial(obj.setTranslation, translation), + partial(obj.setRotation, rotation)) + + # Initialize collision detection facilities + for manager in self._collision_groups: + manager.setup() + + def refresh(self) -> bool: + # Update collision object placement + for transform_update in self._transform_updates: + transform_update() + + # Update all collision managers + # for manager in self._collision_groups: + # manager.update() + + # Check collision for all candidate pairs + for manager_1, manager_2, callback in self._requests: + manager_1.collide(manager_2, callback) + if callback.data.result.isCollision(): + return True + return False + + @dataclass(unsafe_hash=True) class _DifferenceFrameXYZQuat(InterfaceQuantity[np.ndarray]): """Motion vector representation (VX, VY, VZ, WX, WY, WZ) of the finite @@ -1062,7 +1222,7 @@ class _DifferenceFrameXYZQuat(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -1102,7 +1262,7 @@ def __init__(self, quantity=(FrameXYZQuat, dict( frame_name=frame_name, mode=mode)), - num_stack=2))), + max_stack=2))), auto_refresh=False) # Define specialize difference operator on SE3 Lie group @@ -1118,7 +1278,7 @@ def refresh(self) -> np.ndarray: # point. This should never occur in practice as it will be fine at # the end of the first step already, before the reward and termination # conditions are evaluated. - xyzquat_prev, xyzquat = self.xyzquat_stack + xyzquat_prev, xyzquat = self.xyzquat_stack.get() # Compute average frame velocity in local frame since previous step self._data[:] = self._difference(xyzquat_prev, xyzquat) @@ -1151,7 +1311,7 @@ class AverageFrameXYZQuat(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -1196,7 +1356,8 @@ def __init__(self, def refresh(self) -> np.ndarray: # Interpolate the average spatial velocity at midpoint - return self._integrate(self.xyzquat_next, - 0.5 * self.xyzquat_diff) + return self._integrate( + self.xyzquat_next.get(), - 0.5 * self.xyzquat_diff.get()) @dataclass(unsafe_hash=True) @@ -1215,7 +1376,7 @@ class AverageFrameRollPitch(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -1259,13 +1420,13 @@ def __init__(self, def refresh(self) -> np.ndarray: # Compute Yaw-free average orientation - remove_yaw_from_quat(self.quat_mean, self._quat_no_yaw_mean) + remove_yaw_from_quat(self.quat_mean.get(), self._quat_no_yaw_mean) return self._quat_no_yaw_mean @dataclass(unsafe_hash=True) -class AverageFrameSpatialVelocity(InterfaceQuantity[np.ndarray]): +class FrameSpatialAverageVelocity(InterfaceQuantity[np.ndarray]): """Average spatial velocity of a given frame at the end of the agent step. The average spatial velocity is obtained by finite difference. More @@ -1295,7 +1456,7 @@ class AverageFrameSpatialVelocity(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -1360,47 +1521,70 @@ def __init__(self, def refresh(self) -> np.ndarray: # Compute average frame velocity in local frame since previous step - np.multiply(self.xyzquat_diff, self._inv_step_dt, self._v_spatial) + np.multiply( + self.xyzquat_diff.get(), self._inv_step_dt, self._v_spatial) # Translate local velocity to world frame if self.reference_frame == pin.LOCAL_WORLD_ALIGNED: # Define world frame as the "middle" between prev and next pose. # Here, we only care about the middle rotation, so we can consider # SO3 Lie Group algebra instead of SE3. - quat_apply(self.quat_mean, self._v_lin_ang, self._v_lin_ang) + quat_apply(self.quat_mean.get(), self._v_lin_ang, self._v_lin_ang) return self._v_spatial @dataclass(unsafe_hash=True) -class ActuatedJointsPosition(AbstractQuantity[np.ndarray]): - """Concatenation of the current position of all the actuated joints - of the robot. - - In practice, all actuated joints must be 1DoF for now. The principal angle - is used in case of revolute unbounded revolute joints. +class MultiActuatedJointKinematic(AbstractQuantity[np.ndarray]): + """Current position, velocity or acceleration of all the actuated joints + of the robot before or after the mechanical transmissions. - .. warning:: - Revolute unbounded joints are not supported for now. + In practice, all actuated joints must be 1DoF for now. In the case of + revolute unbounded revolute joints, the principal angle 'theta' is used to + encode the position, not the polar coordinates `(cos(theta), sin(theta))`. .. warning:: Data is extracted from the true configuration vector instead of using sensor data. As a result, this quantity is appropriate for computing reward components and termination conditions but must be avoided in observers and controllers. + + .. warning:: + Revolute unbounded joints are not supported for now. + """ + + kinematic_level: pin.KinematicLevel + """Kinematic level to consider, ie position, velocity or acceleration. + """ + + is_motor_side: bool + """Whether the compute kinematic data on motor- or joint-side, ie before or + after their respective mechanical transmision. """ def __init__(self, env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], *, + kinematic_level: pin.KinematicLevel = pin.POSITION, + is_motor_side: bool = False, mode: QuantityEvalMode = QuantityEvalMode.TRUE) -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. + :param kinematic_level: Desired kinematic level, ie position, velocity + or acceleration. + :param is_motor_side: Whether the compute kinematic data on motor- or + joint-side, ie before or after the mechanical + transmisions. + Optional: False by default. :param mode: Desired mode of evaluation for this quantity. """ + # Backup some of the user-arguments + self.kinematic_level = kinematic_level + self.is_motor_side = is_motor_side + # Call base implementation super().__init__( env, @@ -1415,10 +1599,13 @@ def __init__(self, # Note that it will only be used in last resort if it can be written as # a slice. Indeed, "fancy" indexing returns a copy of the original data # instead of a view, which requires fetching data at every refresh. - self.position_indices: List[int] = [] + self.kinematic_indices: List[int] = [] + + # Keep track of the mechanical reduction ratio for all the motors + self._joint_to_motor_ratios = np.array([]) # Buffer storing mechanical joint positions - self.data = np.array([]) + self._data = np.array([]) # Whether mechanical joint positions must be updated at every refresh self._must_refresh = False @@ -1427,25 +1614,41 @@ def initialize(self) -> None: # Call base implementation super().initialize() - # Refresh mechanical joint position indices - self.position_indices.clear() - for motor in self.env.robot.motors: - joint_index = self.pinocchio_model.getJointId(motor.joint_name) - joint = self.pinocchio_model.joints[joint_index] + # Make sure that the state data meet requirements + state = self.state.get() + if ((self.kinematic_level == pin.ACCELERATION and state.a is None) or + (self.kinematic_level >= pin.VELOCITY and state.v is None)): + raise RuntimeError( + "Available state data do not meet requirements for kinematic " + f"level '{self.kinematic_level}'.") + + # Refresh mechanical joint position indices and reduction ratio + joint_to_motor_ratios = [] + self.kinematic_indices.clear() + for motor in self.robot.motors: + joint = self.pinocchio_model.joints[motor.joint_index] joint_type = jiminy.get_joint_type(joint) if joint_type == jiminy.JointModelType.ROTARY_UNBOUNDED: raise ValueError( "Revolute unbounded joints are not supported for now.") - self.position_indices += range(joint.idx_q, joint.idx_q + joint.nq) + if self.kinematic_level == pin.KinematicLevel.POSITION: + kin_first, kin_last = joint.idx_q, joint.idx_q + joint.nq + else: + kin_first, kin_last = joint.idx_v, joint.idx_v + joint.nv + motor_options = motor.get_options() + mechanical_reduction = motor_options["mechanicalReduction"] + joint_to_motor_ratios.append(mechanical_reduction) + self.kinematic_indices += range(kin_first, kin_last) + self._joint_to_motor_ratios = np.array(joint_to_motor_ratios) # Determine whether data can be extracted from state by reference - position_first = min(self.position_indices) - position_last = max(self.position_indices) + kin_first = min(self.kinematic_indices) + kin_last = max(self.kinematic_indices) self._must_refresh = True if self.mode == QuantityEvalMode.TRUE: try: - if (np.array(self.position_indices) == np.arange( - position_first, position_last + 1)).all(): + if np.all(np.array(self.kinematic_indices) == np.arange( + kin_first, kin_last + 1)): self._must_refresh = False else: warnings.warn( @@ -1456,13 +1659,165 @@ def initialize(self) -> None: # Try extracting mechanical joint positions by reference if possible if self._must_refresh: - self.data = np.full((len(self.position_indices),), float("nan")) + self._data = np.full((len(self.kinematic_indices),), float("nan")) else: - self.data = self.state.q[slice(position_first, position_last + 1)] + state = self.state.get() + if self.kinematic_level == pin.KinematicLevel.POSITION: + self._data = state.q[slice(kin_first, kin_last + 1)] + elif self.kinematic_level == pin.KinematicLevel.VELOCITY: + self._data = state.v[slice(kin_first, kin_last + 1)] + else: + self._data = state.a[slice(kin_first, kin_last + 1)] def refresh(self) -> np.ndarray: # Update mechanical joint positions only if necessary + state = self.state.get() if self._must_refresh: - self.state.q.take(self.position_indices, None, self.data, "clip") + if self.kinematic_level == pin.KinematicLevel.POSITION: + data = state.q + elif self.kinematic_level == pin.KinematicLevel.VELOCITY: + data = state.v + else: + data = state.a + data.take(self.kinematic_indices, None, self._data, "clip") + + # Translate encoder data at joint level + if self.is_motor_side: + self._data *= self._joint_to_motor_ratios + + return self._data + + +class EnergyGenerationMode(IntEnum): + """Specify what happens to the energy generated by motors when breaking. + """ + + CHARGE = 0 + """The energy flows back to the battery to charge them without any kind of + losses in the process if negative overall. + """ + + LOST_EACH = 1 + """The generated energy by each motor individually is lost by thermal + dissipation, without flowing back to the battery nor powering other motors + consuming energy if any. + """ + + LOST_GLOBAL = 2 + """The energy is lost by thermal dissipation without flowing back to the + battery if negative overall. + """ + + PENALIZE = 3 + """The generated energy by each motor individually is treated as consumed. + """ + + +# Define proxies for fast lookup +_CHARGE, _LOST_EACH, _LOST_GLOBAL, _PENALIZE = map(int, EnergyGenerationMode) + + +@dataclass(unsafe_hash=True) +class AverageMechanicalPowerConsumption(InterfaceQuantity[float]): + """Average mechanical power consumption by all the motors over a sliding + time window. + """ + + max_stack: int + """Time horizon over which values of the instantaneous power consumption + will be stacked for computing the average. + """ + + generator_mode: EnergyGenerationMode + """Specify what happens to the energy generated by motors when breaking. + See `EnergyGenerationMode` documentation for details. + """ + + mode: QuantityEvalMode + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` + documentation for details about each mode. + + .. warning:: + Mode `REFERENCE` requires a reference trajectory to be selected + manually prior to evaluating this quantity for the first time. + """ + + def __init__( + self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + *, + horizon: float, + generator_mode: EnergyGenerationMode = EnergyGenerationMode.CHARGE, + mode: QuantityEvalMode = QuantityEvalMode.TRUE) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the average. + :param generator_mode: Specify what happens to the energy generated by + motors when breaking. + Optional: `EnergyGenerationMode.CHARGE` by + default. + :param mode: Desired mode of evaluation for this quantity. + """ + # Convert horizon in stack length, assuming constant env timestep + max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + + # Backup some of the user-arguments + self.max_stack = max_stack + self.generator_mode = generator_mode + self.mode = mode + + # Jit-able method computing the total instantaneous power consumption + @nb.jit(nopython=True, cache=True, fastmath=True) + def _compute_power(generator_mode: int, # EnergyGenerationMode + motor_velocities: np.ndarray, + motor_efforts: np.ndarray) -> float: + """Compute the total instantaneous mechanical power consumption of + all motors. + + :param generator_mode: Specify what happens to the energy generated + by motors when breaking. + :param motor_velocities: Velocity of all the motors before + transmission as a 1D array. The order must + be consistent with the motor indices. + :param motor_efforts: Effort of all the motors before transmission + as a 1D array. The order must be consistent + with the motor indices. + """ + if generator_mode in (_CHARGE, _LOST_GLOBAL): + total_power = np.dot(motor_velocities, motor_efforts) + if generator_mode == _CHARGE: + return total_power + return max(total_power, 0.0) + motor_powers = motor_velocities * motor_efforts + if generator_mode == _LOST_EACH: + return np.sum(np.maximum(motor_powers, 0.0)) + return np.sum(np.abs(motor_powers)) + + # Call base implementation + super().__init__( + env, + parent, + requirements=dict( + total_power_stack=(StackedQuantity, dict( + quantity=(BinaryOpQuantity, dict( + quantity_left=(UnaryOpQuantity, dict( + quantity=(StateQuantity, dict( + update_kinematics=False, + mode=self.mode)), + op=lambda state: state.command)), + quantity_right=(MultiActuatedJointKinematic, dict( + kinematic_level=pin.KinematicLevel.VELOCITY, + is_motor_side=True, + mode=self.mode)), + op=partial(_compute_power, int(self.generator_mode)))), + max_stack=self.max_stack, + as_array=True, + mode='slice'))), + auto_refresh=False) - return self.data + def refresh(self) -> float: + return np.mean(self.total_power_stack.get()) diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py index 45db50258..1cb5f4b2e 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py @@ -16,13 +16,13 @@ from ..bases import ( InterfaceJiminyEnv, InterfaceQuantity, AbstractQuantity, StateQuantity, QuantityEvalMode) +from ..quantities import ( + MaskedQuantity, FramePosition, MultiFramePosition, MultiFrameXYZQuat, + MultiFrameMeanXYZQuat, MultiFrameCollisionDetection, + FrameSpatialAverageVelocity, AverageFrameRollPitch) from ..utils import ( matrix_to_yaw, quat_to_yaw, quat_to_matrix, quat_multiply, quat_apply) -from ..quantities import ( - MaskedQuantity, MultiFramesXYZQuat, MultiFramesMeanXYZQuat, - AverageFrameSpatialVelocity, AverageFrameRollPitch) - def sanitize_foot_frame_names( env: InterfaceJiminyEnv, @@ -104,6 +104,62 @@ def translate_position_odom(position: np.ndarray, out[1] = - sin_yaw * pos_rel_x + cos_yaw * pos_rel_y +@dataclass(unsafe_hash=True) +class BaseRelativeHeight(InterfaceQuantity[float]): + """Relative height of the floating base of the robot wrt lowest contact + point or collision body in world frame. + """ + + mode: QuantityEvalMode + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` + documentation for details about each mode. + + .. warning:: + Mode `REFERENCE` requires a reference trajectory to be selected + manually prior to evaluating this quantity for the first time. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + *, + mode: QuantityEvalMode = QuantityEvalMode.TRUE) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param mode: Desired mode of evaluation for this quantity. + Optional: 'QuantityEvalMode.TRUE' by default. + """ + # Backup some user argument(s) + self.mode = mode + + # Get all frame constraints associated with contacts and collisions + frame_names: List[str] = [] + for constraint in env.robot.constraints.contact_frames.values(): + assert isinstance(constraint, jiminy.FrameConstraint) + frame_names.append(constraint.frame_name) + for constraints_body in env.robot.constraints.collision_bodies: + for constraint in constraints_body: + assert isinstance(constraint, jiminy.FrameConstraint) + frame_names.append(constraint.frame_name) + + # Call base implementation + super().__init__( + env, + parent, + requirements=dict( + base_pos=(FramePosition, dict( + frame_name="root_joint")), + contacts_pos=(MultiFramePosition, dict( + frame_names=frame_names))), + auto_refresh=False) + + def refresh(self) -> float: + base_pos, contacts_pos = self.base_pos.get(), self.contacts_pos.get() + return base_pos[2] - np.min(contacts_pos[2]) + + @dataclass(unsafe_hash=True) class BaseOdometryPose(AbstractQuantity[np.ndarray]): """Odometry pose of the floating base of the robot at the end of the agent @@ -112,6 +168,11 @@ class BaseOdometryPose(AbstractQuantity[np.ndarray]): The odometry pose fully specifies the position and heading of the robot in 2D world plane. As such, it comprises the linear translation (X, Y) and the rotation around Z axis (namely rate of change of Yaw Euler angle). + Mathematically, one is supposed to rely on se2 Lie Algebra for performing + operations on odometry poses such as differentiation. In practice, the + double geodesic metric space is used instead to prevent coupling between + the linear and angular parts by considering them independently. Strictly + speaking, it corresponds to the cartesian space (R^2 x SO(2)). """ def __init__(self, @@ -165,12 +226,12 @@ def refresh(self) -> np.ndarray: @dataclass(unsafe_hash=True) -class AverageBaseSpatialVelocity(InterfaceQuantity[np.ndarray]): +class BaseSpatialAverageVelocity(InterfaceQuantity[np.ndarray]): """Average base spatial velocity of the floating base of the robot in local odometry frame at the end of the agent step. The average spatial velocity is obtained by finite difference. See - `AverageFrameSpatialVelocity` documentation for details. + `FrameSpatialAverageVelocity` documentation for details. Roughly speaking, the local odometry reference frame is half-way between `pinocchio.LOCAL` and `pinocchio.LOCAL_WORLD_ALIGNED`. The z-axis is @@ -180,7 +241,7 @@ class AverageBaseSpatialVelocity(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -208,7 +269,7 @@ def __init__(self, env, parent, requirements=dict( - v_spatial=(AverageFrameSpatialVelocity, dict( + v_spatial=(FrameSpatialAverageVelocity, dict( frame_name="root_joint", reference_frame=pin.LOCAL, mode=mode)), @@ -225,25 +286,26 @@ def __init__(self, def refresh(self) -> np.ndarray: # Translate spatial base velocity from local to odometry frame - quat_apply(self.quat_no_yaw_mean, - self.v_spatial.reshape((2, 3)).T, + v_spatial = self.v_spatial.get() + quat_apply(self.quat_no_yaw_mean.get(), + v_spatial.reshape((2, 3)).T, self._v_lin_ang) return self._v_spatial @dataclass(unsafe_hash=True) -class AverageBaseOdometryVelocity(InterfaceQuantity[np.ndarray]): +class BaseOdometryAverageVelocity(InterfaceQuantity[np.ndarray]): """Average odometry velocity of the floating base of the robot in local odometry frame at the end of the agent step. The odometry velocity fully specifies the linear and angular velocity of - the robot in 2D world plane. See `AverageBaseSpatialVelocity` and + the robot in 2D world plane. See `BaseSpatialAverageVelocity` and `BaseOdometryPose`, documentations for details. """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -272,13 +334,13 @@ def __init__(self, parent, requirements=dict( data=(MaskedQuantity, dict( - quantity=(AverageBaseSpatialVelocity, dict( + quantity=(BaseSpatialAverageVelocity, dict( mode=mode)), keys=(0, 1, 5)))), auto_refresh=False) def refresh(self) -> np.ndarray: - return self.data + return self.data.get() @dataclass(unsafe_hash=True) @@ -319,7 +381,7 @@ def __init__(self, parent, requirements=dict( v_angular=(MaskedQuantity, dict( - quantity=(AverageFrameSpatialVelocity, dict( + quantity=(FrameSpatialAverageVelocity, dict( frame_name="root_joint", reference_frame=pin.LOCAL, mode=mode)), @@ -344,10 +406,11 @@ def initialize(self) -> None: def refresh(self) -> np.ndarray: # Compute the local angular momentum of inertia - np.matmul(self._inertia_local, self.v_angular, self._h_angular) + np.matmul(self._inertia_local, self.v_angular.get(), self._h_angular) # Apply quaternion rotation of the local angular momentum of inertia - quat_apply(self.quat_no_yaw_mean, self._h_angular, self._h_angular) + quat_apply( + self.quat_no_yaw_mean.get(), self._h_angular, self._h_angular) return self._h_angular @@ -373,7 +436,7 @@ class MultiFootMeanXYZQuat(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -406,13 +469,13 @@ def __init__(self, env, parent, requirements=dict( - data=(MultiFramesMeanXYZQuat, dict( + data=(MultiFrameMeanXYZQuat, dict( frame_names=self.frame_names, mode=mode))), auto_refresh=False) def refresh(self) -> np.ndarray: - return self.data + return self.data.get() @dataclass(unsafe_hash=True) @@ -437,7 +500,7 @@ class MultiFootMeanOdometryPose(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -485,10 +548,11 @@ def __init__(self, def refresh(self) -> np.ndarray: # Copy translation part - array_copyto(self._xy_view, self.xyzquat_mean[:2]) + xyzquat_mean = self.xyzquat_mean.get() + array_copyto(self._xy_view, xyzquat_mean[:2]) # Compute Yaw angle - quat_to_yaw(self.xyzquat_mean[-4:], self._yaw_view) + quat_to_yaw(xyzquat_mean[-4:], self._yaw_view) return self._odom_pose @@ -505,7 +569,7 @@ class MultiFootRelativeXYZQuat(InterfaceQuantity[np.ndarray]): wrt the others. Notably, in particular case where there is only two frames, it is one is the opposite of the other. As a result, the last relative pose is always dropped from the returned value, based on the same ordering as - 'self.frame_names'. As for `MultiFramesXYZQuat`, the data associated with + 'self.frame_names'. As for `MultiFrameXYZQuat`, the data associated with each frame are returned as a 2D contiguous array. The first dimension gathers the 7 components (X, Y, Z, QuatX, QuatY, QuatZ, QuaW), while the last one corresponds to individual relative frames poses. @@ -519,7 +583,7 @@ class MultiFootRelativeXYZQuat(InterfaceQuantity[np.ndarray]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -556,12 +620,12 @@ def __init__(self, xyzquat_mean=(MultiFootMeanXYZQuat, dict( frame_names=self.frame_names, mode=mode)), - xyzquats=(MultiFramesXYZQuat, dict( + xyzquats=(MultiFrameXYZQuat, dict( frame_names=self.frame_names, mode=mode))), auto_refresh=False) - # Define jit-able method translating multiple positions to local frame + # Jit-able method translating multiple positions to local frame @nb.jit(nopython=True, cache=True, fastmath=True) def translate_positions(position: np.ndarray, position_ref: np.ndarray, @@ -595,7 +659,7 @@ def translate_positions(position: np.ndarray, def refresh(self) -> np.ndarray: # Extract mean and individual frame position and quaternion vectors - xyzquats, xyzquat_mean = self.xyzquats, self.xyzquat_mean + xyzquats, xyzquat_mean = self.xyzquats.get(), self.xyzquat_mean.get() positions, position_mean = xyzquats[:3, :-1], xyzquat_mean[:3] quats, quat_mean = xyzquats[-4:, :-1], xyzquat_mean[-4:] @@ -651,8 +715,8 @@ def __init__( :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. - :para kinematic_level: Desired kinematic level, ie position, velocity - or acceleration. + :param kinematic_level: Desired kinematic level, ie position, velocity + or acceleration. :param mode: Desired mode of evaluation for this quantity. Optional: 'QuantityEvalMode.TRUE' by default. """ @@ -677,7 +741,7 @@ def initialize(self) -> None: super().initialize() # Make sure that the state data meet requirements - state = self.state + state = self.state.get() if ((self.kinematic_level == pin.ACCELERATION and state.a is None) or (self.kinematic_level >= pin.VELOCITY and state.v is None)): raise RuntimeError( @@ -773,7 +837,8 @@ def initialize(self) -> None: super().initialize() # Make sure that the state data meet requirements - if self.state.v is None or self.state.a is None: + state = self.state.get() + if state.v is None or state.a is None: raise RuntimeError( "State data do not meet requirements. Velocity and " "acceleration are missing.") @@ -788,7 +853,7 @@ def initialize(self) -> None: def refresh(self) -> np.ndarray: # Extract intermediary quantities for convenience - (dhg_linear, dhg_angular), com = self.dhg, self.com_position + (dhg_linear, dhg_angular), com = self.dhg, self.com_position.get() # Compute the vertical force applied by the robot f_z = dhg_linear[2] + self._robot_weight @@ -801,7 +866,7 @@ def refresh(self) -> np.ndarray: # Translate the ZMP from world to local odometry frame if requested if self.reference_frame == pin.LOCAL: - translate_position_odom(self._zmp, self.odom_pose, self._zmp) + translate_position_odom(self._zmp, self.odom_pose.get(), self._zmp) return self._zmp @@ -878,7 +943,8 @@ def initialize(self) -> None: super().initialize() # Make sure that the state data meet requirements - if self.state.v is None: + state = self.state.get() + if state.v is None: raise RuntimeError( "State data do not meet requirements. Velocity is missing.") @@ -902,29 +968,30 @@ def initialize(self) -> None: def refresh(self) -> np.ndarray: # Compute the DCM in world frame - com_position, com_velocity = self.com_position, self.com_velocity + com_position = self.com_position.get() + com_velocity = self.com_velocity.get() self._dcm[:] = com_position[:2] + com_velocity[:2] / self.omega # Translate the ZMP from world to local odometry frame if requested if self.reference_frame == pin.LOCAL: - translate_position_odom(self._dcm, self.odom_pose, self._dcm) + translate_position_odom(self._dcm, self.odom_pose.get(), self._dcm) return self._dcm @dataclass(unsafe_hash=True) -class MultiContactRelativeForceTangential(AbstractQuantity[np.ndarray]): - """Standardized tangential forces apply on all contact points and collision +class MultiContactNormalizedSpatialForce(AbstractQuantity[np.ndarray]): + """Standardized spatial forces apply on all contact points and collision bodies in their respective local contact frame. The local contact frame is defined as the frame having the normal of the ground as vertical axis, and the vector orthogonal to the x-axis in world frame as y-axis. - The tangential force is rescaled by the weight of the robot rather than the + The spatial force is rescaled by the weight of the robot rather than the actual vertical force. It has the advantage to guarantee that the resulting quantity is never poorly conditioned, which would be the case otherwise. - Moreover, the effect of the vertical force is not canceled out, which is + Moreover, the contribution of the vertical force is still present, which is interesting for deriving a reward, as it allows for indirectly penalize jerky contact states and violent impacts. The side effect is not being able to guarantee that this quantity is bounded. Indeed, only the ratio of the @@ -954,14 +1021,14 @@ def __init__(self, mode=mode, auto_refresh=False) - # Define jit-able method compute the normalized tangential forces + # Jit-able method computing the normalized spatial forces @nb.jit(nopython=True, cache=True, fastmath=True) - def normalize_tangential_forces(lambda_c: np.ndarray, - index_start: int, - index_end: int, - robot_weight: float, - out: np.ndarray) -> None: - """Compute the tangential forces of all the constraints associated + def normalize_spatial_forces(lambda_c: np.ndarray, + index_start: int, + index_end: int, + robot_weight: float, + out: np.ndarray) -> None: + """Compute the spatial forces of all the constraints associated with contact frames and collision bodies, normalized by the total weight of the robot. @@ -971,20 +1038,20 @@ def normalize_tangential_forces(lambda_c: np.ndarray, :param index_end: One-past-last index of the constraints associated with contact frames and collisions bodies. :param robot_weight: Total weight of the robot which will be used - to rescale the tangential forces. + to rescale the spatial forces. :param out: Pre-allocated array in which to store the result. """ # Extract constraint lambdas of contacts and collisions from state lambda_ = lambda_c[index_start:index_end].reshape((-1, 4)).T - # Extract references to all the tangential forces - # f_lin, f_ang = lambda_[:3], np.array([0.0, 0.0, lambda_[3]]) - forces_tangential = lambda_[:2] + # Extract references to all the spatial forces + forces_linear, forces_angular_z = lambda_[:3], lambda_[3] - # Scale the tangential forces by the weight of the robot - np.divide(forces_tangential, robot_weight, out) + # Scale the spatial forces by the weight of the robot + out[:3] = forces_linear / robot_weight + out[5] = forces_angular_z / robot_weight - self._normalize_tangential_forces = normalize_tangential_forces + self._normalize_spatial_forces = normalize_spatial_forces # Weight of the robot self._robot_weight: float = float("nan") @@ -992,15 +1059,16 @@ def normalize_tangential_forces(lambda_c: np.ndarray, # Slice of constraint lambda multipliers for contacts and collisions self._contact_slice: Tuple[int, int] = (0, 0) - # Stacked tangential forces on all contact points and collision bodies - self._force_tangential_rel_batch = np.array([]) + # Stacked spatial forces on all contact points and collision bodies + self._force_spatial_rel_batch = np.empty((6, 0)) def initialize(self) -> None: # Call base implementation super().initialize() # Make sure that the state data meet requirements - if self.state.lambda_c is None: + state = self.state.get() + if state.lambda_c is None: raise RuntimeError("State data do not meet requirements. " "Constraints lambda multipliers are missing.") @@ -1042,22 +1110,23 @@ def initialize(self) -> None: map(len, self.robot.constraints.collision_bodies)) assert 4 * num_contraints == index_last - index_first - # Pre-allocated memory for stacked normalized tangential forces - self._force_tangential_rel_batch = np.zeros( - (2, num_contraints), order='F') + # Pre-allocated memory for stacked normalized spatial forces + self._force_spatial_rel_batch = np.zeros( + (6, num_contraints), order='C') def refresh(self) -> np.ndarray: - self._normalize_tangential_forces( - self.state.lambda_c, + state = self.state.get() + self._normalize_spatial_forces( + state.lambda_c, *self._contact_slice, self._robot_weight, - self._force_tangential_rel_batch) + self._force_spatial_rel_batch) - return self._force_tangential_rel_batch + return self._force_spatial_rel_batch @dataclass(unsafe_hash=True) -class MultiFootRelativeForceVertical(AbstractQuantity[np.ndarray]): +class MultiFootNormalizedForceVertical(AbstractQuantity[np.ndarray]): """Standardized total vertical forces apply on each foot in world frame. The lambda multipliers of the contact constraints are used to compute the @@ -1107,7 +1176,7 @@ def __init__(self, mode=mode, auto_refresh=False) - # Define jit-able method compute the normalized tangential forces + # Jit-able method computing the normalized vertical forces @nb.jit(nopython=True, cache=True, fastmath=True) def normalize_vertical_forces( lambda_c: np.ndarray, @@ -1132,7 +1201,7 @@ def normalize_vertical_forces( dimension gathers the 3 spatial coordinates while the second corresponds to the N individual constraints on each foot. :param robot_weight: Total weight of the robot which will be used - to rescale the tangential forces. + to rescale the vertical forces. :param out: Pre-allocated array in which to store the result. """ for i, ((index_start, index_end), vertical_transforms) in ( @@ -1141,11 +1210,11 @@ def normalize_vertical_forces( lambda_ = lambda_c[index_start:index_end].reshape((-1, 4)).T # Extract references to all the linear forces - # f_ang = np.array([0.0, 0.0, lambda_[3]]) - f_lin = lambda_[:3] + # forces_angular = np.array([0.0, 0.0, lambda_[3]]) + forces_linear = lambda_[:3] # Compute vertical forces in world frame and aggregate them - f_z_world = np.sum(vertical_transforms * f_lin) + f_z_world = np.sum(vertical_transforms * forces_linear) # Scale the vertical forces by the weight of the robot out[i] = f_z_world / robot_weight @@ -1175,7 +1244,8 @@ def initialize(self) -> None: super().initialize() # Make sure that the state data meet requirements - if self.state.lambda_c is None: + state = self.state.get() + if state.lambda_c is None: raise RuntimeError("State data do not meet requirements. " "Constraints lambda multipliers are missing.") @@ -1244,10 +1314,63 @@ def refresh(self) -> np.ndarray: self._vertical_transform_list) # Compute the normalized sum of the vertical forces in world frame - self._normalize_vertical_forces(self.state.lambda_c, + state = self.state.get() + self._normalize_vertical_forces(state.lambda_c, self._foot_slices, self._vertical_transform_batches, self._robot_weight, self._vertical_force_batch) return self._vertical_force_batch + + +@dataclass(unsafe_hash=True) +class MultiFootCollisionDetection(InterfaceQuantity[bool]): + """Check if some of the feet of the robot are colliding with each other. + + It takes into account some safety margins by which their volume will be + inflated / deflated. See `MultiFrameCollisionDetection` documentation for + details. + """ + + frame_names: Tuple[str, ...] + """Name of the frames corresponding to some feet of the robot. + + These frames must be part of the end-effectors, ie being associated with a + leaf joint in the kinematic tree of the robot. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + frame_names: Union[Sequence[str], Literal['auto']] = 'auto', + *, + security_margin: float = 0.0) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param frame_names: Name of the frames corresponding to some feet of + the robot. 'auto' to automatically detect them from + the set of contact and force sensors of the robot. + Optional: 'auto' by default. + :param security_margin: Signed distance below which a pair of geometry + objects is stated in collision. + Optional: 0.0 by default. + """ + # Backup some user argument(s) + self.frame_names = tuple(sanitize_foot_frame_names(env, frame_names)) + + # Call base implementation + super().__init__( + env, + parent, + requirements=dict( + is_colliding=(MultiFrameCollisionDetection, dict( + frame_names=self.frame_names, + security_margin=security_margin + ))), + auto_refresh=False) + + def refresh(self) -> bool: + return self.is_colliding.get() diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py index f6f880f2d..4f8d689e1 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py @@ -53,8 +53,13 @@ def __init__(self, env: InterfaceJiminyEnv) -> None: # This is necessary because using a quantity as key directly would # prevent its garbage collection, hence breaking automatic reset of # computation tracking for all quantities sharing its cache. - self._caches: Dict[ - Tuple[Type[InterfaceQuantity], int], SharedCache] = {} + # In case of dataclasses, their hash is the same as if it was obtained + # using `hash(dataclasses.astuple(quantity))`. This is clearly not + # unique, as all it requires to be the same is being built from the + # same nested ordered arguments. To get around this issue, we need to + # store (key, value) pairs in a list. + self._caches: List[Tuple[ + Tuple[Type[InterfaceQuantity], int], SharedCache]] = [] # Instantiate trajectory database. # Note that this quantity is not added to the global registry to avoid @@ -77,9 +82,7 @@ def reset(self, reset_tracking: bool = False) -> None: """ # Reset all quantities sequentially for quantity in self.registry.values(): - quantity.reset( - reset_tracking, - ignore_auto_refresh=not self.env.is_simulation_running) + quantity.reset(reset_tracking) def clear(self) -> None: """Clear internal cache of quantities to force re-evaluating them the @@ -90,8 +93,8 @@ def clear(self) -> None: environment has changed (ie either the agent or world itself), thereby invalidating the value currently stored in cache if any. """ - for cache in self._caches.values(): - cache.reset() + for _, cache in self._caches: + cache.reset(ignore_auto_refresh=not self.env.is_simulation_running) def add_trajectory(self, name: str, trajectory: Trajectory) -> None: """Add a new reference trajectory to the database synchronized between @@ -171,8 +174,9 @@ def _build_quantity( raise an exception if another quantity with the exact same name exists. :param quantity_creator: Tuple gathering the class of the new quantity - to manage plus its keyword-arguments except - environment and parent as a dictionary. + to manage plus any keyword-arguments of its + constructor as a dictionary except 'env' and + 'parent'. """ # Instantiate the new quantity quantity_cls, quantity_kwargs = quantity_creator @@ -181,9 +185,24 @@ def _build_quantity( # Set a shared cache entry for all quantities involved in computations quantities_all = [top_quantity] while quantities_all: + # Deal with the first quantity in the process queue quantity = quantities_all.pop() + + # Get already available cache entry if any, otherwise create it key = (type(quantity), hash(quantity)) - quantity.cache = self._caches.setdefault(key, SharedCache()) + for cache_key, cache in self._caches: + if key == cache_key: + owner, *_ = cache.owners + if quantity == owner: + break + else: + cache = SharedCache() + self._caches.append((key, cache)) + + # Set shared cache of the quantity + quantity.cache = cache + + # Add all the requirements of the new quantity in the process queue quantities_all += quantity.requirements.values() return top_quantity @@ -197,8 +216,9 @@ def __setitem__(self, raise an exception if another quantity with the exact same name exists. :param quantity_creator: Tuple gathering the class of the new quantity - to manage plus its keyword-arguments except - environment and parent as a dictionary. + to manage plus any keyword-arguments of its + constructor as a dictionary except 'env' and + 'parent'. """ # Make sure that no quantity with the same name is already managed to # avoid silently overriding quantities being managed in user's back. @@ -233,11 +253,21 @@ def __delitem__(self, name: str) -> None: :param name: Name of the managed quantity to be discarded. It will raise an exception if the specified name does not exists. """ - # Remove shared cache entry for all quantities involved in computations + # Remove shared cache entries for the quantity and its requirements. + # Note that done top-down rather than bottom-up, otherwise reset of + # required quantities no longer having shared cache will be triggered + # automatically by parent quantities following computation graph + # tracking reset whenever a shared cache co-owner is removed. quantities_all = [self.registry.pop(name)] while quantities_all: - quantity = quantities_all.pop() + quantity = quantities_all.pop(0) + cache = quantity.cache quantity.cache = None # type: ignore[assignment] + if len(cache.owners) == 0: + for i, (_, _cache) in enumerate(self._caches): + if cache is _cache: + del self._caches[i] + break quantities_all += quantity.requirements.values() def __iter__(self) -> Iterator[str]: diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py index f471af826..8168a3c1a 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py @@ -2,16 +2,19 @@ its topology (multiple or single branch, fixed or floating base...) and the application (locomotion, grasping...). """ +import sys +import warnings from copy import deepcopy -from collections import deque from dataclasses import dataclass from typing import ( Any, Optional, Sequence, Tuple, TypeVar, Union, Generic, ClassVar, - Callable) + Callable, Literal, List, overload, cast) from typing_extensions import TypeAlias import numpy as np +from jiminy_py.core import ( # pylint: disable=no-name-in-module + multi_array_copyto) from ..bases import InterfaceJiminyEnv, InterfaceQuantity, QuantityCreator @@ -24,7 +27,8 @@ @dataclass(unsafe_hash=True) -class StackedQuantity(InterfaceQuantity[Tuple[ValueT, ...]]): +class StackedQuantity( + InterfaceQuantity[OtherValueT], Generic[ValueT, OtherValueT]): """Keep track of a given quantity over time by automatically stacking its value once per environment step since last reset. @@ -34,50 +38,110 @@ class StackedQuantity(InterfaceQuantity[Tuple[ValueT, ...]]): controller updates are ignored. """ - quantity: InterfaceQuantity + quantity: InterfaceQuantity[ValueT] """Base quantity whose value must be stacked over time since last reset. """ - num_stack: Optional[int] + max_stack: int """Maximum number of values that keep in memory before starting to discard - the oldest one (FIFO). None if unlimited. + the oldest one (FIFO). `sys.maxsize` if unlimited. + """ + + as_array: bool + """Whether to return data as a tuple or a contiguous N-dimensional array + whose last dimension gathers the value of individual timesteps. + """ + + mode: Literal['slice', 'zeros'] + """Fallback strategy in case of incomplete stack. "slice" returns only + available data, "zeros" returns a zero-padded fixed-length stack. """ allow_update_graph: ClassVar[bool] = False """Disable dynamic computation graph update. """ - def __init__(self, + @overload + def __init__(self: "StackedQuantity[ValueT, List[ValueT]]", env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantity: QuantityCreator[ValueT], *, - num_stack: Optional[int] = None) -> None: + max_stack: int, + as_array: Literal[False], + mode: Literal['slice', 'zeros']) -> None: + ... + + @overload + def __init__(self: "StackedQuantity[Union[np.ndarray, float], np.ndarray]", + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + quantity: QuantityCreator[Union[np.ndarray, float]], + *, + max_stack: int, + as_array: Literal[True], + mode: Literal['slice', 'zeros']) -> None: + ... + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + quantity: QuantityCreator[Any], + *, + max_stack: int = sys.maxsize, + as_array: bool = False, + mode: Literal['slice', 'zeros'] = 'slice') -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. :param quantity: Tuple gathering the class of the quantity whose values must be stacked, plus all its constructor keyword- - arguments except environment 'env' and parent 'parent. - :param num_stack: Maximum number of values that keep in memory before - starting to discard the oldest one (FIFO). None if - unlimited. + arguments except environment 'env' and 'parent'. + :param max_stack: Maximum number of values that keep in memory before + starting to discard the oldest one (FIFO). + Optional: The maxium sequence length by default, ie + `sys.maxsize` (2^63 - 1). + :param as_array: Whether to return data as a list or a contiguous + N-dimensional array whose last dimension gathers the + value of individual timesteps. + :param mode: Fallback strategy in case of incomplete stack. + 'zeros' is only supported by quantities returning + fixed-size N-D array. + Optional: 'slice' by default. """ + # Make sure that the input arguments are valid + if max_stack > 10000 and (mode != 'slice' or as_array): + warnings.warn( + "Very large stack length is strongly discourages for " + "`mode != 'slice'` or `as_array=True`.") + # Backup user arguments - self.num_stack = num_stack + self.max_stack = max_stack + self.as_array = as_array + self.mode = mode # Call base implementation super().__init__(env, parent, - requirements=dict(data=quantity), + requirements=dict(quantity=quantity), auto_refresh=True) - # Keep track of the quantity that must be stacked once instantiated - self.quantity = self.requirements["data"] + # Allocate stack buffer. + # Note that using a plain list is more efficient in practice. Although + # front deletion is very fast compared to list, casting deque to tuple + # or list is very slow, which ultimately prevail. The matter gets worst + # as the maximum length gets longer. + self._value_list: List[ValueT] = [] - # Allocate deque buffer - self._deque: deque = deque(maxlen=self.num_stack) + # Continuous memory to store the whole stack if requested. + # Note that it will be allocated lazily since the dimension of the + # quantity is not known in advance. + self._data = np.array([]) + self._data_views: Tuple[np.ndarray, ...] = () + + # Define proxy to number of steps of current episode for fast access + self._num_steps = np.array(-1) # Keep track of the last time the quantity has been stacked self._num_steps_prev = -1 @@ -86,26 +150,79 @@ def initialize(self) -> None: # Call base implementation super().initialize() - # Clear buffer - self._deque.clear() + # Refresh proxy + self._num_steps = self.env.num_steps + + # Clear stack buffer + self._value_list.clear() + + # Initialize buffers if necessary + if self.as_array or self.mode == 'zeros': + # Get current value of base quantity + value = self.quantity.get() + + # Make sure that the quantity is an array or a scalar + if not isinstance(value, (int, float, np.ndarray)): + raise ValueError( + "'as_array=True' is only supported by quantities " + "returning N-dimensional arrays as value.") + value = np.asarray(value) + + # Full the queue with zero if necessary + if self.mode == 'zeros': + for _ in range(self.max_stack): + self._value_list.append( + np.zeros_like(value)) # type: ignore[arg-type] + + # Allocate stack memory if necessary + if self.as_array: + self._data = np.zeros((*value.shape, self.max_stack), + order='F', + dtype=value.dtype) + self._data_views = tuple( + self._data[..., i] for i in range(self.max_stack)) # Reset step counter self._num_steps_prev = -1 - def refresh(self) -> Tuple[ValueT, ...]: + def refresh(self) -> OtherValueT: # Append value to the queue only once per step and only if a simulation # is running. Note that data must be deep-copied to make sure it does # not get altered afterward. + value_list = self._value_list if self.env.is_simulation_running: - num_steps = self.env.num_steps + num_steps = self._num_steps.item() if num_steps != self._num_steps_prev: - assert num_steps == self._num_steps_prev + 1 - self._deque.append(deepcopy(self.data)) + if num_steps != self._num_steps_prev + 1: + raise RuntimeError( + "Previous step missing in the stack. Please reset the " + "environment after adding this quantity.") + value = self.quantity.get() + if isinstance(value, np.ndarray): + value_list.append(value.copy()) # type: ignore[arg-type] + else: + value_list.append(deepcopy(value)) + if len(value_list) > self.max_stack: + del value_list[0] self._num_steps_prev += 1 - # Return the whole stack as a tuple to preserve the integrity of the + # Aggregate data in contiguous array only if requested + if self.as_array: + is_padded = self.mode == 'zeros' + offset = - self._num_steps_prev - 1 + data, data_views = self._data, self._data_views + if offset > - self.max_stack: + if is_padded: + value_list = value_list[offset:] + else: + data = data[..., offset:] + data_views = self._data_views[offset:] + multi_array_copyto(data_views, value_list) + return cast(OtherValueT, data) + + # Return the whole stack as a list to preserve the integrity of the # underlying container and make the API robust to internal changes. - return tuple(self._deque) + return cast(OtherValueT, value_list) @dataclass(unsafe_hash=True) @@ -120,7 +237,7 @@ class MaskedQuantity(InterfaceQuantity[np.ndarray]): contiguous arrays. """ - quantity: InterfaceQuantity + quantity: InterfaceQuantity[np.ndarray] """Base quantity whose elements must be extracted. """ @@ -144,8 +261,8 @@ def __init__(self, :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. :param quantity: Tuple gathering the class of the quantity whose values - must be extracted, plus all its constructor keyword- - arguments except environment 'env' and parent 'parent. + must be extracted, plus any keyword-arguments of its + constructor except 'env' and 'parent'. :param keys: Sequence of indices or boolean mask that will be used to extract elements from the quantity along one axis. :param axis: Axis over which to extract elements. `None` to consider @@ -191,24 +308,24 @@ def __init__(self, # Call base implementation super().__init__(env, parent, - requirements=dict(data=quantity), + requirements=dict(quantity=quantity), auto_refresh=False) - # Keep track of the quantity from which data must be extracted - self.quantity = self.requirements["data"] - def refresh(self) -> np.ndarray: + # Get current value of base quantity + value = self.quantity.get() + # Extract elements from quantity if not self._slices: # Note that `take` is faster than classical advanced indexing via # `operator[]` (`__getitem__`) because the latter is more generic. # Notably, `operator[]` supports boolean mask but `take` does not. - return self.data.take(self.indices, self.axis) + return value.take(self.indices, self.axis) if self.axis is None: # `ravel` must be used instead of `flat` to get a view that can # be sliced without copy. - return self.data.ravel(order="K")[self._slices] - return self.data[self._slices] + return value.ravel(order="K")[self._slices] + return value[self._slices] @dataclass(unsafe_hash=True) @@ -218,7 +335,7 @@ class UnaryOpQuantity(InterfaceQuantity[ValueT], This quantity is useful to translate quantities from world frame to local odometry frame. It may also be used to convert multi-variate quantities as - scalar, typically by computing the Lp-norm. + scalar, typically by computing the L^p-norm. """ quantity: InterfaceQuantity[OtherValueT] @@ -240,8 +357,8 @@ def __init__(self, requirement if any, `None` otherwise. :param quantity: Tuple gathering the class of the quantity whose value must be passed as argument of the unary operator, plus - all its constructor keyword-arguments except - environment 'env' and parent 'parent. + any keyword-arguments of its constructor except 'env' + and 'parent'. :param op: Any callable taking any value of the quantity as input argument. For example `partial(np.linalg.norm, ord=2)` to compute the difference. @@ -253,14 +370,11 @@ def __init__(self, super().__init__( env, parent, - requirements=dict(data=quantity), + requirements=dict(quantity=quantity), auto_refresh=False) - # Keep track of the left- and right-hand side quantities for hashing - self.quantity = self.requirements["data"] - def refresh(self) -> ValueT: - return self.op(self.data) + return self.op(self.quantity.get()) @dataclass(unsafe_hash=True) @@ -282,7 +396,7 @@ class BinaryOpQuantity(InterfaceQuantity[ValueT], """ op: Callable[[OtherValueT, YetAnotherValueT], ValueT] - """Callable taking right- and left-hand side quantities as input argument. + """Callable taking left- and right-hand side quantities as input argument. """ def __init__(self, @@ -317,12 +431,60 @@ def __init__(self, env, parent, requirements=dict( - value_left=quantity_left, value_right=quantity_right), + quantity_left=quantity_left, quantity_right=quantity_right), + auto_refresh=False) + + def refresh(self) -> ValueT: + return self.op(self.quantity_left.get(), self.quantity_right.get()) + + +@dataclass(unsafe_hash=True) +class MultiAryOpQuantity(InterfaceQuantity[ValueT]): + """Apply a given n-ary operator to the values of a given set of quantities. + """ + + quantities: Tuple[InterfaceQuantity[Any], ...] + """Sequence of quantities that will be forwarded to the n-ary operator in + this exact order. + """ + + op: Callable[[Sequence[Any]], ValueT] + """Callable taking the packed sequence of values for all the specified + quantities as input argument. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + quantities: Sequence[QuantityCreator[Any]], + op: Callable[[Sequence[Any]], ValueT]) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param quantities: Ordered sequence of n pairs, each gathering the + class of a quantity whose value must be passed as + argument of the n-ary operator, plus any + keyword-arguments of its constructor except 'env' + and 'parent'. + :param op: Any callable taking the packed sequence of values for all + the quantities as input argument, in the exact order they + were originally specified. + """ + # Backup some user argument(s) + self.op = op + + # Call base implementation + super().__init__( + env, + parent, + requirements={ + f"quantity_{i}": quantity + for i, quantity in enumerate(quantities)}, auto_refresh=False) - # Keep track of the left- and right-hand side quantities for hashing - self.quantity_left = self.requirements["value_left"] - self.quantity_right = self.requirements["value_right"] + # Keep track of the instantiated quantities for identity check + self.quantities = tuple(self.requirements.values()) def refresh(self) -> ValueT: - return self.op(self.value_left, self.value_right) + return self.op([quantity.get() for quantity in self.quantities]) diff --git a/python/gym_jiminy/common/gym_jiminy/common/utils/math.py b/python/gym_jiminy/common/gym_jiminy/common/utils/math.py index f1ca42353..d425c610a 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/utils/math.py +++ b/python/gym_jiminy/common/gym_jiminy/common/utils/math.py @@ -1189,7 +1189,7 @@ def quat_average(quat: np.ndarray, assert len(axes) > 0 and 0 not in axes q_perm = quat.transpose(( - *(i for i in range(1, quat.ndim) if i not in axes), 0, *axes)) + *[i for i in range(1, quat.ndim) if i not in axes], 0, *axes)) q_flat = q_perm.reshape((*q_perm.shape[:-len(axes)], -1)) _, eigvec = np.linalg.eigh(q_flat @ np.swapaxes(q_flat, -1, -2)) return np.moveaxis(eigvec[..., -1], -1, 0) diff --git a/python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py b/python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py index 5875d6441..13f2664be 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py +++ b/python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py @@ -12,9 +12,10 @@ from pydoc import locate from dataclasses import asdict from functools import partial +from collections.abc import Sequence from typing import ( - Dict, Any, Optional, Union, Type, Sequence, Callable, TypedDict, Literal, - cast) + Dict, Any, Optional, Union, Type, Sequence as SequenceT, Callable, + TypedDict, Literal, overload, cast) import h5py import toml @@ -33,34 +34,35 @@ ControlledJiminyEnv, ComposedJiminyEnv, AbstractReward, - BaseQuantityReward, - BaseMixtureReward) + MixtureReward, + AbstractTerminationCondition) from ..envs import BaseJiminyEnv -class RewardConfig(TypedDict, total=False): - """Store information required for instantiating a given reward. +class CompositionConfig(TypedDict, total=False): + """Store information required for instantiating a given composition, which + comprises reward components or a termination condition at the time being. - Specifically, it is a dictionary comprising the class of the reward, which - must derive from `AbstractReward`, and optionally some keyword-arguments - that must be passed to its corresponding constructor. + Specifically, it is a dictionary comprising the class of the composition + that must derive from `AbstractReward` or `AbstractTerminationCondition]`, + and optionally some keyword-arguments to pass to its constructor. """ - cls: Union[Type[AbstractReward], str] - """Reward class type. + cls: Union[Type[AbstractReward], Type[AbstractTerminationCondition], str] + """Composition class type. .. note:: Both class type or fully qualified dotted path are supported. """ kwargs: Dict[str, Any] - """Environment constructor keyword-arguments. + """Composition constructor keyword-arguments. This attribute can be omitted. """ -class TrajectoriesConfig(TypedDict, total=False): +class TrajectoryDatabaseConfig(TypedDict, total=False): """Store information required for adding a database of reference trajectories to the environment. @@ -115,14 +117,20 @@ class EnvConfig(TypedDict, total=False): This attribute can be omitted. """ - reward: RewardConfig + reward: CompositionConfig """Reward configuration. This attribute can be omitted. """ - trajectories: TrajectoriesConfig - """Reference trajectories configuration. + terminations: SequenceT[CompositionConfig] + """Sequence of configuration for every individual termination conditions. + + This attribute can be omitted. + """ + + trajectories: TrajectoryDatabaseConfig + """Reference trajectory database configuration. This attribute can be omitted. """ @@ -182,9 +190,9 @@ class LayerConfig(TypedDict, total=False): controller block with its corresponding wrapper. Specifically, it is a dictionary comprising the configuration of the block - if any, and optionally the configuration of the reward. It is generally - sufficient to specify either one or the other. See the documentation of the - both fields for details. + if any, and optionally the configuration of the reward and termination. It + is generally sufficient to specify either one or the other. See the + documentation of the both fields for details. """ block: BlockConfig @@ -206,7 +214,7 @@ class LayerConfig(TypedDict, total=False): def build_pipeline(env_config: EnvConfig, - layers_config: Sequence[LayerConfig], + layers_config: SequenceT[LayerConfig], *, root_path: Optional[Union[str, pathlib.Path]] = None ) -> Callable[..., InterfaceJiminyEnv]: @@ -222,55 +230,78 @@ def build_pipeline(env_config: EnvConfig, lowest level layer to the highest, each element corresponding to the configuration of a individual layer, as a dict of type `LayerConfig`. """ - # Define helper to sanitize reward configuration - def sanitize_reward_config(reward_config: RewardConfig) -> None: - """Sanitize reward configuration in-place. + # Define helper to sanitize composition configuration + def sanitize_composition_config(composition_config: CompositionConfig, + is_reward: bool) -> None: + """Sanitize composition configuration in-place. - :param reward_config: Configuration of the reward, as a dict of type - `RewardConfig`. + :param composition_config: Configuration of the composition, as a + dict of type `CompositionConfig`. """ - # Get reward class type - cls = reward_config["cls"] + # Get composition class type + cls = composition_config["cls"] if isinstance(cls, str): obj = locate(cls) if obj is None: raise RuntimeError(f"Class '{cls}' not found.") - assert isinstance(obj, type) and issubclass(obj, AbstractReward) - reward_config["cls"] = cls = obj + assert isinstance(obj, type) and ( + (is_reward and issubclass(obj, AbstractReward)) or + (not is_reward and issubclass( + obj, AbstractTerminationCondition))) + composition_config["cls"] = cls = obj - # Get reward constructor keyword-arguments - kwargs = reward_config.get("kwargs", {}) + # Get its constructor keyword-arguments + kwargs = composition_config.get("kwargs", {}) - # Special handling for `BaseMixtureReward` - if issubclass(cls, BaseMixtureReward): + # Special handling for `MixtureReward` + if is_reward and issubclass(cls, MixtureReward): for component_config in kwargs["components"]: - sanitize_reward_config(component_config) - - # Define helper to build the reward - def build_reward(env: InterfaceJiminyEnv, - reward_config: RewardConfig) -> AbstractReward: - """Instantiate a reward associated with a given environment provided - some reward configuration. + sanitize_composition_config(component_config, is_reward) + + @overload + def build_composition( + env: InterfaceJiminyEnv, + composition_config: CompositionConfig, + is_reward: Literal[True] + ) -> AbstractReward: + ... + + @overload + def build_composition( + env: InterfaceJiminyEnv, + composition_config: CompositionConfig, + is_reward: Literal[False] + ) -> AbstractTerminationCondition: + ... + + # Define helper to build the composition + def build_composition( + env: InterfaceJiminyEnv, + composition_config: CompositionConfig, + is_reward: bool + ) -> Union[AbstractReward, AbstractTerminationCondition]: + """Instantiate a composition associated with a given environment from + some composition configuration. :param env: Base environment or pipeline wrapper to wrap. - :param reward_config: Configuration of the reward, as a dict of type - `RewardConfig`. + :param composition_config: Configuration of the composition, as a + dict of type `CompositionConfig`. """ - # Get reward class type - cls = reward_config["cls"] - assert isinstance(cls, type) and issubclass(cls, AbstractReward) + # Get composition class type + cls = composition_config["cls"] + assert isinstance(cls, type) - # Get reward constructor keyword-arguments - kwargs = reward_config.get("kwargs", {}) + # Get its constructor keyword-arguments + kwargs = composition_config.get("kwargs", {}).copy() - # Special handling for `BaseMixtureReward` - if issubclass(cls, BaseMixtureReward): + # Special handling for `MixtureReward` + if is_reward and issubclass(cls, MixtureReward): kwargs["components"] = tuple( - build_reward(env, reward_config) + build_composition(env, reward_config, is_reward) for reward_config in kwargs["components"]) - # Special handling for `BaseQuantityReward` - if cls is BaseQuantityReward: + # Special handling for 'quantity' key + if "quantity" in kwargs: quantity_config = kwargs["quantity"] kwargs["quantity"] = ( quantity_config["cls"], quantity_config["kwargs"]) @@ -278,17 +309,22 @@ def build_reward(env: InterfaceJiminyEnv, return cls(env, **kwargs) # Define helper to build reward - def build_composition(env_creator: Callable[..., InterfaceJiminyEnv], - reward_config: Optional[RewardConfig], - trajectories_config: Optional[TrajectoriesConfig], - **env_kwargs: Any) -> InterfaceJiminyEnv: - """Helper adding reward on top of a base environment or a pipeline - using `ComposedJiminyEnv` wrapper. + def build_composition_layer( + env_creator: Callable[..., InterfaceJiminyEnv], + reward_config: Optional[CompositionConfig], + terminations_config: SequenceT[CompositionConfig], + trajectories_config: Optional[TrajectoryDatabaseConfig], + **env_kwargs: Any) -> InterfaceJiminyEnv: + """Helper adding reward components and/or termination conditions on top + of a base environment or a pipeline using `ComposedJiminyEnv` wrapper. :param env_creator: Callable that takes optional keyword arguments as input and returns an pipeline or base environment. :param reward_config: Configuration of the reward, as a dict of type - `RewardConfig`. + `CompositionConfig`. + :param termination_config: Configuration of the termination conditions, + as a sequence of dict of type + `CompositionConfig`. :param trajectories: Set of named trajectories as a dictionary. See `ComposedJiminyEnv` documentation for details. :param env_kwargs: Keyword arguments to forward to the constructor of @@ -304,7 +340,12 @@ def build_composition(env_creator: Callable[..., InterfaceJiminyEnv], # Instantiate the reward reward = None if reward_config is not None: - reward = build_reward(env, reward_config) + reward = build_composition(env, reward_config, True) + + # Instantiate the termination conditions + terminations = tuple( + build_composition(env, termination_config, False) + for termination_config in terminations_config) # Get trajectory dataset trajectories: Dict[str, Trajectory] = {} @@ -313,9 +354,11 @@ def build_composition(env_creator: Callable[..., InterfaceJiminyEnv], Dict[str, Trajectory], trajectories_config["dataset"]) # Instantiate the composition wrapper if necessary - if reward or trajectories: - env = ComposedJiminyEnv( - env, reward=reward, trajectories=trajectories) + if reward or terminations or trajectories: + env = ComposedJiminyEnv(env, + reward=reward, + terminations=terminations, + trajectories=trajectories) # Select the reference trajectory if specified if trajectories_config is not None: @@ -327,15 +370,16 @@ def build_composition(env_creator: Callable[..., InterfaceJiminyEnv], return env # Define helper to wrap a single layer - def build_layer(env_creator: Callable[..., InterfaceJiminyEnv], - wrapper_cls: Type[BasePipelineWrapper], - wrapper_kwargs: Dict[str, Any], - block_cls: Optional[Type[InterfaceBlock]], - block_kwargs: Dict[str, Any], - **env_kwargs: Any - ) -> BasePipelineWrapper: - """Helper wrapping a base environment or a pipeline with additional - layer, typically an observer or a controller. + def build_controller_observer_layer( + env_creator: Callable[..., InterfaceJiminyEnv], + wrapper_cls: Type[BasePipelineWrapper], + wrapper_kwargs: Dict[str, Any], + block_cls: Optional[Type[InterfaceBlock]], + block_kwargs: Dict[str, Any], + **env_kwargs: Any + ) -> BasePipelineWrapper: + """Helper wrapping a base environment or a pipeline with an additional + observer-controller layer. :param env_creator: Callable that takes optional keyword arguments as input and returns an pipeline or base environment. @@ -401,7 +445,13 @@ def build_layer(env_creator: Callable[..., InterfaceJiminyEnv], # Parse reward configuration reward_config = env_config.get("reward") if reward_config is not None: - sanitize_reward_config(reward_config) + sanitize_composition_config(reward_config, is_reward=True) + + # Parse the configuration of every termination conditions + terminations_config = env_config.get("terminations", ()) + assert isinstance(terminations_config, Sequence) + for termination_config in terminations_config: + sanitize_composition_config(termination_config, is_reward=False) # Parse trajectory configuration trajectories_config = env_config.get("trajectories") @@ -420,12 +470,6 @@ def build_layer(env_creator: Callable[..., InterfaceJiminyEnv], path = pathlib.Path(root_path) / path trajectories[name] = load_trajectory_from_hdf5(path) - # Compose base environment with an extra user-specified reward - pipeline_creator = partial(build_composition, - pipeline_creator, - reward_config, - trajectories_config) - # Generate pipeline recursively for layer_config in layers_config: # Extract block and wrapper config @@ -476,13 +520,20 @@ def build_layer(env_creator: Callable[..., InterfaceJiminyEnv], "Either 'block.cls' or 'wrapper.cls' must be specified.") # Add layer on top of the existing pipeline - pipeline_creator = partial(build_layer, + pipeline_creator = partial(build_controller_observer_layer, pipeline_creator, wrapper_cls_, wrapper_kwargs, block_cls_, block_kwargs) + # Add extra user-specified reward, termination conditions and trajectories + pipeline_creator = partial(build_composition_layer, + pipeline_creator, + reward_config, + terminations_config, + trajectories_config) + return pipeline_creator diff --git a/python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py b/python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py index dd8d61d43..f13219f35 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py +++ b/python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py @@ -237,9 +237,9 @@ def clip(data: DataNested, space: gym.Space[DataNested]) -> DataNested: field: clip(data[field], subspace) for field, subspace in space.spaces.items()}) if tree.issubclass_sequence(data_type): - return data_type(tuple( + return data_type([ clip(data[i], subspace) - for i, subspace in enumerate(space.spaces))) + for i, subspace in enumerate(space.spaces)]) return _array_clip(data, *get_bounds(space)) diff --git a/python/gym_jiminy/common/gym_jiminy/common/wrappers/__init__.py b/python/gym_jiminy/common/gym_jiminy/common/wrappers/__init__.py index 1111b4863..3d450525c 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/wrappers/__init__.py +++ b/python/gym_jiminy/common/gym_jiminy/common/wrappers/__init__.py @@ -9,8 +9,8 @@ __all__ = [ 'FilterObservation', 'StackObservation', - 'NormalizeAction', 'NormalizeObservation', - 'FlattenAction', - 'FlattenObservation' + 'FlattenObservation', + 'NormalizeAction', + 'FlattenAction' ] diff --git a/python/gym_jiminy/envs/gym_jiminy/envs/ant.py b/python/gym_jiminy/envs/gym_jiminy/envs/ant.py index 503542345..5ee0cc7d1 100644 --- a/python/gym_jiminy/envs/gym_jiminy/envs/ant.py +++ b/python/gym_jiminy/envs/gym_jiminy/envs/ant.py @@ -239,9 +239,9 @@ def compute_reward(self, terminated: bool, info: InfoType) -> float: * survive_reward: Constant positive reward equal to 1.0 as long as no termination condition has been triggered. * ctrl_cost: Negative reward to penalize power consumption, defined - as the L2-norm of the action vector weighted by 0.5. + as the L^2-norm of the action vector weighted by 0.5. * contact_cost: Negative reward to penalize jerky, violent motions, - defined as the aggregated L2-norm of the external forces applied + defined as the aggregated L^2-norm of the external forces applied on all bodies weighted by 5e-4. The value of each individual reward is added to `info` for monitoring. diff --git a/python/gym_jiminy/examples/quantity_benchmark.py b/python/gym_jiminy/examples/quantity_benchmark.py index 9a2111c4e..61cb6d43c 100644 --- a/python/gym_jiminy/examples/quantity_benchmark.py +++ b/python/gym_jiminy/examples/quantity_benchmark.py @@ -11,10 +11,9 @@ # Define number of samples for benchmarking N_SAMPLES = 50000 -# Disable caching by forcing `SharedCache.has_value` to always return `False` -setattr(gym_jiminy.common.bases.quantities.SharedCache, - "has_value", - property(lambda self: False)) +# Disable caching by disabling "IS_CACHED" FSM State +gym_jiminy.common.bases.quantities._IS_CACHED = ( + gym_jiminy.common.bases.quantities.QuantityStateMachine.IS_INITIALIZED) # Instantiate a dummy environment env = gym.make("gym_jiminy.envs:atlas") @@ -40,7 +39,7 @@ break # Extract batched data buffer of `FrameOrientation` quantities - shared_data = quantity.requirements['data'] + shared_data = quantity.data # Benchmark computation of batched data buffer duration = timeit.timeit( diff --git a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/compositions/__init__.py b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/compositions/__init__.py index 9264558bd..4abecdb52 100644 --- a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/compositions/__init__.py +++ b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/compositions/__init__.py @@ -1,8 +1,8 @@ # pylint: disable=missing-module-docstring -from .locomotion import tanh_normalization, MaximizeStability +from .locomotion import tanh_normalization, MaximizeRobusntess __all__ = [ "tanh_normalization", - "MaximizeStability" + "MaximizeRobusntess" ] diff --git a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/compositions/locomotion.py b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/compositions/locomotion.py index 1d814f6cc..6e392cf1f 100644 --- a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/compositions/locomotion.py +++ b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/compositions/locomotion.py @@ -7,7 +7,7 @@ from gym_jiminy.common.compositions import CUTOFF_ESP from gym_jiminy.common.bases import ( - InterfaceJiminyEnv, QuantityEvalMode, BaseQuantityReward) + InterfaceJiminyEnv, QuantityEvalMode, QuantityReward) from ..quantities import StabilityMarginProjectedSupportPolygon @@ -29,14 +29,14 @@ def tanh_normalization(value: float, be bounded or unbounded, and signed or not, without restrictions. :param cutoff: Cut-off threshold to consider. - :param order: Order of Lp-Norm that will be used as distance metric. + :param order: Order of L^p-norm that will be used as distance metric. """ value_rel = ( cutoff_high + cutoff_low - 2 * value) / (cutoff_high - cutoff_low) return 1.0 / (1.0 + math.pow(CUTOFF_ESP / (1.0 - CUTOFF_ESP), value_rel)) -class MaximizeStability(BaseQuantityReward): +class MaximizeRobusntess(QuantityReward): """Encourage the agent to maintain itself in postures as robust as possible to external disturbances. @@ -66,21 +66,21 @@ class MaximizeStability(BaseQuantityReward): """ def __init__(self, env: InterfaceJiminyEnv, - cutoff_inner: float, - cutoff_outer: float) -> None: + cutoff: float, + cutoff_outer: float = 0.0) -> None: """ :param env: Base or wrapped jiminy environment. - :param cutoff_inner: Cutoff threshold when the ZMP lies inside the - support polygon. The reward will be larger than - '1.0 - CUTOFF_ESP' if the distance from the border - is larger than 'cutoff_inner'. + :param cutoff: Cutoff threshold when the ZMP lies inside the support + polygon. The reward will be larger than + '1.0 - CUTOFF_ESP' if the distance from the border is + larger than 'cutoff_inner'. :param cutoff_outer: Cutoff threshold when the ZMP lies outside the support polygon. The reward will be smaller than 'CUTOFF_ESP' if the ZMP is further away from the border of the support polygon than 'cutoff_outer'. """ # Backup some user argument(s) - self.cutoff_inner = cutoff_inner + self.cutoff_inner = cutoff self.cutoff_outer = cutoff_outer # The cutoff thresholds must be positive @@ -91,7 +91,7 @@ def __init__(self, # Call base implementation super().__init__( env, - "reward_momentum", + "reward_robustness", (StabilityMarginProjectedSupportPolygon, dict( mode=QuantityEvalMode.TRUE )), diff --git a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/math/qhull.py b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/math/qhull.py index e5467b9ac..6572e4df6 100644 --- a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/math/qhull.py +++ b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/math/qhull.py @@ -354,7 +354,7 @@ def compute_convex_chebyshev_center( A = np.concatenate(( equations[:, :-1], np.ones((len(equations), 1))), axis=1) b = - equations[:, -1:] - c = np.array([*(0.0,) * (num_dims - 1), -1.0]) + c = np.array([*((0.0,) * (num_dims - 1)), -1.0]) res = linprog(c, A_ub=A, b_ub=b, bounds=(None, None)) return res.x[:-1], res.x[-1] diff --git a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/quantities/locomotion.py b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/quantities/locomotion.py index 838917c88..60e70f769 100644 --- a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/quantities/locomotion.py +++ b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/quantities/locomotion.py @@ -70,8 +70,7 @@ def __init__(self, env, parent, requirements=dict( - odom_pose=(BaseOdometryPose, dict(mode=mode)) - ), + odom_pose=(BaseOdometryPose, dict(mode=mode))), mode=mode, auto_refresh=False) @@ -137,7 +136,7 @@ def initialize(self) -> None: # because it preserves contiguity when copying frame data, and because # the `ConvexHull2D` would perform one extra copy otherwise. self._candidate_xy_batch = np.empty( - (len(self._candidate_xy_refs), 2), order="C") + (len(self._candidate_xy_refs), 2), order='C') # Refresh proxies self._candidate_xy_views = tuple(self._candidate_xy_batch) @@ -149,7 +148,7 @@ def refresh(self) -> ConvexHull2D: # Translate candidate contact points from world to local odometry frame if self.reference_frame == pin.LOCAL: translate_position_odom(self._candidate_xy_batch, - self.odom_pose, + self.odom_pose.get(), self._candidate_xy_batch) # Compute the 2D convex hull in world plane @@ -181,7 +180,7 @@ class StabilityMarginProjectedSupportPolygon(InterfaceQuantity[float]): """ mode: QuantityEvalMode - """Specify on which state to evaluate this quantity. See `Mode` + """Specify on which state to evaluate this quantity. See `QuantityEvalMode` documentation for details about each mode. .. warning:: @@ -217,4 +216,5 @@ def __init__(self, auto_refresh=False) def refresh(self) -> float: - return - self.support_polygon.get_distance_to_point(self.zmp).item() + support_polygon, zmp = self.support_polygon.get(), self.zmp.get() + return - support_polygon.get_distance_to_point(zmp).item() diff --git a/python/gym_jiminy/unit_py/data/anymal_pipeline.toml b/python/gym_jiminy/unit_py/data/anymal_pipeline.toml index 053d2beef..c1e667f57 100644 --- a/python/gym_jiminy/unit_py/data/anymal_pipeline.toml +++ b/python/gym_jiminy/unit_py/data/anymal_pipeline.toml @@ -2,12 +2,28 @@ cls = "gym_jiminy.envs.ANYmalJiminyEnv" [env_config.kwargs] step_dt = 0.04 -[env_config.reward] -cls = "gym_jiminy.common.compositions.AdditiveMixtureReward" + +# ======================= Reference trajectory database ======================= + [env_config.trajectories] mode = "raise" name = "reference" dataset.reference = "./anymal_trajectory.hdf5" + +# ======================= Ad-hoc termination conditions ======================= + +[[env_config.terminations]] +cls = "gym_jiminy.common.compositions.BaseRollPitchTermination" +[env_config.terminations.kwargs] +low = [-0.2, -0.05] +high = [-0.05, 0.3] +grace_period = 0.1 +is_training_only = false + +# ========================== Ad-hoc reward components ========================= + +[env_config.reward] +cls = "gym_jiminy.common.compositions.AdditiveMixtureReward" [env_config.reward.kwargs] name = "reward_total" weights = [0.6, 0.4] @@ -18,6 +34,8 @@ cutoff = 0.5 [[env_config.reward.kwargs.components]] cls = "gym_jiminy.common.compositions.SurviveReward" +# ========================= Observer-Controller blocks ======================== + [[layers_config]] block.cls = "gym_jiminy.common.blocks.PDController" [layers_config.block.kwargs] @@ -46,6 +64,8 @@ exact_init = false kp = 1.0 ki = 0.1 +# ========================= Policy interface wrappers ========================= + [[layers_config]] wrapper.cls = "gym_jiminy.common.wrappers.StackObservation" [layers_config.wrapper.kwargs] diff --git a/python/gym_jiminy/unit_py/data/cassie_standing_9.png b/python/gym_jiminy/unit_py/data/cassie_standing_9.png new file mode 100644 index 000000000..e0a3f0a20 Binary files /dev/null and b/python/gym_jiminy/unit_py/data/cassie_standing_9.png differ diff --git a/python/gym_jiminy/unit_py/test_pipeline_control.py b/python/gym_jiminy/unit_py/test_pipeline_control.py index 103d0131f..864733059 100644 --- a/python/gym_jiminy/unit_py/test_pipeline_control.py +++ b/python/gym_jiminy/unit_py/test_pipeline_control.py @@ -21,6 +21,9 @@ from gym_jiminy.common.blocks import PDController, PDAdapter, MahonyFilter from gym_jiminy.common.blocks.proportional_derivative_controller import ( integrate_zoh) +from gym_jiminy.common.wrappers import ( + FilterObservation, StackObservation, NormalizeObservation, + FlattenObservation) from gym_jiminy.common.utils import ( quat_to_rpy, matrix_to_rpy, matrix_to_quat, remove_twist_from_quat) @@ -298,6 +301,8 @@ def test_pd_controller(self): self.assertTrue(np.all(np.abs(target_vel) <= motor.velocity_limit)) def test_repeatability(self): + """ TODO: Write documentation. + """ # Instantiate the environment env = AtlasPDControlJiminyEnv() @@ -311,3 +316,30 @@ def test_repeatability(self): assert np.all(a_prev == env.robot_state.a) for _ in range(n_steps): env.step(env.action) + + def test_preserve_obs_key_order(self): + """ TODO: Write documentation. + """ + env = AtlasPDControlJiminyEnv() + + env_stack = StackObservation( + env, skip_frames_ratio=-1, num_stack=2, nested_filter_keys=[["t"]]) + env_filter = FilterObservation( + env, nested_filter_keys=env.observation_space.keys()) + env_obs_norm = NormalizeObservation(env) + for env in (env, env_stack, env_filter, env_obs_norm): + env.reset(seed=0) + assert [*env.observation_space.keys()] == [*env.observation.keys()] + + env_flat = FlattenObservation(env) + env_flat.reset(seed=0) + all_values_flat = [] + obs_nodes = list(env.observation.values()) + while obs_nodes: + value = obs_nodes.pop() + if isinstance(value, dict): + obs_nodes += value.values() + else: + all_values_flat.append(value.flatten()) + obs_flat = np.concatenate(all_values_flat[::-1]) + np.testing.assert_allclose(env_flat.observation, obs_flat) diff --git a/python/gym_jiminy/unit_py/test_quantities.py b/python/gym_jiminy/unit_py/test_quantities.py index 95ef04bca..38beddcf7 100644 --- a/python/gym_jiminy/unit_py/test_quantities.py +++ b/python/gym_jiminy/unit_py/test_quantities.py @@ -1,5 +1,6 @@ """ TODO: Write documentation """ +import sys import math import unittest @@ -15,24 +16,29 @@ remove_yaw_from_quat) from gym_jiminy.common.bases import QuantityEvalMode, DatasetTrajectoryQuantity from gym_jiminy.common.quantities import ( + EnergyGenerationMode, OrientationType, QuantityManager, - FrameOrientation, - MultiFramesOrientation, - FrameXYZQuat, - MultiFramesMeanXYZQuat, + StackedQuantity, MaskedQuantity, + MultiFrameMeanXYZQuat, + MultiFrameOrientation, MultiFootMeanOdometryPose, MultiFootRelativeXYZQuat, - AverageFrameSpatialVelocity, - AverageBaseOdometryVelocity, + MultiFrameCollisionDetection, + MultiActuatedJointKinematic, + MultiContactNormalizedSpatialForce, + MultiFootNormalizedForceVertical, + FrameOrientation, + FrameXYZQuat, + FrameSpatialAverageVelocity, + BaseOdometryAverageVelocity, + BaseRelativeHeight, AverageBaseMomentum, - ActuatedJointsPosition, + AverageMechanicalPowerConsumption, CenterOfMass, CapturePoint, - ZeroMomentPoint, - MultiContactRelativeForceTangential, - MultiFootRelativeForceVertical) + ZeroMomentPoint) class Quantities(unittest.TestCase): @@ -57,24 +63,30 @@ def test_shared_cache(self): assert len(quantities["com"].cache.owners) == 2 zmp_0 = quantity_manager.zmp.copy() - assert quantities["com"].cache.has_value - assert not quantities["acom"].cache.has_value - assert not quantities["com"]._is_initialized - assert quantities["zmp"].requirements["com_position"]._is_initialized + assert quantities["com"].cache.sm_state == 2 + assert quantities["acom"].cache.sm_state == 0 + is_initialized_all = [ + owner._is_initialized for owner in quantities["com"].cache.owners] + assert len(is_initialized_all) == 2 + assert len(set(is_initialized_all)) == 2 env.step(env.action_space.sample()) zmp_1 = quantity_manager["zmp"].copy() assert np.all(zmp_0 == zmp_1) quantity_manager.clear() - assert quantities["zmp"].requirements["com_position"]._is_initialized - assert not quantities["com"].cache.has_value + is_initialized_all = [ + owner._is_initialized for owner in quantities["com"].cache.owners] + assert any(is_initialized_all) + assert quantities["com"].cache.sm_state == 1 zmp_1 = quantity_manager.zmp.copy() assert np.any(zmp_0 != zmp_1) env.step(env.action_space.sample()) quantity_manager.reset() - assert not quantities["zmp"].requirements[ - "com_position"]._is_initialized + assert quantities["com"].cache.sm_state == 0 + is_initialized_all = [ + owner._is_initialized for owner in quantities["com"].cache.owners] + assert not any(is_initialized_all) zmp_2 = quantity_manager.zmp.copy() assert np.any(zmp_1 != zmp_2) @@ -101,19 +113,19 @@ def test_dynamic_batching(self): ("rpy_2", FrameOrientation, dict( frame_name=frame_names[-1], type=OrientationType.EULER)), - ("rpy_batch_0", MultiFramesOrientation, dict( # Intersection + ("rpy_batch_0", MultiFrameOrientation, dict( # Intersection frame_names=(frame_names[-3], frame_names[1]), type=OrientationType.EULER)), - ("rpy_batch_1", MultiFramesOrientation, dict( # Inclusion + ("rpy_batch_1", MultiFrameOrientation, dict( # Inclusion frame_names=(frame_names[1], frame_names[-1]), type=OrientationType.EULER)), - ("rpy_batch_2", MultiFramesOrientation, dict( # Disjoint + ("rpy_batch_2", MultiFrameOrientation, dict( # Disjoint frame_names=(frame_names[1], frame_names[-4]), type=OrientationType.EULER)), - ("rot_mat_batch", MultiFramesOrientation, dict( + ("rot_mat_batch", MultiFrameOrientation, dict( frame_names=(frame_names[1], frame_names[-1]), type=OrientationType.MATRIX)), - ("quat_batch", MultiFramesOrientation, dict( + ("quat_batch", MultiFrameOrientation, dict( frame_names=(frame_names[1], frame_names[-4]), type=OrientationType.QUATERNION))): quantity_manager[name] = (cls, kwargs) @@ -121,43 +133,39 @@ def test_dynamic_batching(self): xyzquat_0 = quantity_manager.xyzquat_0.copy() rpy_0 = quantity_manager.rpy_0.copy() - assert len(quantities['rpy_0'].requirements['data'].frame_names) == 1 + assert len(quantities['rpy_0'].data.cache.owners[0].frame_names) == 1 assert np.all(rpy_0 == quantity_manager.rpy_1) rpy_2 = quantity_manager.rpy_2.copy() assert np.any(rpy_0 != rpy_2) - assert len(quantities['rpy_2'].requirements['data'].frame_names) == 2 + assert len(quantities['rpy_2'].data.cache.owners[0].frame_names) == 2 assert tuple(quantity_manager.rpy_batch_0.shape) == (3, 2) - assert len(quantities['rpy_batch_0'].requirements['data']. - frame_names) == 3 + assert len(quantities['rpy_batch_0'].data.cache.owners[0].frame_names) == 3 quantity_manager.rpy_batch_1 - assert len(quantities['rpy_batch_1'].requirements['data']. - frame_names) == 3 + assert len(quantities['rpy_batch_1'].data.cache.owners[0].frame_names) == 3 quantity_manager.rpy_batch_2 - assert len(quantities['rpy_batch_2'].requirements['data']. - frame_names) == 5 + assert len(quantities['rpy_batch_2'].data.cache.owners[0].frame_names) == 5 assert tuple(quantity_manager.rot_mat_batch.shape) == (3, 3, 2) assert tuple(quantity_manager.quat_batch.shape) == (4, 2) - assert len(quantities['quat_batch'].requirements['data']. - requirements['rot_mat_map'].frame_names) == 8 + assert len(quantities['quat_batch'].data.rot_mat_map.cache.owners[0].frame_names) == 8 env.step(env.action_space.sample()) quantity_manager.reset() rpy_0_next = quantity_manager.rpy_0 - xyzquat_0_next = quantity_manager.xyzquat_0.copy() + xyzquat_0_next = quantity_manager.xyzquat_0.copy() assert np.any(rpy_0 != rpy_0_next) assert np.any(xyzquat_0 != xyzquat_0_next) - assert len(quantities['rpy_2'].requirements['data'].frame_names) == 2 + assert len(quantities['rpy_2'].data.cache.owners[0].frame_names) == 5 - assert len(quantities['rpy_1'].requirements['data'].cache.owners) == 6 + assert len(quantities['rpy_1'].data.cache.owners) == 6 del quantity_manager['rpy_2'] - assert len(quantities['rpy_1'].requirements['data'].cache.owners) == 5 + assert len(quantities['rpy_1'].data.cache.owners) == 5 quantity_manager.rpy_1 - assert len(quantities['rpy_1'].requirements['data'].frame_names) == 1 + assert len(quantities['rpy_1'].data.cache.owners[0].frame_names) == 1 quantity_manager.reset(reset_tracking=True) assert np.all(rpy_0_next == quantity_manager.rpy_0) assert np.all(xyzquat_0_next == quantity_manager.xyzquat_0) - assert len(quantities['rpy_0'].requirements['data'].frame_names) == 1 + assert len(quantities['rpy_0'].data.cache.owners[0].frame_names) == 1 def test_discard(self): """ TODO: Write documentation @@ -177,17 +185,17 @@ def test_discard(self): quantities = quantity_manager.registry assert len(quantities['rpy_1'].cache.owners) == 2 - assert len(quantities['rpy_2'].requirements['data'].cache.owners) == 3 + assert len(quantities['rpy_2'].data.cache.owners) == 3 del quantity_manager['rpy_0'] assert len(quantities['rpy_1'].cache.owners) == 1 - assert len(quantities['rpy_2'].requirements['data'].cache.owners) == 2 + assert len(quantities['rpy_2'].data.cache.owners) == 2 del quantity_manager['rpy_1'] - assert len(quantities['rpy_2'].requirements['data'].cache.owners) == 1 + assert len(quantities['rpy_2'].data.cache.owners) == 1 del quantity_manager['rpy_2'] - for (cls, _), cache in quantity_manager._caches.items(): + for (cls, _), cache in quantity_manager._caches: assert len(cache.owners) == (cls is DatasetTrajectoryQuantity) def test_env(self): @@ -204,13 +212,13 @@ def test_env(self): env.reset(seed=0) assert np.all(zmp_0 == env.quantities["zmp"]) - def test_stack(self): + def test_stack_auto_refresh(self): """ TODO: Write documentation """ env = gym.make("gym_jiminy.envs:atlas") env.reset(seed=0) - quantity_cls = AverageFrameSpatialVelocity + quantity_cls = FrameSpatialAverageVelocity quantity_kwargs = dict( frame_name=env.robot.pinocchio_model.frames[1].name) env.quantities["v_avg"] = (quantity_cls, quantity_kwargs) @@ -227,6 +235,47 @@ def test_stack(self): env.step(env.action_space.sample()) assert np.all(v_avg != env.quantities["v_avg"]) + def test_stack_api(self): + """ TODO: Write documentation + """ + env = gym.make("gym_jiminy.envs:atlas") + + for max_stack, as_array, mode in ( + (None, False, "slice"), + (3, False, "slice"), + (3, True, "slice"), + (3, False, "zeros"), + (3, True, "zeros")): + quantity_creator = (StackedQuantity, dict( + quantity=(MultiFootRelativeXYZQuat, {}), + max_stack=max_stack or sys.maxsize, + as_array=as_array, + mode=mode)) + env.quantities["xyzquat_stack"] = quantity_creator + env.reset(seed=0) + + value = env.quantities["xyzquat_stack"] + if as_array: + assert isinstance(value, np.ndarray) + else: + assert isinstance(value, list) + for i in range(1, (max_stack or 5) + 2): + num_stack = max_stack or i + if mode == "slice": + num_stack = min(i, num_stack) + value = env.quantities["xyzquat_stack"] + if as_array: + assert value.shape[-1] == num_stack + if mode == "zeros": + np.testing.assert_allclose(value[..., :-i], 0.0) + else: + assert len(value) == num_stack + if mode == "zeros": + np.testing.assert_allclose(value[:-i], 0.0) + env.step(env.action) + + del env.quantities["xyzquat_stack"] + def test_masked(self): """ TODO: Write documentation """ @@ -240,8 +289,9 @@ def test_masked(self): keys=(0, 1, 5))) quantity = env.quantities.registry["v_masked"] assert not quantity._slices + value = quantity.quantity.get() np.testing.assert_allclose( - env.quantities["v_masked"], quantity.data[[0, 1, 5]]) + env.quantities["v_masked"], value[[0, 1, 5]]) del env.quantities["v_masked"] # 2. From boolean mask @@ -249,8 +299,9 @@ def test_masked(self): quantity=(FrameXYZQuat, dict(frame_name="root_joint")), keys=(True, True, False, False, False, True))) quantity = env.quantities.registry["v_masked"] + value = quantity.quantity.get() np.testing.assert_allclose( - env.quantities["v_masked"], quantity.data[[0, 1, 5]]) + env.quantities["v_masked"], value[[0, 1, 5]]) del env.quantities["v_masked"] # 3. From slice-able indices @@ -258,9 +309,11 @@ def test_masked(self): quantity=(FrameXYZQuat, dict(frame_name="root_joint")), keys=(0, 2, 4))) quantity = env.quantities.registry["v_masked"] - assert len(quantity._slices) == 1 and quantity._slices[0] == slice(0, 5, 2) + assert len(quantity._slices) == 1 and ( + quantity._slices[0] == slice(0, 5, 2)) + value = quantity.quantity.get() np.testing.assert_allclose( - env.quantities["v_masked"], quantity.data[[0, 2, 4]]) + env.quantities["v_masked"], value[[0, 2, 4]]) def test_true_vs_reference(self): env = gym.make("gym_jiminy.envs:atlas", debug=False) @@ -278,17 +331,23 @@ def test_true_vs_reference(self): lambda mode: (FrameXYZQuat, dict( frame_name=frame_names[2], mode=mode)), - lambda mode: (MultiFramesMeanXYZQuat, dict( + lambda mode: (MultiFrameMeanXYZQuat, dict( frame_names=tuple(frame_names[i] for i in (1, 3, -2)), mode=mode)), lambda mode: (MultiFootMeanOdometryPose, dict( mode=mode)), - lambda mode: (AverageFrameSpatialVelocity, dict( + lambda mode: (FrameSpatialAverageVelocity, dict( frame_name=frame_names[1], mode=mode)), - lambda mode: (AverageBaseOdometryVelocity, dict( + lambda mode: (BaseOdometryAverageVelocity, dict( + mode=mode)), + lambda mode: (MultiActuatedJointKinematic, dict( + kinematic_level=pin.KinematicLevel.POSITION, + is_motor_side=False, mode=mode)), - lambda mode: (ActuatedJointsPosition, dict( + lambda mode: (MultiActuatedJointKinematic, dict( + kinematic_level=pin.KinematicLevel.VELOCITY, + is_motor_side=True, mode=mode)), lambda mode: (CenterOfMass, dict( kinematic_level=pin.KinematicLevel.ACCELERATION, @@ -337,7 +396,7 @@ def test_average_odometry_velocity(self): env = gym.make("gym_jiminy.envs:atlas") env.quantities["odometry_velocity"] = ( - AverageBaseOdometryVelocity, dict( + BaseOdometryAverageVelocity, dict( mode=QuantityEvalMode.TRUE)) quantity = env.quantities.registry["odometry_velocity"] @@ -356,7 +415,7 @@ def test_average_odometry_velocity(self): rot_mat @ base_velocity_mean_local[3:])) np.testing.assert_allclose( - quantity.requirements['data'].data, base_velocity_mean_world) + quantity.data.quantity.get(), base_velocity_mean_world) base_odom_velocity = base_velocity_mean_world[[0, 1, 5]] np.testing.assert_allclose( env.quantities["odometry_velocity"], base_odom_velocity) @@ -383,25 +442,51 @@ def test_average_momentum(self): np.testing.assert_allclose( env.quantities["base_momentum"], angular_momentum) - def test_motor_positions(self): + def test_actuated_joints_kinematic(self): """ TODO: Write documentation """ - env = gym.make("gym_jiminy.envs:atlas") - - env.quantities["actuated_joint_positions"] = ( - ActuatedJointsPosition, dict(mode=QuantityEvalMode.TRUE)) + env = gym.make("gym_jiminy.envs:cassie") + + for level in ( + pin.KinematicLevel.POSITION, + pin.KinematicLevel.VELOCITY, + pin.KinematicLevel.ACCELERATION): + env.quantities[f"joint_{level}"] = ( + MultiActuatedJointKinematic, dict( + kinematic_level=level, + is_motor_side=False, + mode=QuantityEvalMode.TRUE)) + if level < 2: + env.quantities[f"motor_{level}"] = ( + MultiActuatedJointKinematic, dict( + kinematic_level=level, + is_motor_side=True, + mode=QuantityEvalMode.TRUE)) - env.reset(seed=0) - env.step(env.action_space.sample()) - - position_indices = [] - for motor in env.robot.motors: - joint = env.robot.pinocchio_model.joints[motor.joint_index] - position_indices += range(joint.idx_q, joint.idx_q + joint.nq) - - np.testing.assert_allclose( - env.quantities["actuated_joint_positions"], - env.robot_state.q[position_indices]) + env.reset(seed=0) + env.step(env.action_space.sample()) + + kinematic_indices = [] + for motor in env.robot.motors: + joint = env.robot.pinocchio_model.joints[motor.joint_index] + if level == pin.KinematicLevel.POSITION: + kin_first, kin_last = joint.idx_q, joint.idx_q + joint.nq + else: + kin_first, kin_last = joint.idx_v, joint.idx_v + joint.nv + kinematic_indices += range(kin_first, kin_last) + if level == pin.KinematicLevel.POSITION: + joint_value = env.robot_state.q[kinematic_indices] + elif level == pin.KinematicLevel.VELOCITY: + joint_value = env.robot_state.v[kinematic_indices] + else: + joint_value = env.robot_state.a[kinematic_indices] + encoder_data = env.robot.sensor_measurements["EncoderSensor"] + + np.testing.assert_allclose( + env.quantities[f"joint_{level}"], joint_value) + if level < 2: + np.testing.assert_allclose( + env.quantities[f"motor_{level}"], encoder_data[level]) def test_capture_point(self): """ TODO: Write documentation @@ -432,9 +517,9 @@ def test_capture_point(self): env.step(env.action_space.sample()) com_position = env.robot.pinocchio_data.com[0] - np.testing.assert_allclose(quantity.com_position, com_position) + np.testing.assert_allclose(quantity.com_position.get(), com_position) com_velocity = env.robot.pinocchio_data.vcom[0] - np.testing.assert_allclose(quantity.com_velocity, com_velocity) + np.testing.assert_allclose(quantity.com_velocity.get(), com_velocity) np.testing.assert_allclose( env.quantities["dcm"], com_position[:2] + com_velocity[:2] / omega) @@ -448,7 +533,7 @@ def test_mean_pose(self): frame.name for frame in env.robot.pinocchio_model.frames] env.quantities["mean_pose"] = ( - MultiFramesMeanXYZQuat, dict( + MultiFrameMeanXYZQuat, dict( frame_names=frame_names[:5], mode=QuantityEvalMode.TRUE)) @@ -509,18 +594,18 @@ def test_foot_relative_pose(self): for frame_name in ("l_foot", "r_foot"): frame_index = env.robot.pinocchio_model.getFrameId(frame_name) foot_poses.append(env.robot.pinocchio_data.oMf[frame_index]) - pos_feet = np.stack(tuple( - foot_pose.translation for foot_pose in foot_poses), axis=-1) - quat_feet = np.stack(tuple( + pos_feet = np.stack([ + foot_pose.translation for foot_pose in foot_poses], axis=-1) + quat_feet = np.stack([ matrix_to_quat(foot_pose.rotation) - for foot_pose in foot_poses), axis=-1) + for foot_pose in foot_poses], axis=-1) pos_mean = np.mean(pos_feet, axis=-1, keepdims=True) rot_mean = quat_to_matrix(quat_average(quat_feet)) pos_rel = rot_mean.T @ (pos_feet - pos_mean) - quat_rel = np.stack(tuple( + quat_rel = np.stack([ matrix_to_quat(rot_mean.T @ foot_pose.rotation) - for foot_pose in foot_poses), axis=-1) + for foot_pose in foot_poses], axis=-1) quat_rel[-4:] *= np.sign(quat_rel[-1]) value = env.quantities["foot_rel_poses"].copy() @@ -529,13 +614,13 @@ def test_foot_relative_pose(self): np.testing.assert_allclose(value[:3], pos_rel[:, :-1]) np.testing.assert_allclose(value[-4:], quat_rel[:, :-1]) - def test_tangential_forces(self): + def test_contact_spatial_forces(self): """ TODO: Write documentation """ env = gym.make("gym_jiminy.envs:atlas") - env.quantities["force_tangential_rel"] = ( - MultiContactRelativeForceTangential, {}) + env.quantities["force_spatial_rel"] = ( + MultiContactNormalizedSpatialForce, {}) env.reset(seed=0) for _ in range(10): @@ -543,20 +628,20 @@ def test_tangential_forces(self): gravity = abs(env.robot.pinocchio_model.gravity.linear[2]) robot_weight = env.robot.pinocchio_data.mass[0] * gravity - force_tangential_rel = np.stack(tuple( - constraint.lambda_c[:2] - for constraint in env.robot.constraints.contact_frames.values()), + force_spatial_rel = np.stack([np.concatenate( + (constraint.lambda_c[:3], np.zeros((2,)), constraint.lambda_c[[3]]) + ) for constraint in env.robot.constraints.contact_frames.values()], axis=-1) / robot_weight np.testing.assert_allclose( - force_tangential_rel, env.quantities["force_tangential_rel"]) + force_spatial_rel, env.quantities["force_spatial_rel"]) - def test_vertical_forces(self): + def test_foot_vertical_forces(self): """ TODO: Write documentation """ env = gym.make("gym_jiminy.envs:atlas-pid") env.quantities["force_vertical_rel"] = ( - MultiFootRelativeForceVertical, {}) + MultiFootNormalizedForceVertical, {}) env.reset(seed=0) for _ in range(10): @@ -575,3 +660,88 @@ def test_vertical_forces(self): np.testing.assert_allclose( force_vertical_rel, env.quantities["force_vertical_rel"]) np.testing.assert_allclose(np.sum(force_vertical_rel), 1.0, atol=1e-3) + + def test_base_height(self): + env = gym.make("gym_jiminy.envs:atlas-pid") + + env.quantities["base_height"] = (BaseRelativeHeight, {}) + + env.reset(seed=0) + action = env.action_space.sample() + for _ in range(10): + env.step(action) + + value = env.quantities["base_height"] + base_z = env.robot.pinocchio_data.oMf[1].translation[[2]] + contacts_z = [] + for constraint in env.robot.constraints.contact_frames.values(): + frame_index = constraint.frame_index + frame_pos = env.robot.pinocchio_data.oMf[frame_index] + contacts_z.append(frame_pos.translation[[2]]) + np.testing.assert_allclose(base_z - np.min(contacts_z), value) + + def test_frames_collision(self): + env = gym.make("gym_jiminy.envs:atlas-pid", step_dt=0.01) + + env.quantities["frames_collision"] = ( + MultiFrameCollisionDetection, dict( + frame_names=("l_foot", "r_foot"), + security_margin=0.0)) + + motor_names = [motor.name for motor in env.robot.motors] + left_motor_index = motor_names.index('l_leg_hpx') + right_motor_index = motor_names.index('r_leg_hpx') + action = np.zeros((len(motor_names),)) + action[[left_motor_index, right_motor_index]] = -0.5, 0.5 + + env.robot.remove_contact_points([]) + env.eval() + env.reset(seed=0) + assert not env.quantities["frames_collision"] + for _ in range(20): + env.step(action) + if env.quantities["frames_collision"]: + break + else: + raise AssertionError("No collision detected.") + + def test_power_consumption(self): + env = gym.make("gym_jiminy.envs:cassie") + + for mode in ( + EnergyGenerationMode.CHARGE, + EnergyGenerationMode.LOST_EACH, + EnergyGenerationMode.LOST_GLOBAL, + EnergyGenerationMode.PENALIZE): + env.quantities["mean_power_consumption"] = ( + AverageMechanicalPowerConsumption, dict( + horizon=0.2, + generator_mode=mode)) + quantity = env.quantities.registry["mean_power_consumption"] + env.reset(seed=0) + + total_power_stack = [0.0,] + encoder_data = env.robot.sensor_measurements["EncoderSensor"] + _, motor_velocities = encoder_data + for _ in range(8): + motor_efforts = 0.1 * env.action_space.sample() + env.step(motor_efforts) + + motor_powers = motor_efforts * motor_velocities + if mode == EnergyGenerationMode.CHARGE: + total_power = np.sum(motor_powers) + elif mode == EnergyGenerationMode.LOST_EACH: + total_power = np.sum(np.maximum(motor_powers, 0.0)) + elif mode == EnergyGenerationMode.LOST_GLOBAL: + total_power = max(np.sum(motor_powers), 0.0) + else: + total_power = np.sum(np.abs(motor_powers)) + total_power_stack.append(total_power) + mean_total_power = np.mean( + total_power_stack[-quantity.max_stack:]) + + value = quantity.total_power_stack.get() + np.testing.assert_allclose(total_power, value[-1]) + np.testing.assert_allclose(mean_total_power, quantity.get()) + + del env.quantities["mean_power_consumption"] diff --git a/python/gym_jiminy/unit_py/test_rewards.py b/python/gym_jiminy/unit_py/test_rewards.py index 03e07bc8b..b3e043a87 100644 --- a/python/gym_jiminy/unit_py/test_rewards.py +++ b/python/gym_jiminy/unit_py/test_rewards.py @@ -22,7 +22,7 @@ AdditiveMixtureReward) from gym_jiminy.toolbox.compositions import ( tanh_normalization, - MaximizeStability) + MaximizeRobusntess) class Rewards(unittest.TestCase): @@ -44,6 +44,8 @@ def setUp(self): self.env.quantities.select_trajectory("reference") def test_deletion(self): + """ TODO: Write documentation + """ assert len(self.env.quantities.registry) == 0 reward_survive = TrackingActuatedJointPositionsReward( self.env, cutoff=1.0) @@ -53,6 +55,8 @@ def test_deletion(self): assert len(self.env.quantities.registry) == 0 def test_tracking(self): + """ TODO: Write documentation + """ for reward_class, cutoff in ( (TrackingBaseOdometryVelocityReward, 20.0), (TrackingActuatedJointPositionsReward, 10.0), @@ -62,8 +66,7 @@ def test_tracking(self): (TrackingFootOrientationsReward, 2.0), (TrackingFootForceDistributionReward, 2.0)): reward = reward_class(self.env, cutoff=cutoff) - quantity_true = reward.quantity.requirements['value_left'] - quantity_ref = reward.quantity.requirements['value_right'] + quantity = reward.data self.env.reset(seed=0) action = 0.5 * self.env.action_space.sample() @@ -71,17 +74,18 @@ def test_tracking(self): self.env.step(action) _, _, terminated, _, _ = self.env.step(self.env.action) + value_left = quantity.quantity_left.get() + value_right = quantity.quantity_right.get() with np.testing.assert_raises(AssertionError): - np.testing.assert_allclose( - quantity_true.get(), quantity_ref.get()) + np.testing.assert_allclose(value_left, value_right) if isinstance(reward, TrackingBaseHeightReward): np.testing.assert_allclose( - quantity_true.get(), self.env.robot_state.q[2]) + value_left, self.env.robot_state.q[2]) gamma = - np.log(CUTOFF_ESP) / cutoff ** 2 - value = np.exp(- gamma * np.sum((reward.quantity.op( - quantity_true.get(), quantity_ref.get())) ** 2)) + value = np.exp(- gamma * np.sum( + (quantity.op(value_left, value_right)) ** 2)) assert value > 0.01 np.testing.assert_allclose(reward(terminated, {}), value) @@ -113,16 +117,19 @@ def test_mixture(self): 0.2 * reward_survive(terminated, {})) def test_stability(self): + """ TODO: Write documentation + """ CUTOFF_INNER, CUTOFF_OUTER = 0.1, 0.5 - reward_stability = MaximizeStability( - self.env, cutoff_inner=0.1, cutoff_outer=0.5) - quantity = reward_stability.quantity + reward_stability = MaximizeRobusntess( + self.env, cutoff=0.1, cutoff_outer=0.5) + quantity = reward_stability.data self.env.reset(seed=0) action = self.env.action_space.sample() _, _, terminated, _, _ = self.env.step(action) - dist = quantity.support_polygon.get_distance_to_point(quantity.zmp) + support_polygon = quantity.support_polygon.get() + dist = support_polygon.get_distance_to_point(quantity.zmp.get()) value = tanh_normalization(dist.item(), -CUTOFF_INNER, CUTOFF_OUTER) np.testing.assert_allclose(tanh_normalization( -CUTOFF_INNER, -CUTOFF_INNER, CUTOFF_OUTER), 1.0 - CUTOFF_ESP) @@ -131,9 +138,11 @@ def test_stability(self): np.testing.assert_allclose(reward_stability(terminated, {}), value) def test_friction(self): + """ TODO: Write documentation + """ CUTOFF = 0.5 reward_friction = MinimizeFrictionReward(self.env, cutoff=CUTOFF) - quantity = reward_friction.quantity + quantity = reward_friction.data self.env.reset(seed=0) _, _, terminated, _, _ = self.env.step(self.env.action) diff --git a/python/gym_jiminy/unit_py/test_terminations.py b/python/gym_jiminy/unit_py/test_terminations.py new file mode 100644 index 000000000..2a075faf9 --- /dev/null +++ b/python/gym_jiminy/unit_py/test_terminations.py @@ -0,0 +1,379 @@ +""" TODO: Write documentation +""" +from operator import sub +import unittest + +import numpy as np + +import gymnasium as gym +from jiminy_py.log import extract_trajectory_from_log + +from gym_jiminy.common.utils import ( + quat_difference, matrix_to_quat, matrix_to_rpy) +from gym_jiminy.common.bases import EpisodeState, ComposedJiminyEnv +from gym_jiminy.common.quantities import ( + OrientationType, FramePosition, FrameOrientation) +from gym_jiminy.common.compositions import ( + DriftTrackingQuantityTermination, + ShiftTrackingQuantityTermination, + BaseRollPitchTermination, + FallingTermination, + FootCollisionTermination, + MechanicalSafetyTermination, + FlyingTermination, + ImpactForceTermination, + MechanicalPowerConsumptionTermination, + DriftTrackingBaseOdometryPositionTermination, + DriftTrackingBaseOdometryOrientationTermination, + ShiftTrackingMotorPositionsTermination, + ShiftTrackingFootOdometryPositionsTermination, + ShiftTrackingFootOdometryOrientationsTermination) + + +class TerminationConditions(unittest.TestCase): + """ TODO: Write documentation + """ + def setUp(self): + self.env = gym.make("gym_jiminy.envs:atlas-pid", debug=False) + + self.env.eval() + self.env.reset(seed=1) + action = 0.5 * self.env.action_space.sample() + for _ in range(25): + self.env.step(action) + self.env.stop() + trajectory = extract_trajectory_from_log(self.env.log_data) + self.env.train() + + self.env.quantities.add_trajectory("reference", trajectory) + self.env.quantities.select_trajectory("reference") + + def test_composition(self): + """ TODO: Write documentation + """ + ROLL_MIN, ROLL_MAX = -0.2, 0.2 + PITCH_MIN, PITCH_MAX = -0.05, 0.3 + termination = BaseRollPitchTermination( + self.env, + np.array([ROLL_MIN, PITCH_MIN]), + np.array([ROLL_MAX, PITCH_MAX])) + self.env.reset(seed=0) + env = ComposedJiminyEnv(self.env, terminations=(termination,)) + + env.reset(seed=0) + action = self.env.action_space.sample() + for _ in range(20): + _, _, terminated_env, _, _ = env.step(action) + terminated_cond, _ = termination({}) + assert not (terminated_env ^ terminated_cond) + if terminated_env: + terminated_unwrapped, _ = env.unwrapped.has_terminated({}) + assert not terminated_unwrapped + break + + def test_drift_tracking(self): + """ TODO: Write documentation + """ + termination_pos_config = ("pos", (FramePosition, {}), -0.2, 0.3, sub) + termination_rot_config = ( + "rot", + (FrameOrientation, dict(type=OrientationType.QUATERNION)), + np.array([-0.5, -0.6, -0.7]), + np.array([0.7, 0.5, 0.6]), + quat_difference) + + for i, (is_truncation, is_training_only) in enumerate(( + (False, False), (True, False), (False, True))): + termination_pos, termination_rot = ( + DriftTrackingQuantityTermination( + self.env, + f"drift_tracking_{name}_{i}", + lambda mode: (quantity_cls, dict( + **quantity_kwargs, + frame_name="root_joint", + mode=mode)), + low=low, + high=high, + horizon=0.3, + grace_period=0.2, + op=op, + is_truncation=is_truncation, + is_training_only=is_training_only + ) for name, (quantity_cls, quantity_kwargs), low, high, op in ( + termination_pos_config, termination_rot_config)) + + self.env.reset(seed=0) + self.env.eval() + action = self.env.action_space.sample() + oMf = self.env.robot.pinocchio_data.oMf[1] + position, rotation = oMf.translation, oMf.rotation + + positions, rotations = [], [] + for _ in range(25): + info_pos, info_rot = {}, {} + flags_pos = termination_pos(info_pos) + flags_rot = termination_rot(info_rot) + + positions.append(position.copy()) + rotations.append(matrix_to_quat(rotation)) + + for termination, (terminated, truncated), values, info in ( + (termination_pos, flags_pos, positions, info_pos), + (termination_rot, flags_rot, rotations, info_rot)): + values = values[-termination.max_stack:] + drift = termination.op(values[-1], values[0]) + value = termination.data.quantity_left.get() + np.testing.assert_allclose(drift, value) + + time = self.env.stepper_state.t + is_active = ( + time >= termination.grace_period and + not termination.is_training_only) + assert info == { + termination.name: EpisodeState.TERMINATED + if terminated else EpisodeState.TRUNCATED + if truncated else EpisodeState.CONTINUED} + if terminated or truncated: + assert is_active + assert terminated ^ termination.is_truncation + elif is_active: + value = termination.data.get() + assert np.all(value >= termination.low) + assert np.all(value <= termination.high) + _, _, terminated, truncated, _ = self.env.step(action) + if terminated or truncated: + break + + def test_shift_tracking(self): + """ TODO: Write documentation + """ + termination_pos_config = ("pos", (FramePosition, {}), 0.1, sub) + termination_rot_config = ( + "rot", + (FrameOrientation, dict(type=OrientationType.QUATERNION)), + 0.3, + quat_difference) + + for i, (is_truncation, is_training_only) in enumerate(( + (False, False), (True, False), (False, True))): + termination_pos, termination_rot = ( + ShiftTrackingQuantityTermination( + self.env, + f"shift_tracking_{name}_{i}", + lambda mode: (quantity_cls, dict( + **quantity_kwargs, + frame_name="root_joint", + mode=mode)), + thr=thr, + horizon=0.3, + grace_period=0.2, + op=op, + is_truncation=is_truncation, + is_training_only=is_training_only + ) for name, (quantity_cls, quantity_kwargs), thr, op in ( + termination_pos_config, termination_rot_config)) + + self.env.reset(seed=0) + self.env.eval() + action = self.env.action_space.sample() + oMf = self.env.robot.pinocchio_data.oMf[1] + position, rotation = oMf.translation, oMf.rotation + + positions, rotations = [], [] + for _ in range(25): + info_pos, info_rot = {}, {} + flags_pos = termination_pos(info_pos) + flags_rot = termination_rot(info_rot) + + positions.append(position.copy()) + rotations.append(matrix_to_quat(rotation)) + + for termination, (terminated, truncated), values, info in ( + (termination_pos, flags_pos, positions, info_pos), + (termination_rot, flags_rot, rotations, info_rot)): + values = values[-termination.max_stack:] + stack = np.stack(values, axis=-1) + left = termination.data.quantity_left.get() + np.testing.assert_allclose(stack, left) + right = termination.data.quantity_right.get() + diff = termination.op(left, right) + shift = np.min(np.linalg.norm( + diff.reshape((-1, len(values))), axis=0)) + value = termination.data.get() + np.testing.assert_allclose(shift, value) + + time = self.env.stepper_state.t + is_active = ( + time >= termination.grace_period and + not termination.is_training_only) + assert info == { + termination.name: EpisodeState.TERMINATED + if terminated else EpisodeState.TRUNCATED + if truncated else EpisodeState.CONTINUED} + if terminated or truncated: + assert is_active + assert terminated ^ termination.is_truncation + elif is_active: + assert np.all(value <= termination.high) + _, _, terminated, truncated, _ = self.env.step(action) + if terminated or truncated: + break + + def test_base_roll_pitch(self): + """ TODO: Write documentation + """ + ROLL_MIN, ROLL_MAX = -0.2, 0.2 + PITCH_MIN, PITCH_MAX = -0.05, 0.3 + roll_pitch_termination = BaseRollPitchTermination( + self.env, + np.array([ROLL_MIN, PITCH_MIN]), + np.array([ROLL_MAX, PITCH_MAX])) + + self.env.reset(seed=0) + rotation = self.env.robot.pinocchio_data.oMf[1].rotation + action = self.env.action_space.sample() + for _ in range(20): + _, _, terminated, _, _ = self.env.step(action) + if terminated: + break + terminated, truncated = roll_pitch_termination({}) + assert not truncated + roll, pitch, _ = matrix_to_rpy(rotation) + is_valid = ( + ROLL_MIN < roll < ROLL_MAX and PITCH_MIN < pitch < PITCH_MAX) + assert terminated ^ is_valid + + def test_foot_collision(self): + """ TODO: Write documentation + """ + termination = FootCollisionTermination(self.env, security_margin=0.0) + + motor_names = [motor.name for motor in self.env.robot.motors] + left_motor_index = motor_names.index('l_leg_hpx') + right_motor_index = motor_names.index('r_leg_hpx') + action = np.zeros((len(motor_names),)) + action[[left_motor_index, right_motor_index]] = -0.5, 0.5 + + self.env.robot.remove_contact_points([]) + self.env.eval() + self.env.reset(seed=0) + for _ in range(10): + self.env.step(action) + terminated, truncated = termination({}) + assert not truncated + if terminated: + break + else: + raise AssertionError("No collision detected.") + + def test_safety_limits(self): + """ TODO: Write documentation + """ + POSITION_MARGIN, VELOCITY_MAX = 0.05, 1.0 + termination = MechanicalSafetyTermination( + self.env, POSITION_MARGIN, VELOCITY_MAX) + + self.env.reset(seed=0) + + position_indices, velocity_indices = [], [] + pincocchio_model = self.env.robot.pinocchio_model + for motor in self.env.robot.motors: + joint = pincocchio_model.joints[motor.joint_index] + position_indices.append(joint.idx_q) + velocity_indices.append(joint.idx_v) + position_lower = pincocchio_model.lowerPositionLimit[position_indices] + position_lower += POSITION_MARGIN + position_upper = pincocchio_model.upperPositionLimit[position_indices] + position_upper -= POSITION_MARGIN + + action = self.env.action_space.sample() + for _ in range(20): + _, _, terminated, _, _ = self.env.step(action) + if terminated: + break + terminated, truncated = termination({}) + position = self.env.robot_state.q[position_indices] + velocity = self.env.robot_state.v[velocity_indices] + is_valid = np.all( + (position_lower <= position) | (velocity >= - VELOCITY_MAX)) + is_valid = is_valid and np.all( + (position_upper >= position) | (velocity <= VELOCITY_MAX)) + assert terminated ^ is_valid + + def test_flying(self): + """ TODO: Write documentation + """ + MAX_HEIGHT = 0.02 + termination = FlyingTermination(self.env, max_height=MAX_HEIGHT) + + self.env.reset(seed=0) + + engine_options = self.env.unwrapped.engine.get_options() + heightmap = engine_options["world"]["groundProfile"] + + action = self.env.action_space.sample() + for _ in range(20): + _, _, terminated, _, _ = self.env.step(action) + if terminated: + break + terminated, truncated = termination({}) + is_valid = False + for frame_index in self.env.robot.contact_frame_indices: + transform = self.env.robot.pinocchio_data.oMf[frame_index] + position = transform.translation + height, normal = heightmap(position[:2]) + depth = (position[2] - height) * normal[2] + if depth <= MAX_HEIGHT: + is_valid = True + break + assert terminated ^ is_valid + + def test_drift_tracking_base_odom(self): + """ TODO: Write documentation + """ + MAX_POS_ERROR, MAX_ROT_ERROR = 0.1, 0.2 + termination_pos = DriftTrackingBaseOdometryPositionTermination( + self.env, MAX_POS_ERROR, 1.0) + quantity_pos = termination_pos.data + termination_rot = DriftTrackingBaseOdometryOrientationTermination( + self.env, MAX_ROT_ERROR, 1.0) + quantity_rot = termination_rot.data + + self.env.reset(seed=0) + action = self.env.action_space.sample() + for _ in range(20): + _, _, terminated, _, _ = self.env.step(action) + if terminated: + break + terminated, truncated = termination_pos({}) + value_left = quantity_pos.quantity_left.get() + value_right = quantity_pos.quantity_right.get() + diff = value_left - value_right + is_valid = np.linalg.norm(diff) <= MAX_POS_ERROR + assert terminated ^ is_valid + value_left = quantity_rot.quantity_left.get() + value_right = quantity_rot.quantity_right.get() + diff = value_left - value_right + terminated, truncated = termination_rot({}) + is_valid = np.abs(diff) <= MAX_ROT_ERROR + assert terminated ^ is_valid + + def test_misc(self): + """ TODO: Write documentation + """ + for termination in ( + FallingTermination(self.env, 0.6), + ImpactForceTermination(self.env, 1.0), + MechanicalPowerConsumptionTermination(self.env, 400.0, 1.0), + ShiftTrackingMotorPositionsTermination(self.env, 0.4, 0.5), + ShiftTrackingFootOdometryPositionsTermination( + self.env, 0.2, 0.5), + ShiftTrackingFootOdometryOrientationsTermination( + self.env, 0.1, 0.5)): + self.env.reset(seed=0) + self.env.eval() + action = self.env.action_space.sample() + for _ in range(20): + _, _, terminated, _, _ = self.env.step(action) + terminated, truncated = termination({}) + assert not truncated diff --git a/python/jiminy_py/examples/collision_detection.py b/python/jiminy_py/examples/collision_detection.py index 423fb44d6..50e6061b7 100644 --- a/python/jiminy_py/examples/collision_detection.py +++ b/python/jiminy_py/examples/collision_detection.py @@ -25,9 +25,9 @@ def __init__(self, geom_model.getGeometryId, (geom_name_1, geom_name_2)) self.oMg1, self.oMg2 = ( geom_data.oMg[i] for i in (geom_index_1, geom_index_2)) - self.collide_functor = hppfcl.ComputeCollision(*( + self.collide_functor = hppfcl.ComputeCollision(*[ geom_model.geometryObjects[i].geometry - for i in (geom_index_1, geom_index_2))) + for i in (geom_index_1, geom_index_2)]) self.req = hppfcl.CollisionRequest() self.req.enable_cached_gjk_guess = True self.req.distance_upper_bound = 1e-6 diff --git a/python/jiminy_py/examples/extra_cameras.py b/python/jiminy_py/examples/extra_cameras.py new file mode 100644 index 000000000..4c246d82a --- /dev/null +++ b/python/jiminy_py/examples/extra_cameras.py @@ -0,0 +1,74 @@ +import numpy as np +import matplotlib.pyplot as plt + +import gymnasium as gym + +from panda3d.core import VBase4, Point3, Vec3 +from jiminy_py.viewer import Viewer +import pinocchio as pin + +Viewer.close() +#Viewer.connect_backend("panda3d-sync") +env = gym.make("gym_jiminy.envs:atlas-pid", viewer_kwargs={"backend": "panda3d-sync"}) +env.reset(seed=0) +env.step(env.action) +#env.render() +env.simulator.render(return_rgb_array=True) + +env.viewer.add_marker("sphere", + shape="sphere", + pose=(np.array((1.7, 0.0, 1.5)), None), + color="red", + radius=0.1, + always_foreground=False) + +Viewer.add_camera("rgb", height=200, width=200, is_depthmap=False) +Viewer.add_camera("depth", height=128, width=128, is_depthmap=True) +Viewer.set_camera_transform( + position=[2.5, -1.4, 1.6], # [3.0, 0.0, 0.0], + rotation=[1.35, 0.0, 0.8], # [np.pi/2, 0.0, np.pi/2] + camera_name="depth") + +frame_index = env.robot.pinocchio_model.getFrameId("head") +frame_pose = env.robot.pinocchio_data.oMf[frame_index] +# Viewer._backend_obj.gui.set_camera_transform( +# pos=frame_pose.translation + np.array([0.0, 0.0, 0.0]), +# quat=pin.Quaternion(frame_pose.rotation @ pin.rpy.rpyToMatrix(0.0, 0.0, -np.pi/2)).coeffs(), +# camera_name="rgb") +Viewer.set_camera_transform( + position=frame_pose.translation + np.array([0.0, 0.0, 0.0]), + rotation=pin.rpy.matrixToRpy(frame_pose.rotation @ pin.rpy.rpyToMatrix(np.pi/2, 0.0, -np.pi/2)), + camera_name="rgb") + +lens = Viewer._backend_obj.render.find("user_camera_depth").node().get_lens() +# proj = lens.get_projection_mat_inv() +# buffer = Viewer._backend_obj._user_buffers["depth"] +# buffer.trigger_copy() +# Viewer._backend_obj.graphics_engine.render_frame() +# texture = buffer.get_texture() +# tex_peeker = texture.peek() +# pixel = VBase4() +# tex_peeker.lookup(pixel, 0.5, 0.5) # (y, x normalized coordinates, from top-left to bottom-right) +# depth_rel = 2.0 * pixel[0] - 1.0 # map range [0.0 (near), 1.0 (far)] to [-1.0, 1.0] +# point = Point3() +# #lens.extrude_depth(Point3(0.0, 0.0, depth_rel), point) +# # proj.xform_point_general(Point3(0.0, 0.0, pixel[0])) +# # depth = point[1] +# depth = 1.0 / (proj[2][3] * depth_rel + proj[3][3]) +# print(depth) + +rgb_array = Viewer.capture_frame(camera_name="rgb") +depth_array = Viewer.capture_frame(camera_name="depth") +# depth_normalized_array = lens.near / (lens.far - (lens.far - lens.near) * depth_array) +depth_true_array = lens.near / (1.0 - (1.0 - lens.near / lens.far) * depth_array) +fig = plt.figure(figsize=(10, 5)) +ax1 = fig.add_subplot(121) +ax1.imshow(rgb_array) +ax2 = fig.add_subplot(122) +ax2.imshow(depth_true_array, cmap=plt.cm.binary) +for ax in (ax1, ax2): + ax.axis('off') + ax.xaxis.set_visible(False) + ax.yaxis.set_visible(False) +fig.tight_layout(pad=1.0) +plt.show(block=False) diff --git a/python/jiminy_py/setup.py b/python/jiminy_py/setup.py index 2746ecbcd..336612d81 100644 --- a/python/jiminy_py/setup.py +++ b/python/jiminy_py/setup.py @@ -136,7 +136,7 @@ def finalize_options(self) -> None: "meshcat>=0.3.2", # Used to detect running Meshcat servers and avoid orphan child # processes. - "psutil", + "psutil>=6.0", # Low-level backend for Ipython powering Jupyter notebooks "ipykernel>=5.0,<7.0", # Used internally by Viewer to read/write Meshcat snapshots diff --git a/python/jiminy_py/src/jiminy_py/dynamics.py b/python/jiminy_py/src/jiminy_py/dynamics.py index 5b18f6969..9a6c8a9cd 100644 --- a/python/jiminy_py/src/jiminy_py/dynamics.py +++ b/python/jiminy_py/src/jiminy_py/dynamics.py @@ -287,15 +287,15 @@ def get(self, that are available. :param t: Time of the state to extract from the trajectory. - :param mode: Specifies how to deal with query time of are out of the - time interval 'time_interval' of the trajectory. Specify - 'raise' to raise an exception if the query time is - out-of-bound wrt to underlying state sequence of the - selected trajectory. Specify 'clip' to force clipping of - the query time before interpolation of the state sequence. - Specify 'wrap' to wrap around the query time wrt the time - span of the trajectory. This is useful to store periodic - trajectories as finite state sequences. + :param mode: Fallback strategy when the query time is not in the time + interval 'time_interval' of the trajectory. 'raise' raises + an exception if the query time is out-of-bound wrt the + underlying state sequence of the selected trajectory. + 'clip' forces clipping of the query time before + interpolation of the state sequence. 'wrap' wraps around + the query time wrt the time span of the trajectory. This + is useful to store periodic trajectories as finite state + sequences. """ # Raise exception if state sequence is empty if not self.has_data: @@ -311,7 +311,10 @@ def get(self, if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL: raise RuntimeError("Time is out-of-range.") elif mode == "wrap": - t = ((t - t_start) % (t_end - t_start)) + t_start + if t_end > t_start: + t = ((t - t_start) % (t_end - t_start)) + t_start + else: + t = t_start else: t = max(t, t_start) # Clipping right it is sufficient diff --git a/python/jiminy_py/src/jiminy_py/simulator.py b/python/jiminy_py/src/jiminy_py/simulator.py index 06d3a2b43..ac3e79fca 100644 --- a/python/jiminy_py/src/jiminy_py/simulator.py +++ b/python/jiminy_py/src/jiminy_py/simulator.py @@ -202,7 +202,7 @@ def build(cls, config_path: Optional[str] = None, avoid_instable_collisions: bool = True, debug: bool = False, - *, robot_name: str = "", + *, name: str = "", **kwargs: Any) -> 'Simulator': r"""Create a new single-robot simulator instance from scratch based on configuration files only. @@ -233,8 +233,8 @@ def build(cls, its vertices. :param debug: Whether the debug mode must be activated. Doing it enables temporary files automatic deletion. - :param robot_name: Desired name of the robot. - Optional: Empty string by default. + :param name: Desired name of the robot. + Optional: Empty string by default. :param kwargs: Keyword arguments to forward to class constructor. """ # Handling of default argument(s) @@ -246,7 +246,7 @@ def build(cls, # Instantiate and initialize the robot robot = _build_robot_from_urdf( - robot_name, urdf_path, hardware_path, mesh_path_dir, has_freeflyer, + name, urdf_path, hardware_path, mesh_path_dir, has_freeflyer, avoid_instable_collisions, debug) # Instantiate and initialize the engine diff --git a/python/jiminy_py/src/jiminy_py/tree.py b/python/jiminy_py/src/jiminy_py/tree.py index 9b33940e8..c3dd75f59 100644 --- a/python/jiminy_py/src/jiminy_py/tree.py +++ b/python/jiminy_py/src/jiminy_py/tree.py @@ -173,14 +173,22 @@ def _unflatten_as(data: StructNested[Any], """ data_type = type(data) if issubclass_mapping(data_type): # type: ignore[arg-type] - return data_type({ # type: ignore[call-arg] - key: _unflatten_as(value, data_leaf_it) - for key, value in data.items() # type: ignore[union-attr] - }) + flat_items = [ + (key, _unflatten_as(value, data_leaf_it)) + for key, value in data.items()] # type: ignore[union-attr] + try: + # Initialisation from dict cannot be the default path as + # `gym.spaces.Dict` would sort keys in this specific scenario, + # which must be avoided. + return data_type(flat_items) # type: ignore[call-arg] + except (ValueError, RuntimeError): + # Fallback to initialisation from dict in the rare event of + # a container type not supporting initialisation from a + # sequence of key-value pairs. + return data_type(dict(flat_items)) # type: ignore[call-arg] if issubclass_sequence(data_type): # type: ignore[arg-type] - return data_type(tuple( # type: ignore[call-arg] - _unflatten_as(value, data_leaf_it) for value in data - )) + return data_type([ # type: ignore[call-arg] + _unflatten_as(value, data_leaf_it) for value in data]) return next(data_leaf_it) diff --git a/python/jiminy_py/src/jiminy_py/viewer/meshcat/recorder.py b/python/jiminy_py/src/jiminy_py/viewer/meshcat/recorder.py index e1014e845..ccca76ad5 100644 --- a/python/jiminy_py/src/jiminy_py/viewer/meshcat/recorder.py +++ b/python/jiminy_py/src/jiminy_py/viewer/meshcat/recorder.py @@ -20,7 +20,7 @@ PLAYWRIGHT_DOWNLOAD_TIMEOUT = 180.0 # 3min to download browser (~130Mo) -PLAYWRIGHT_START_TIMEOUT = 40000.0 # 40s +PLAYWRIGHT_START_TIMEOUT = 60000.0 # 60s WINDOW_SIZE_DEFAULT = (600, 600) @@ -350,7 +350,7 @@ def start_video_recording(self, """ TODO: Write documentation. """ self._send_request( - "start_record", message=f"{fps}|{width}|{height}", timeout=10.0) + "start_record", message=f"{fps}|{width}|{height}", timeout=15.0) self.is_recording = True def add_video_frame(self) -> None: diff --git a/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_visualizer.py b/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_visualizer.py index 28e0362d6..2d1548acd 100644 --- a/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_visualizer.py +++ b/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_visualizer.py @@ -34,8 +34,8 @@ from direct.gui.OnscreenImage import OnscreenImage from direct.gui.OnscreenText import OnscreenText from panda3d.core import ( # pylint: disable=no-name-in-module - NodePath, Point3, Vec3, Vec4, Mat4, Quat, LQuaternion, Geom, GeomEnums, - GeomNode, GeomTriangles, GeomVertexData, GeomVertexArrayFormat, + NodePath, PandaNode, Point3, Vec3, Vec4, Mat4, Quat, Geom, GeomEnums, + GeomNode, GeomTriangles, GeomVertexData, GeomVertexArrayFormat, BitMask32, GeomVertexFormat, GeomVertexWriter, PNMImage, PNMImageHeader, TextNode, OmniBoundingVolume, CompassEffect, BillboardEffect, InternalName, Filename, Material, Texture, TextureStage, TransparencyAttrib, PGTop, Camera, Lens, @@ -58,7 +58,7 @@ WINDOW_SIZE_DEFAULT = (600, 600) -CAMERA_POS_DEFAULT = [(4.0, -4.0, 1.5), (0, 0, 0.5)] +CAMERA_POSE_DEFAULT = [(4.0, -4.0, 1.5), (0, 0, 0.5)] SKY_TOP_COLOR = (0.53, 0.8, 0.98, 1.0) SKY_BOTTOM_COLOR = (0.1, 0.1, 0.43, 1.0) @@ -366,9 +366,39 @@ def make_torus(minor_radius: float = 0.2, num_segments: int = 16) -> Geom: return geom +def enable_pbr_shader(node: NodePath) -> None: + """Create physics-based shader. + + .. note:: + Lighting must be adapted accordingly to give the desired effect. + + .. warning:: + It slows down the rendering by about 30% on discrete NVIDIA GPU. + + :param node: Root node on which to apply shader, usually the camera itself. + """ + tempnode = NodePath(PandaNode("temp node")) + shader_options = {'ENABLE_SHADOWS': True} + pbr_shader = simplepbr.shaderutils.make_shader( + 'pbr', 'simplepbr.vert', 'simplepbr.frag', shader_options) + tempnode.set_attrib(ShaderAttrib.make(pbr_shader)) + env_map = simplepbr.EnvMap.create_empty() + tempnode.set_shader_input( + 'filtered_env_map', env_map.filtered_env_map) + tempnode.set_shader_input( + 'max_reflection_lod', + env_map.filtered_env_map.num_loadable_ram_mipmap_images) + tempnode.set_shader_input('sh_coeffs', env_map.sh_coefficients) + node.set_initial_state(tempnode.get_state()) + + class Panda3dApp(panda3d_viewer.viewer_app.ViewerApp): """A Panda3D based application. """ + UserRGBCameraMask = BitMask32(1 << 1) # 0x2 + UserDepthCameraMask = BitMask32(1 << 2) # 0x4 + UserCameraMask = UserRGBCameraMask | UserDepthCameraMask + def __init__(self, # pylint: disable=super-init-not-called config: Optional[ViewerConfig] = None) -> None: # Enforce viewer configuration @@ -376,7 +406,8 @@ def __init__(self, # pylint: disable=super-init-not-called config = ViewerConfig() config.set_window_size(*WINDOW_SIZE_DEFAULT) config.set_window_fixed(False) - config.enable_antialiasing(True, multisamples=4) + config.enable_antialiasing(False, multisamples=0) + # config.set_value('want-pstats', True) config.set_value('framebuffer-software', False) config.set_value('framebuffer-hardware', False) config.set_value('load-display', 'pandagl') @@ -388,13 +419,14 @@ def __init__(self, # pylint: disable=super-init-not-called config.set_value('sync-video', False) config.set_value('default-near', 0.1) config.set_value('gl-version', '3 1') + # config.set_value('gl-check-errors', '#t') config.set_value('notify-level', 'fatal') config.set_value('notify-level-x11display', 'fatal') config.set_value('notify-level-device', 'fatal') config.set_value('default-directnotify-level', 'error') loadPrcFileData('', str(config)) - # Define offscreen buffer + # Offscreen buffer self.buff: Optional[GraphicsOutput] = None # Initialize base implementation. @@ -429,28 +461,16 @@ def keyboardInterruptHandler( self._spotlight = self.config.GetBool('enable-spotlight', False) self._lights_mask = [True, True] - # Create physics-based shader and adapt lighting accordingly. - # It slows down the rendering by about 30% on discrete NVIDIA GPU. - shader_options = {'ENABLE_SHADOWS': True} - pbr_shader = simplepbr.shaderutils.make_shader( - 'pbr', 'simplepbr.vert', 'simplepbr.frag', shader_options) - self.render.set_attrib(ShaderAttrib.make(pbr_shader)) - env_map = simplepbr.EnvMap.create_empty() - self.render.set_shader_input( - 'filtered_env_map', env_map.filtered_env_map) - self.render.set_shader_input( - 'max_reflection_lod', - env_map.filtered_env_map.num_loadable_ram_mipmap_images) - self.render.set_shader_input('sh_coeffs', env_map.sh_coefficients) + # Adapt lighting to accomodate physics-based rendering. self._lights = [ self._make_light_ambient((0.5, 0.5, 0.5)), self._make_light_direct(1, (1.0, 1.0, 1.0), pos=(8.0, -8.0, 10.0))] - # Define default camera pos - self._camera_defaults = CAMERA_POS_DEFAULT + # Current camera pose + self._camera_defaults = CAMERA_POSE_DEFAULT self.reset_camera(*self._camera_defaults) - # Define clock. It will be used later to limit framerate + # Custom clock. It will be used later to limit framerate self.clock = ClockObject.get_global_clock() self.framerate: Optional[float] = None @@ -469,9 +489,9 @@ def keyboardInterruptHandler( # Create gradient for skybox self.skybox = make_gradient_skybox( SKY_TOP_COLOR, SKY_BOTTOM_COLOR, 0.35, 0.17) - self.skybox.set_shader_auto(True) + self.skybox.set_shader_auto() self.skybox.set_light_off() - self.skybox.hide(self.LightMask) + self.skybox.hide(self.LightMask | self.UserDepthCameraMask) # The background needs to be parented to an intermediary node to which # a compass effect is applied to keep it at the same position as the @@ -522,7 +542,7 @@ def keyboardInterruptHandler( self.offA2dBottomCenter.set_pos(0, 0, self.a2dBottom) self.offA2dBottomRight.set_pos(self.a2dRight, 0, self.a2dBottom) - # Define widget overlay + # Widget overlay self.offscreen_graphics_lens: Optional[Lens] = None self.offscreen_display_region: Optional[DisplayRegion] = None self._help_label = None @@ -530,10 +550,14 @@ def keyboardInterruptHandler( self._legend: Optional[OnscreenImage] = None self._clock: Optional[OnscreenText] = None - # Define input control + # Custom user-specified cameras + self._user_buffers: Dict[str, NodePath] = {} + self._user_cameras: Dict[str, NodePath] = {} + + # Input control self.key_map = {"mouse1": 0, "mouse2": 0, "mouse3": 0} - # Define camera control + # Camera control self.zoom_rate = 1.03 self.camera_lookat = np.zeros(3) self.longitude_deg = 0.0 @@ -541,13 +565,17 @@ def keyboardInterruptHandler( self.last_mouse_x = 0.0 self.last_mouse_y = 0.0 - # Define object/highlighting selector + # Object/highlighting selector self.picker_ray: Optional[CollisionRay] = None self.picker_node: Optional[CollisionNode] = None self.picked_object: Optional[Tuple[str, str]] = None self.click_mouse_x = 0.0 self.click_mouse_y = 0.0 + # Make the original window inactive without deleting it. + # It must be kept in order to maintain alive the same graphics context. + self.win.set_active(False) + # Create resizeable offscreen buffer. # Note that a resizable buffer is systematically created, no matter # if the main window is an offscreen non-resizable window or an @@ -580,16 +608,23 @@ def open_window(self) -> None: if self.has_gui(): raise RuntimeError("Only one graphical window can be opened.") + # Force enabling multi-sampling for onscreen graphical window + fbprops = FrameBufferProperties(FrameBufferProperties.getDefault()) + fbprops.set_multisamples(4) + # Replace the original offscreen window by an onscreen one if possible is_success = True size = self.win.get_size() try: self.windowType = 'onscreen' - self.open_main_window(size=size) + self.open_main_window(size=size, fbprops=fbprops) except Exception: # pylint: disable=broad-except is_success = False self.windowType = 'offscreen' - self.open_main_window(size=size) + self.open_main_window(size=size, fbprops=fbprops) + + # Enable Physics-based rendering + enable_pbr_shader(self.cam.node()) if is_success: # Setup mouse and keyboard controls for onscreen display @@ -648,12 +683,15 @@ def _open_offscreen_window(self, # Set offscreen buffer frame properties. # Note that accumulator bits and back buffers is not supported by # resizeable buffers. + # Beware MSAA is very picky on MacOS regarding the image format, with + # only a few variant being supported. # See https://github.com/panda3d/panda3d/issues/1121 + # https://github.com/panda3d/panda3d/issues/756 fbprops = FrameBufferProperties() - fbprops.set_rgba_bits(8, 8, 8, 0) + fbprops.set_rgba_bits(8, 8, 8, 8) fbprops.set_float_color(False) fbprops.set_depth_bits(16) - fbprops.set_float_depth(True) + fbprops.set_float_depth(False) fbprops.set_multisamples(4) # Set offscreen buffer windows properties @@ -671,23 +709,30 @@ def _open_offscreen_window(self, self.pipe, "offscreen_buffer", 0, fbprops, winprops, flags, self.win.get_gsg(), self.win) if win is None: - raise RuntimeError("Faulty graphics pipeline of this machine.") + raise RuntimeError("Faulty graphics pipeline on this machine.") self.buff = win + # Disable automatic rendering of the buffer of efficiency + win.set_one_shot(True) + # Append buffer to the list of windows managed by the ShowBase self.winList.append(win) # Attach a texture as screenshot requires copying GPU data to RAM - self.buff.add_render_texture( - Texture(), GraphicsOutput.RTM_triggered_copy_ram) + tex = Texture() + tex.set_format(Texture.F_rgb) + self.buff.add_render_texture(tex, GraphicsOutput.RTM_copy_ram) # Create 3D camera region for the scene. # Set near distance of camera lens to allow seeing model from close. self.offscreen_graphics_lens = PerspectiveLens() self.offscreen_graphics_lens.set_near(0.1) - self.make_camera( + cam = self.make_camera( win, camName='offscreen_camera', lens=self.offscreen_graphics_lens) + # Enable Physics-based rendering + enable_pbr_shader(cam.node()) + # Create 2D display region for widgets self.offscreen_display_region = win.makeMonoDisplayRegion() self.offscreen_display_region.set_sort(5) @@ -739,6 +784,122 @@ def _adjust_offscreen_window_aspect_ratio(self) -> None: self.offA2dBottomLeft.set_pos(a2dLeft, 0, a2dBottom) self.offA2dBottomRight.set_pos(a2dRight, 0, a2dBottom) + def add_camera(self, + name: str, + is_depthmap: bool, + size: Tuple[int, int]) -> None: + """Add a RGB or depth camera to the scene. + + The user is responsible for managing it, ie set its pose in world, get + screenshot from it, and remove it when no longer relevant. Manually + added cameras is mainly useful for simulating exteroceptive sensors. + + :param name: Name of the camera to be added. + :param is_depthmap: Whether the camera output gathers 3 8-bits integers + RGB channels or 1 32-bits floats depth channel. + :param size: Resolution (height and width) in pixel of the image being + captured by the camera. + """ + # TODO: Expose optional parameters to set lens type and properties. + + # Make sure that no camera with the same name already exists + if name in self._user_cameras: + raise ValueError( + "A camera with the same name already exists. Please delete " + "it by calling `remove_camera` before adding a new one.") + + # Create new offscreen buffer + fbprops = FrameBufferProperties() + if is_depthmap: + fbprops.set_depth_bits(32) + fbprops.set_float_depth(True) + fbprops.set_multisamples(0) + else: + fbprops.set_rgba_bits(8, 8, 8, 8) + fbprops.set_float_color(False) + fbprops.set_depth_bits(16) + fbprops.set_float_depth(False) + fbprops.set_multisamples(4) + winprops = WindowProperties() + winprops.set_size(*size) + flags = GraphicsPipe.BF_refuse_window + buffer = self.graphicsEngine.make_output( + self.pipe, f"user_buffer_{name}", 0, fbprops, winprops, flags, + self.win.get_gsg(), self.win) + if buffer is None: + raise RuntimeError("Faulty graphics pipeline on this machine.") + self._user_buffers[name] = buffer + self.winList.append(buffer) + + # Disable automatic rendering of the buffer of efficiency + buffer.set_one_shot(True) + + # Disable color buffer and enable depth buffer + if is_depthmap: + buffer.set_clear_color_active(False) + buffer.set_clear_depth_active(True) + + # Attach a texture as screenshot requires copying GPU data to RAM + tex = Texture(f"user_texture_{name}") + if is_depthmap: + tex.set_format(Texture.F_depth_component) + else: + tex.set_format(Texture.F_rgb) + buffer.add_render_texture(tex, GraphicsOutput.RTM_copy_ram) + + # Create 3D camera region for the scene. + # See official documentation about field of view parameterization: + # https://docs.panda3d.org/1.10/python/programming/camera-control/perspective-lenses # noqa: E501 # pylint: disable=line-too-long + lens = PerspectiveLens() + if is_depthmap: + lens.set_fov(50.0) # field of view angle [0, 180], 40° by default + lens.set_near(0.02) # near distance (objects closer not rendered) + lens.set_far(6.0) # far distance (objects farther not rendered) + # lens.set_film_size(24, 36) + # lens.set_focal_length(50) # Setting this will overwrite fov + else: + lens.set_near(0.1) + lens.set_aspect_ratio(self.get_aspect_ratio(buffer)) + if is_depthmap: + mask = self.UserDepthCameraMask + else: + mask = self.UserRGBCameraMask + cam = self.make_camera( + buffer, camName=f"user_camera_{name}", lens=lens, mask=mask) + cam.reparent_to(self.render) + self._user_cameras[name] = cam + + # Disable shader for depth map since it irrelevant + if is_depthmap: + tempnode = NodePath(PandaNode("temp node")) + tempnode.set_material_off(2) + tempnode.set_texture_off(2) + tempnode.set_light_off(2) + tempnode.set_shader_off(2) + cam.node().set_initial_state(tempnode.get_state()) + else: + # Enable Physics-based rendering + enable_pbr_shader(cam.node()) + + # Force rendering the scene to finalize initialization of the GSG + self.graphics_engine.render_frame() + + # Flipped buffer upside-down + buffer.inverted = True + + def remove_camera(self, name: str) -> None: + """Remove one of the cameras being managed by the user, which has been + added manually via `add_camera`. + + :param name: Name of the camera to remove. + """ + # Make sure that the camera exists before trying to delete it + if name not in self._user_cameras: + raise ValueError(f"No camera with name '{name}' was found.") + self.close_window(self._user_buffers[name], keepCamera=False) + del self._user_cameras[name] + del self._user_buffers[name] + def getSize(self, win: Optional[Any] = None) -> Tuple[int, int]: """Patched to return the size of the window used for capturing frame by default, instead of main window. @@ -893,9 +1054,9 @@ def move_orbital_camera_task(self, def _make_light_ambient(self, color: Tuple3FType) -> NodePath: """Patched to fix wrong color alpha. """ - node = super()._make_light_ambient(color) - node.get_node(0).set_color((*color, 1.0)) - return node + light = super()._make_light_ambient(color) + light.node().set_color((*color, 1.0)) + return light def _make_light_direct(self, index: int, @@ -905,9 +1066,9 @@ def _make_light_direct(self, ) -> NodePath: """Patched to fix wrong color alpha. """ - light_path = super()._make_light_direct(index, color, pos, target) - light_path.get_node(0).set_color((*color, 1.0)) - return light_path + light = super()._make_light_direct(index, color, pos, target) + light.node().set_color((*color, 1.0)) + return light def _make_axes(self) -> NodePath: model = GeomNode('axes') @@ -917,9 +1078,10 @@ def _make_axes(self) -> NodePath: if self.win.gsg.driver_vendor.startswith('NVIDIA'): node.set_render_mode_thickness(4) node.set_antialias(AntialiasAttrib.MLine) - node.set_shader_auto(True) + node.set_shader_auto() node.set_light_off() - node.hide(self.LightMask) + node.hide(self.LightMask | self.UserCameraMask) + node.set_tag("is_virtual", "1") node.set_scale(0.3) return node @@ -959,7 +1121,7 @@ def _make_floor(self, # Set material to render shadows if supported material = Material() material.set_base_color((1.35, 1.35, 1.35, 1.0)) - node.set_material(material, True) + node.set_material(material) # Disable light casting node.hide(self.LightMask) @@ -967,7 +1129,7 @@ def _make_floor(self, # Adjust frustum of the lights to project shadow over the whole scene for light_path in self._lights[1:]: bmin, bmax = node.get_tight_bounds(light_path) - lens = light_path.get_node(0).get_lens() + lens = light_path.node().get_lens() lens.set_film_offset((bmin.xz + bmax.xz) * 0.5) lens.set_film_size(bmax.xz - bmin.xz) lens.set_near_far(bmin.y, bmax.y) @@ -1068,9 +1230,10 @@ def append_frame(self, if self.win.gsg.driver_vendor.startswith('NVIDIA'): node.set_render_mode_thickness(4) node.set_antialias(AntialiasAttrib.MLine) - node.set_shader_auto(True) + node.set_shader_auto() node.set_light_off() - node.hide(self.LightMask) + node.hide(self.LightMask | self.UserCameraMask) + node.set_tag("is_virtual", "1") self.append_node(root_path, name, node, frame) def append_cone(self, @@ -1153,6 +1316,10 @@ def append_arrow(self, body_node.set_scale(1.0, 1.0, length) body_node.set_pos(0.0, 0.0, (-0.5 if anchor_top else 0.5) * length) arrow_node.set_scale(radius, radius, 1.0) + + arrow_node.hide(self.LightMask | self.UserCameraMask) + arrow_node.set_tag("is_virtual", "1") + self.append_node(root_path, name, arrow_node, frame) def append_mesh(self, @@ -1495,7 +1662,7 @@ def set_material(self, material.set_diffuse(Vec4(*color)) material.set_specular(Vec3(1, 1, 1)) material.set_roughness(0.4) - node.set_material(material, True) + node.set_material(material) if color[3] < 1.0: node.set_transparency(TransparencyAttrib.M_alpha) @@ -1570,10 +1737,14 @@ def show_node(self, node.set_tag("status", "hidden") node.hide() if always_foreground is not None: + # FIXME: Properly restore original mask if any if always_foreground: node.set_bin("fixed", 0) + node.hide(self.UserCameraMask) else: node.clear_bin() + if node.get_tag("is_virtual") != "1": + node.show(self.UserCameraMask) node.set_depth_test(not always_foreground) node.set_depth_write(not always_foreground) @@ -1584,27 +1755,40 @@ def get_camera_transform(self) -> Tuple[np.ndarray, np.ndarray]: representation of the orientation (X, Y, Z, W) as a pair of `np.ndarray`. """ - return (np.array(self.camera.get_pos()), - np.array(self.camera.get_quat())) + return (np.array(self.camera.get_pos(), dtype=np.float64), + np.array(self.camera.get_quat(), dtype=np.float64)) def set_camera_transform(self, pos: Tuple3FType, quat: np.ndarray, - lookat: Tuple3FType = (0.0, 0.0, 0.0)) -> None: + camera_name: Optional[str] = None) -> None: """Set the current absolute pose of the camera. :param pos: Desired position of the camera. :param quat: Desired orientation of the camera as a quaternion (X, Y, Z, W). - :param lookat: Point at which the camera is looking at. It is partially - redundant with the desired orientation and will take - precedence in case of inconsistency. It is also involved - in zoom control. + :param camera_name: Name of the camera to consider. Whether one of the + cameras that were manually added by the user via + `add_camera`, or None to specify the one associated + with the main window. If the main window is an + onscreen graphical window, then the camera of its + accompanying offscreen buffer for screenshots will + be jointly moved since they are attached together. + Optional: None by default. """ - self.camera.set_pos(*pos) - self.camera.set_quat(LQuaternion(quat[-1], *quat[:-1])) - self.camera_lookat = np.array(lookat) - self.move_orbital_camera_task() + # Pick the right camera + if camera_name is None: + camera = self.camera + else: + camera = self._user_cameras[camera_name] + + # Move the camera + camera.set_pos_quat(Vec3(*pos), Quat(quat[-1], *quat[:-1])) + + # Reset orbital camera control + if camera_name is None: + self.camera_lookat = np.array([0.0, 0.0, 0.0]) + self.move_orbital_camera_task() def get_camera_lookat(self) -> np.ndarray: """Get the location of the point toward which the camera is looking at. @@ -1670,7 +1854,7 @@ def save_screenshot(self, filename: Optional[str] = None) -> bool: # Refresh the scene to make sure it is perfectly up-to-date. # It will take into account the updated position of the camera. assert self.buff is not None - self.buff.trigger_copy() + self.buff.set_one_shot(True) self.graphics_engine.render_frame() # Capture frame as image @@ -1692,9 +1876,9 @@ def save_screenshot(self, filename: Optional[str] = None) -> bool: return True - def get_screenshot(self, - requested_format: str = 'RGB', - raw: bool = False) -> Union[np.ndarray, bytes]: + def get_screenshot(self, # pylint: disable=arguments-renamed + camera_name: Optional[str] = None + ) -> Union[np.ndarray, bytes]: """Patched to take screenshot of the last window available instead of the main one, and to add raw data return mode for efficient multiprocessing. @@ -1705,32 +1889,54 @@ def get_screenshot(self, scheduler. The framerate limit must be disable manually to avoid such limitation. - .. note:: - Internally, Panda3d uses BGRA, so using it is slightly faster than - RGBA, but not RGB since there is one channel missing. - - :param requested_format: Desired export format (e.g. 'RGB' or 'BGRA') - :param raw: whether to return a raw memory view of bytes, of a - structured `np.ndarray` of uint8 with dimensions [H, W, D]. + :param camera_name: Name of the camera to consider. Whether one of the + cameras that were manually added by the user via + `add_camera`, or None to specify the one associated + with the main window (ie either the camera attached + to the main window directly if the later is an + offscreen buffer, otherwise the camera of its + accompanying offscreen buffer). + Optional: None by default. """ + # Get desired buffer + if camera_name is None: + assert self.buff is not None + buffer = self.buff + else: + buffer = self._user_buffers[camera_name] + + # Get frame as raw texture + texture = buffer.get_texture() + is_depth_map = texture.format == Texture.F_depth_component32 + + # Disable shadow casting for depth map computation since it is useless + shadow_buffers = [] + if is_depth_map: + for light in self._lights: + if not light.node().is_ambient_light(): + shadow_buffer = light.node().getShadowBuffer(self.win.gsg) + if shadow_buffer is not None: + shadow_buffer.active = False + shadow_buffers.append(shadow_buffer) + # Refresh the scene - assert self.buff is not None - self.buff.trigger_copy() + buffer.set_one_shot(True) self.graphics_engine.render_frame() - # Get frame as raw texture - assert self.buff is not None - texture = self.buff.get_texture() + # Restore shadow casting + for shadow_buffer in shadow_buffers: + shadow_buffer.active = True # Extract raw array buffer from texture - image = texture.get_ram_image_as(requested_format) - - # Return raw buffer if requested - if raw: - return image.get_data() + if is_depth_map: + image = texture.get_ram_image() + else: + image = texture.get_ram_image_as('RGB') # Convert raw texture to numpy array if requested xsize, ysize = texture.get_x_size(), texture.get_y_size() + if is_depth_map: + return np.frombuffer(image, np.float32).reshape((ysize, xsize)) return np.frombuffer(image, np.uint8).reshape((ysize, xsize, -1)) def enable_shadow(self, enable: bool) -> None: diff --git a/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_widget.py b/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_widget.py index 149495da1..1a94e20fb 100644 --- a/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_widget.py +++ b/python/jiminy_py/src/jiminy_py/viewer/panda3d/panda3d_widget.py @@ -71,7 +71,7 @@ def paintEvent(self, # Note that `QImage` does not manage the lifetime of the input data # buffer, so it is necessary to keep it is local scope until the end of # its drawning. - data = self.get_screenshot(requested_format='RGB', raw=True) + data = self.get_screenshot() img = QtGui.QImage( data, *self.buff.getSize(), QtGui.QImage.Format_RGB888) diff --git a/python/jiminy_py/src/jiminy_py/viewer/replay.py b/python/jiminy_py/src/jiminy_py/viewer/replay.py index fc7e9add3..4c8cd74f2 100644 --- a/python/jiminy_py/src/jiminy_py/viewer/replay.py +++ b/python/jiminy_py/src/jiminy_py/viewer/replay.py @@ -595,6 +595,7 @@ def play_trajectories( # Create frame storage frame = av.VideoFrame(*record_video_size, 'rgb24') + frame_bytes = memoryview(frame.planes[0]) # Add frames to video sequentially update_hook_t = None @@ -636,10 +637,8 @@ def play_trajectories( # Update frame. # Note that `capture_frame` is by far the main bottleneck # of the whole recording process (~75% on discrete gpu). - buffer = Viewer.capture_frame( - *record_video_size, raw_data=True) - memoryview( - frame.planes[0])[:] = buffer # type: ignore[arg-type] + buffer = Viewer.capture_frame(*record_video_size) + frame_bytes[:] = buffer.reshape(-1).data # Write frame for packet in stream.encode(frame): diff --git a/python/jiminy_py/src/jiminy_py/viewer/viewer.py b/python/jiminy_py/src/jiminy_py/viewer/viewer.py index 4785ed12f..9a1fef98c 100644 --- a/python/jiminy_py/src/jiminy_py/viewer/viewer.py +++ b/python/jiminy_py/src/jiminy_py/viewer/viewer.py @@ -21,7 +21,6 @@ import subprocess import webbrowser import multiprocessing -from copy import deepcopy from urllib.request import urlopen from functools import wraps, partial from threading import RLock @@ -1248,7 +1247,7 @@ def connect_backend(backend: Optional[str] = None) -> None: for pid in psutil.pids(): try: proc_info = Process(pid) - for conn in proc_info.connections("tcp4"): + for conn in proc_info.net_connections("tcp4"): if conn.status != 'LISTEN' or \ conn.laddr.ip != '127.0.0.1': continue @@ -1428,7 +1427,7 @@ def set_legend(labels: Optional[Sequence[str]] = None) -> None: for text, (robot_name, color) in zip( labels, Viewer._backend_robot_colors.items()): if color is not None: - rgba = (*(int(e * 255) for e in color[:3]), color[3]) + rgba = (*[int(e * 255) for e in color[:3]], color[3]) color_text = f"rgba({','.join(map(str, rgba))})" else: color_text = "black" @@ -1460,7 +1459,8 @@ def set_clock(t: Optional[float] = None) -> None: @staticmethod @_with_lock @_must_be_open - def get_camera_transform() -> Tuple[Tuple3FType, Tuple3FType]: + def get_camera_transform(camera_name: Optional[str] = None + ) -> Tuple[Tuple3FType, Tuple3FType]: """Get transform of the camera pose. .. warning:: @@ -1471,6 +1471,15 @@ def get_camera_transform() -> Tuple[Tuple3FType, Tuple3FType]: since it is impossible to get access to this information. Thus this method is valid as long as the user does not move the camera manually using mouse camera control. + + .. warning:: + Specifying a camera name is only supported by Panda3d rendering + backend. + + :param camera_name: Name of the camera to consider. None to specify + the "default" world facing camera that is used + when GUI (onscreen window) in enabled. + Optional: None by default. """ # Assert(s) for type checker assert Viewer.backend is not None @@ -1481,9 +1490,15 @@ def get_camera_transform() -> Tuple[Tuple3FType, Tuple3FType]: quat /= np.linalg.norm(quat) rot = pin.Quaternion(*quat).matrix() rpy = matrixToRpy(rot @ CAMERA_INV_TRANSFORM_PANDA3D.T) - else: - xyz, rpy = deepcopy(Viewer._camera_xyzrpy) - return xyz, rpy + return xyz, rpy + + # Make sure that no camera name has been specified for meshcat + if camera_name is not None: + raise ValueError( + "Specifying a camera name is only supported by Panda3d.") + + xpy, rpy = map(tuple, Viewer._camera_xyzrpy) + return xpy, rpy @_with_lock @_must_be_open @@ -1491,6 +1506,7 @@ def set_camera_transform(self: Optional["Viewer"] = None, position: Optional[Tuple3FType] = None, rotation: Optional[Tuple3FType] = None, relative: Optional[Union[str, int]] = None, + camera_name: Optional[str] = None, wait: bool = False) -> None: """Set transform of the camera pose. @@ -1500,6 +1516,10 @@ def set_camera_transform(self: Optional["Viewer"] = None, [0.0, 0.0, 0.0] moves the camera at the center of scene, looking downward. + .. warning:: + Specifying a camera name is only supported by Panda3d rendering + backend. + :param position: Position [X, Y, Z] as a list or 1D array. If `None`, when it will be kept as is. Optional: None by default. @@ -1516,12 +1536,22 @@ def set_camera_transform(self: Optional["Viewer"] = None, - **other:** relative to a robot frame, not accounting for the rotation of the frame during travelling. It supports both frame name and index in model. + :param camera_name: Name of the camera to consider. None to specify + the "default" world facing camera that is used + when GUI (onscreen window) in enabled. + Optional: None by default. :param wait: Whether to wait for rendering to finish. """ # pylint: disable=invalid-name, possibly-used-before-assignment # Assert(s) for type checker assert Viewer.backend is not None assert Viewer._backend_obj is not None + assert self is None or isinstance(self, Viewer) + + # Make sure that no camera name has been specified for meshcat + if Viewer.backend == 'meshcat' and camera_name is not None: + raise ValueError( + "Specifying a camera name is only supported by Panda3d.") if self is None and relative is not None and relative != 'camera': raise ValueError( @@ -1530,7 +1560,8 @@ def set_camera_transform(self: Optional["Viewer"] = None, # Handling of position and rotation arguments if position is None or rotation is None or relative == 'camera': - position_camera, rotation_camera = Viewer.get_camera_transform() + position_camera, rotation_camera = Viewer.get_camera_transform( + camera_name=camera_name) if position is None: if relative is not None: position = (0.0, 0.0, 0.0) @@ -1568,7 +1599,8 @@ def set_camera_transform(self: Optional["Viewer"] = None, H_abs = H_orig * SE3(rotation_mat, position) position = H_abs.translation rotation = matrixToRpy(H_abs.rotation) - Viewer.set_camera_transform(None, position, rotation) + Viewer.set_camera_transform( + None, position, rotation, camera_name=camera_name) return # Perform the desired transformation @@ -1576,7 +1608,7 @@ def set_camera_transform(self: Optional["Viewer"] = None, rotation_panda3d = pin.Quaternion( rotation_mat @ CAMERA_INV_TRANSFORM_PANDA3D).coeffs() Viewer._backend_obj.gui.set_camera_transform( - position, rotation_panda3d) + position, rotation_panda3d, camera_name) elif Viewer.backend == 'meshcat': # pylint: disable=import-outside-toplevel # Meshcat camera is rotated by -pi/2 along Roll axis wrt the @@ -1758,7 +1790,7 @@ def set_color(self, """Override the color of the visual and collision geometries of the robot on-the-fly. - .. note:: + .. warning:: This method is only supported by Panda3d for now. :param color: Color of the robot. It will override the original color @@ -1864,20 +1896,56 @@ def update_floor(ground_profile: Optional[jiminy.HeightmapFunction] = None, update_floor as meshcat_update_floor) meshcat_update_floor(Viewer._backend_obj.gui, geom) + @staticmethod + @_with_lock + @_must_be_open + def add_camera(camera_name: str, + width: int, + height: int, + is_depthmap: bool) -> None: + """TODO: Write documentation. + + .. warning:: + This method is only supported by Panda3d for now. + + """ + # Assert(s) for type checker + assert Viewer.backend is not None + assert Viewer._backend_obj is not None + + # Make sure the backend supports this method + if not Viewer.backend.startswith('panda3d'): + raise NotImplementedError( + "This method is only supported by Panda3d.") + + # Add camera + Viewer._backend_obj.gui.add_camera( + camera_name, (width, height), is_depthmap) + @staticmethod @_with_lock @_must_be_open def capture_frame(width: Optional[int] = None, height: Optional[int] = None, - raw_data: bool = False) -> Union[np.ndarray, bytes]: + camera_name: Optional[str] = None, + raw_data: bool = False + ) -> Union[np.ndarray, bytes]: """Take a snapshot and return associated data. + .. warning:: + Specifying a camera name is only supported by Panda3d rendering + backend, while raw data mode is only supported by Meshcat. + :param width: Width for the image in pixels. None to keep unchanged. Optional: Kept unchanged by default. :param height: Height for the image in pixels. None to keep unchanged. Optional: Kept unchanged by default. + :param camera_name: Name of the camera to consider. None to specify + the "default" world facing camera that is used + when GUI (onscreen window) in enabled. + Optional: None by default. :param raw_data: Whether to return a 2D numpy array, or the raw output - from the backend (the actual type may vary). + from the backend as bytes array. """ # Assert(s) for type checker assert Viewer.backend is not None @@ -1885,31 +1953,33 @@ def capture_frame(width: Optional[int] = None, # Check user arguments if Viewer.backend.startswith('panda3d'): - # Resize window if size has changed - _width, _height = Viewer._backend_obj.gui.getSize() - if width is None: - width = _width - if height is None: - height = _height - if _width != width or _height != height: - Viewer._backend_obj.gui.set_window_size(width, height) + if camera_name is None: + # Resize window if size has changed + _width, _height = Viewer._backend_obj.gui.getSize() + if width is None: + width = _width + if height is None: + height = _height + if _width != width or _height != height: + Viewer._backend_obj.gui.set_window_size(width, height) + elif width is not None or height is not None: + raise ValueError( + "Specifying both camera name and image width and/or " + "height is not supported.") - # Get raw buffer image instead of numpy array for efficiency - buffer = Viewer._backend_obj.gui.get_screenshot( - requested_format='RGB', raw=True) - if buffer is None: + # Get screenshot + image = Viewer._backend_obj.gui.get_screenshot(camera_name) + if image is None: raise RuntimeError( "Impossible to capture frame. There is something wrong " "with the graphics stack on this machine.") + return image - # Return raw data if requested - if raw_data: - return buffer - - # Extract and return numpy array RGB - return np.frombuffer(buffer, np.uint8).reshape((height, width, 3)) + # Make sure that no camera name has been specified for meshcat + if camera_name is not None: + raise ValueError( + "Specifying a camera name is only supported by Panda3d.") - # if Viewer.backend == 'meshcat': # Send capture frame request to the background recorder process img_html = Viewer._backend_obj.capture_frame(width, height) diff --git a/python/jiminy_pywrap/include/jiminy/python/functors.h b/python/jiminy_pywrap/include/jiminy/python/functors.h index 759545e22..e393a1576 100644 --- a/python/jiminy_pywrap/include/jiminy/python/functors.h +++ b/python/jiminy_pywrap/include/jiminy/python/functors.h @@ -220,18 +220,25 @@ namespace jiminy::python { } - void operator()( - const Eigen::Vector2d & posFrame, double & height, Eigen::Vector3d & normal) + void operator()(const Eigen::Vector2d & posFrame, + double & height, + std::optional> normal) { switch (heightmapType_) { case HeightmapType::CONSTANT: height = bp::extract(handlePyPtr_); - normal = Eigen::Vector3d::UnitZ(); + if (normal.has_value()) + { + normal.value() = Eigen::Vector3d::UnitZ(); + } break; case HeightmapType::STAIRS: handlePyPtr_(posFrame, convertToPython(height, false)); - normal = Eigen::Vector3d::UnitZ(); + if (normal.has_value()) + { + normal.value() = Eigen::Vector3d::UnitZ(); + } break; case HeightmapType::GENERIC: default: diff --git a/python/jiminy_pywrap/include/jiminy/python/utilities.h b/python/jiminy_pywrap/include/jiminy/python/utilities.h index de6b4f812..e669a9c2f 100644 --- a/python/jiminy_pywrap/include/jiminy/python/utilities.h +++ b/python/jiminy_pywrap/include/jiminy/python/utilities.h @@ -812,8 +812,7 @@ namespace jiminy::python struct result_converter { template || - is_eigen_ref_v>> + typename = std::enable_if_t || is_eigen_ref_v>> struct apply { struct type diff --git a/python/jiminy_pywrap/src/functors.cc b/python/jiminy_pywrap/src/functors.cc index dbc52cf50..8f5dcebdd 100644 --- a/python/jiminy_pywrap/src/functors.cc +++ b/python/jiminy_pywrap/src/functors.cc @@ -25,20 +25,45 @@ namespace jiminy::python // *********************************** HeightmapFunction *********************************** // void queryHeightMap(HeightmapFunction & heightmap, - np::ndarray positionsPy, - np::ndarray heightsPy) + const Eigen::Matrix2Xd & positions, + Eigen::Ref heights) { - auto const positions = convertFromPython< - Eigen::Map> - >(positionsPy); - auto heights = convertFromPython< - Eigen::Map> - >(heightsPy).col(0); - - for (Eigen::Index i = 0; i < positions.cols() ; ++i) + // Make sure that the number of query points is consistent between all arguments + if (heights.size() != positions.cols()) { - Eigen::Vector3d normal; - heightmap(positions.col(i), heights[i], normal); + JIMINY_THROW( + std::invalid_argument, + "'positions' and/or 'heights' are inconsistent with each other. 'position' must " + "be a 2D array whose first dimension gathers the 2 position coordinates in world " + "plane (X, Y) while the second dimension corresponds to individual query points."); + } + + // Loop over all query points sequentially + for (Eigen::Index i = 0; i < positions.cols(); ++i) + { + heightmap(positions.col(i), heights[i], std::nullopt); + } + } + + void queryHeightMapWithNormals(HeightmapFunction & heightmap, + const Eigen::Matrix2Xd & positions, + Eigen::Ref heights, + Eigen::Ref normals) + { + // Make sure that the number of query points is consistent between all arguments + if (heights.size() != positions.cols() || normals.cols() != positions.cols()) + { + JIMINY_THROW(std::invalid_argument, + "'positions', 'heights' and/or 'normals' are inconsistent with each " + "other. 'normals' must be a 2D array whose first dimension gathers the 3 " + "position coordinates (X, Y, Z) while the second dimension corresponds " + "to individual query points."); + } + + // Loop over all query points sequentially + for (Eigen::Index i = 0; i < positions.cols(); ++i) + { + heightmap(positions.col(i), heights[i], normals.col(i)); } } @@ -83,8 +108,10 @@ namespace jiminy::python &internal::heightmap_function::getPyFun, bp::return_value_policy()); + bp::def( + "query_heightmap", &queryHeightMap, (bp::args("heightmap"), "positions", "heights")); bp::def("query_heightmap", - &queryHeightMap, - (bp::args("heightmap"), "positions", "heights")); + &queryHeightMapWithNormals, + (bp::args("heightmap"), "positions", "heights", "normals")); } } diff --git a/python/jiminy_pywrap/src/generators.cc b/python/jiminy_pywrap/src/generators.cc index 725846a55..9f7f4a39d 100644 --- a/python/jiminy_pywrap/src/generators.cc +++ b/python/jiminy_pywrap/src/generators.cc @@ -35,7 +35,6 @@ namespace jiminy::python return generator.seed(std::seed_seq(seedSeq.cbegin(), seedSeq.cend())); } - #define GENERIC_DISTRIBUTION_WRAPPER(dist, arg1, arg2) \ Eigen::MatrixXd dist##FromStackedArgs( \ const uniform_random_bit_generator_ref & generator, \ @@ -93,17 +92,95 @@ namespace jiminy::python #undef GENERIC_DISTRIBUTION_WRAPPER + template + std::enable_if_t>...>, double> + evaluatePerlinProcessUnpacked(DerivedPerlinProcess & fun, Args... args) + { + return fun(Eigen::Matrix(sizeof...(Args)), 1>{args...}); + } + + template + std::enable_if_t>...>, + typename DerivedPerlinProcess::template VectorN> + gradPerlinProcessUnpacked(DerivedPerlinProcess & fun, Args... args) + { + return fun.grad(Eigen::Matrix(sizeof...(Args)), 1>{args...}); + } + + template + using type_t = T; + + template + auto evaluatePerlinProcessUnpackedSignature( + std::index_sequence) -> double (*)(DerivedPerlinProcess &, type_t...); + + template + auto gradPerlinProcessUnpackedSignature(std::index_sequence) -> + typename DerivedPerlinProcess::template VectorN (*)(DerivedPerlinProcess &, + type_t...); + + template + struct PyPerlinProcessVisitor : public bp::def_visitor> + { + public: + template + static void visit(PyClass & cl) + { + using DerivedPerlinProcess = typename PyClass::wrapped_type; + + // clang-format off + cl + .def("__call__", + static_cast( + std::make_index_sequence{}))>(evaluatePerlinProcessUnpacked)) + .def("__call__", &DerivedPerlinProcess::operator(), (bp::arg("self"), "vec")) + .def("grad", + static_cast( + std::make_index_sequence{}))>(gradPerlinProcessUnpacked)) + .def("grad", &DerivedPerlinProcess::grad, (bp::arg("self"), "vec")) + .def( + "reset", + makeFunction(ConvertGeneratorFromPythonAndInvoke< + void(const uniform_random_bit_generator_ref &), DerivedPerlinProcess + >(&DerivedPerlinProcess::reset), + bp::default_call_policies(), + (bp::arg("self"), "generator"))) + .ADD_PROPERTY_GET("wavelength", &DerivedPerlinProcess::getWavelength) + .ADD_PROPERTY_GET("num_octaves", &DerivedPerlinProcess::getNumOctaves); + // clang-format on + } + + static void expose() + { + bp::class_, + // bp::bases>, + std::shared_ptr>, + boost::noncopyable>( + toString("RandomPerlinProcess", N, "D").c_str(), + bp::init( + (bp::arg("self"), "wavelength", bp::arg("num_octaves") = 6U))) + .def(PyPerlinProcessVisitor()); + + bp::class_, + // bp::bases>, + std::shared_ptr>, + boost::noncopyable>( + toString("PeriodicPerlinProcess", N, "D").c_str(), + bp::init( + (bp::arg("self"), "wavelength", "period", bp::arg("num_octaves") = 6U))) + .ADD_PROPERTY_GET("period", &PeriodicPerlinProcess::getPeriod) + .def(PyPerlinProcessVisitor()); + } + }; + void exposeGenerators() { - // clang-format off - bp::class_, - boost::noncopyable>("PCG32", - bp::init((bp::arg("self"), "state"))) + bp::class_, boost::noncopyable>( + "PCG32", bp::init((bp::arg("self"), "state"))) .def(bp::init<>((bp::arg("self")))) - .def("__init__", bp::make_constructor(&makePCG32FromSeedSed, - bp::default_call_policies(), - (bp::arg("seed_seq")))) + .def("__init__", + bp::make_constructor( + &makePCG32FromSeedSed, bp::default_call_policies(), (bp::arg("seed_seq")))) .def("__call__", &PCG32::operator(), (bp::arg("self"))) .def("seed", &seedPCG32FromSeedSed, (bp::arg("self"), "seed_seq")) .add_static_property( @@ -113,93 +190,96 @@ namespace jiminy::python bp::implicitly_convertible>(); -#define BIND_GENERIC_DISTRIBUTION(dist, arg1, arg2) \ - bp::def(#dist, makeFunction( \ - ConvertGeneratorFromPythonAndInvoke(&dist##FromStackedArgs), \ - bp::default_call_policies(), \ - (bp::arg("generator"), #arg1, #arg2))); \ - bp::def(#dist, makeFunction( \ - ConvertGeneratorFromPythonAndInvoke(&dist##FromSize), \ - bp::default_call_policies(), \ - (bp::arg("generator"), bp::arg(#arg1) = 0.0F, bp::arg(#arg2) = 1.0F, \ - bp::arg("size") = bp::object()))); +#define BIND_GENERIC_DISTRIBUTION(dist, arg1, arg2) \ + bp::def(#dist, \ + makeFunction(ConvertGeneratorFromPythonAndInvoke(&dist##FromStackedArgs), \ + bp::default_call_policies(), \ + (bp::arg("generator"), #arg1, #arg2))); \ + bp::def(#dist, \ + makeFunction(ConvertGeneratorFromPythonAndInvoke(&dist##FromSize), \ + bp::default_call_policies(), \ + (bp::arg("generator"), \ + bp::arg(#arg1) = 0.0F, \ + bp::arg(#arg2) = 1.0F, \ + bp::arg("size") = bp::object()))); - BIND_GENERIC_DISTRIBUTION(uniform, lo, hi) - BIND_GENERIC_DISTRIBUTION(normal, mean, stddev) + BIND_GENERIC_DISTRIBUTION(uniform, lo, hi) + BIND_GENERIC_DISTRIBUTION(normal, mean, stddev) #undef BIND_GENERIC_DISTRIBUTION // Must be declared last to take precedence over generic declaration with default values - bp::def("uniform", makeFunction(ConvertGeneratorFromPythonAndInvoke( - static_cast< - float (*)(const uniform_random_bit_generator_ref &) - >(&uniform)), - bp::default_call_policies(), - (bp::arg("generator")))); + bp::def("uniform", + makeFunction( + ConvertGeneratorFromPythonAndInvoke( + static_cast &)>( + &uniform)), + bp::default_call_policies(), + (bp::arg("generator")))); + + bp::class_, + boost::noncopyable>("PeriodicTabularProcess", bp::no_init) + .def("__call__", &PeriodicTabularProcess::operator(), (bp::arg("self"), "time")) + .def("grad", &PeriodicTabularProcess::grad, (bp::arg("self"), "time")) + .def("reset", + makeFunction(ConvertGeneratorFromPythonAndInvoke(&PeriodicTabularProcess::reset), + bp::default_call_policies(), + (bp::arg("self"), "generator"))) + .ADD_PROPERTY_GET("wavelength", &PeriodicTabularProcess::getWavelength) + .ADD_PROPERTY_GET("period", &PeriodicTabularProcess::getPeriod); bp::class_, std::shared_ptr, - boost::noncopyable>("PeriodicGaussianProcess", - bp::init( - (bp::arg("self"), "wavelength", "period"))) - .def("__call__", &PeriodicGaussianProcess::operator(), - (bp::arg("self"), bp::arg("time"))) - .def("reset", makeFunction( - ConvertGeneratorFromPythonAndInvoke(&PeriodicGaussianProcess::reset), - bp::default_call_policies(), - (bp::arg("self"), "generator"))) - .ADD_PROPERTY_GET("wavelength", &PeriodicGaussianProcess::getWavelength) - .ADD_PROPERTY_GET("period", &PeriodicGaussianProcess::getPeriod); + boost::noncopyable>( + "PeriodicGaussianProcess", + bp::init((bp::arg("self"), "wavelength", "period"))); bp::class_, std::shared_ptr, - boost::noncopyable>("PeriodicFourierProcess", - bp::init( - (bp::arg("self"), "wavelength", "period"))) - .def("__call__", &PeriodicFourierProcess::operator(), - (bp::arg("self"), bp::arg("time"))) - .def("reset", makeFunction( - ConvertGeneratorFromPythonAndInvoke(&PeriodicFourierProcess::reset), - bp::default_call_policies(), - (bp::arg("self"), "generator"))) - .ADD_PROPERTY_GET("wavelength", &PeriodicFourierProcess::getWavelength) - .ADD_PROPERTY_GET("period", &PeriodicFourierProcess::getPeriod); - - bp::class_, - boost::noncopyable>("AbstractPerlinProcess", bp::no_init) - .def("__call__", &AbstractPerlinProcess::operator(), - (bp::arg("self"), "time")) - .def("reset", makeFunction( - ConvertGeneratorFromPythonAndInvoke(&AbstractPerlinProcess::reset), - bp::default_call_policies(), - (bp::arg("self"), "generator"))) - .ADD_PROPERTY_GET("wavelength", &AbstractPerlinProcess::getWavelength) - .ADD_PROPERTY_GET("num_octaves", &AbstractPerlinProcess::getNumOctaves); - - bp::class_, - std::shared_ptr, - boost::noncopyable>("RandomPerlinProcess", - bp::init( - (bp::arg("self"), "wavelength", bp::arg("num_octaves") = 6U))); - - bp::class_, - std::shared_ptr, - boost::noncopyable>("PeriodicPerlinProcess", - bp::init( - (bp::arg("self"), "wavelength", "period", bp::arg("num_octaves") = 6U))) - .ADD_PROPERTY_GET("period", &PeriodicPerlinProcess::getPeriod); - - bp::def("random_tile_ground", &tiles, - (bp::arg("size"), "height_max", "interp_delta", - "sparsity", "orientation", "seed")); - bp::def("stairs_ground", &stairs, (bp::arg("step_width"), "step_height", "step_number", "orientation")); + boost::noncopyable>( + "PeriodicFourierProcess", + bp::init((bp::arg("self"), "wavelength", "period"))); + + /* FIXME: Use template lambda and compile-time for-loop when moving to c++20. + For reference: https://stackoverflow.com/a/76272348/4820605 */ + PyPerlinProcessVisitor<1>::expose(); + PyPerlinProcessVisitor<2>::expose(); + PyPerlinProcessVisitor<3>::expose(); + + bp::def( + "random_tile_ground", + &tiles, + (bp::arg("size"), "height_max", "interp_delta", "sparsity", "orientation", "seed")); + bp::def("periodic_stairs_ground", + &periodicStairs, + (bp::arg("step_width"), "step_height", "step_number", "orientation")); + bp::def("unidirectional_random_perlin_ground", + &unidirectionalRandomPerlinGround, + (bp::arg("wavelength"), "num_octaves", "orientation", "seed")); + bp::def("unidirectional_periodic_perlin_ground", + &unidirectionalPeriodicPerlinGround, + (bp::arg("wavelength"), "period", "num_octaves", "orientation", "seed")); + bp::def("random_perlin_ground", + &randomPerlinGround, + (bp::arg("wavelength"), "num_octaves", "seed")); + bp::def("periodic_perlin_ground", + &periodicPerlinGround, + (bp::arg("wavelength"), "period", "num_octaves", "seed")); bp::def("sum_heightmaps", &sumHeightmaps, (bp::arg("heightmaps"))); bp::def("merge_heightmaps", &mergeHeightmaps, (bp::arg("heightmaps"))); - bp::def("discretize_heightmap", &discretizeHeightmap, - (bp::arg("heightmap"), "x_min", "x_max", "x_unit", "y_min", - "y_max", "y_unit", bp::arg("must_simplify") = false)); - // clang-format on + bp::def("discretize_heightmap", + &discretizeHeightmap, + (bp::arg("heightmap"), + "x_min", + "x_max", + "x_unit", + "y_min", + "y_max", + "y_unit", + bp::arg("must_simplify") = false)); } } diff --git a/python/jiminy_pywrap/src/module.cc b/python/jiminy_pywrap/src/module.cc index c304ce0f1..1b8aa70f5 100755 --- a/python/jiminy_pywrap/src/module.cc +++ b/python/jiminy_pywrap/src/module.cc @@ -6,6 +6,8 @@ #include "jiminy/core/utilities/random.h" +#include + /* Eigenpy must be imported first, since it sets pre-processor definitions used by Boost Python to configure Python C API. */ #include "pinocchio/bindings/python/fwd.hpp" @@ -53,6 +55,14 @@ namespace jiminy::python // Initialized EigenPy, enabling PyArrays<->Eigen automatic converters eigenpy::enableEigenPy(); + if (!eigenpy::register_symbolic_link_to_registered_type()) + { + bp::enum_("GJKInitialGuess") + .value("DefaultGuess", hpp::fcl::GJKInitialGuess::DefaultGuess) + .value("CachedGuess", hpp::fcl::GJKInitialGuess::CachedGuess) + .value("BoundingVolumeGuess", hpp::fcl::GJKInitialGuess::BoundingVolumeGuess); + } + // Expose the version bp::scope().attr("__version__") = bp::str(JIMINY_VERSION); bp::scope().attr("__raw_version__") = bp::str(JIMINY_VERSION);