Skip to content

Commit

Permalink
Merge pull request #3124 from stan-dev/fix/kronecker_scalar
Browse files Browse the repository at this point in the history
use eigen internal traits for getting the scalar type of a eigen type
  • Loading branch information
WardBrian authored Nov 14, 2024
2 parents 7ada875 + 283f7bf commit 0a0831d
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 10 deletions.
59 changes: 57 additions & 2 deletions stan/math/prim/meta/is_eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@ template <typename T>
struct is_eigen
: bool_constant<is_base_pointer_convertible<Eigen::EigenBase, T>::value> {};

namespace internal {
// primary template handles types that have no nested ::type member:
template <class, class = void>
struct has_internal_trait : std::false_type {};

// specialization recognizes types that do have a nested ::type member:
template <class T>
struct has_internal_trait<T,
std::void_t<Eigen::internal::traits<std::decay_t<T>>>>
: std::true_type {};

// primary template handles types that have no nested ::type member:
template <class, class = void>
struct has_scalar_trait : std::false_type {};

// specialization recognizes types that do have a nested ::type member:
template <class T>
struct has_scalar_trait<T, std::void_t<typename std::decay_t<T>::Scalar>>
: std::true_type {};

} // namespace internal

/**
* Template metaprogram defining the base scalar type of
* values stored in an Eigen matrix.
Expand All @@ -28,7 +50,9 @@ struct is_eigen
* @ingroup type_trait
*/
template <typename T>
struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
struct scalar_type<T,
std::enable_if_t<is_eigen<T>::value
&& internal::has_scalar_trait<T>::value>> {
using type = scalar_type_t<typename std::decay_t<T>::Scalar>;
};

Expand All @@ -40,10 +64,41 @@ struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
* @ingroup type_trait
*/
template <typename T>
struct value_type<T, std::enable_if_t<is_eigen<T>::value>> {
struct value_type<T,
std::enable_if_t<is_eigen<T>::value
&& internal::has_scalar_trait<T>::value>> {
using type = typename std::decay_t<T>::Scalar;
};

/**
* Template metaprogram defining the base scalar type of
* values stored in an Eigen matrix.
*
* @tparam T type to check.
* @ingroup type_trait
*/
template <typename T>
struct scalar_type<T,
std::enable_if_t<is_eigen<T>::value
&& !internal::has_scalar_trait<T>::value>> {
using type = scalar_type_t<
typename Eigen::internal::traits<std::decay_t<T>>::Scalar>;
};

/**
* Template metaprogram defining the type of values stored in an
* Eigen matrix, vector, or row vector.
*
* @tparam T type to check
* @ingroup type_trait
*/
template <typename T>
struct value_type<T,
std::enable_if_t<is_eigen<T>::value
&& !internal::has_scalar_trait<T>::value>> {
using type = typename Eigen::internal::traits<std::decay_t<T>>::Scalar;
};

/*! \ingroup require_eigens_types */
/*! \defgroup eigen_types eigen */
/*! \addtogroup eigen_types */
Expand Down
1 change: 1 addition & 0 deletions stan/math/rev/core/arena_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ namespace internal {
template <typename T>
struct traits<stan::math::arena_matrix<T>> {
using base = traits<Eigen::Map<T>>;
using Scalar = typename base::Scalar;
using XprKind = typename Eigen::internal::traits<std::decay_t<T>>::XprKind;
enum {
PlainObjectTypeInnerSize = base::PlainObjectTypeInnerSize,
Expand Down
26 changes: 18 additions & 8 deletions test/unit/math/prim/meta/value_type_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <stan/math/prim/meta.hpp>
#include <test/unit/util.hpp>
#include <stan/math/prim/meta.hpp>
#include <unsupported/Eigen/KroneckerProduct>
#include <gtest/gtest.h>
#include <vector>

Expand All @@ -8,16 +9,16 @@ TEST(MathMetaPrim, value_type_vector) {
using std::vector;

EXPECT_SAME_TYPE(vector<double>::value_type,
value_type<vector<double> >::type);
value_type<vector<double>>::type);

EXPECT_SAME_TYPE(vector<double>::value_type,
value_type<const vector<double> >::type);
value_type<const vector<double>>::type);

EXPECT_SAME_TYPE(vector<vector<int> >::value_type,
value_type<vector<vector<int> > >::type);
EXPECT_SAME_TYPE(vector<vector<int>>::value_type,
value_type<vector<vector<int>>>::type);

EXPECT_SAME_TYPE(vector<vector<int> >::value_type,
value_type<const vector<vector<int> > >::type);
EXPECT_SAME_TYPE(vector<vector<int>>::value_type,
value_type<const vector<vector<int>>>::type);
}

TEST(MathMetaPrim, value_type_matrix) {
Expand All @@ -33,5 +34,14 @@ TEST(MathMetaPrim, value_type_matrix) {
value_type<Eigen::RowVectorXd>::type);

EXPECT_SAME_TYPE(Eigen::RowVectorXd,
value_type<std::vector<Eigen::RowVectorXd> >::type);
value_type<std::vector<Eigen::RowVectorXd>>::type);
}

TEST(MathMetaPrim, value_type_kronecker) {
Eigen::Matrix<double, 2, 2> A;
const auto B
= Eigen::kroneckerProduct(A, Eigen::Matrix<double, 2, 2>::Identity());
Eigen::Matrix<double, 4, 1> C = Eigen::Matrix<double, 4, 1>::Random(4, 1);
EXPECT_TRUE((std::is_same<double, stan::value_type_t<decltype(B)>>::value));
Eigen::MatrixXd D = B * C;
}

0 comments on commit 0a0831d

Please sign in to comment.