diff --git a/include/etl/concepts.hpp b/include/etl/concepts.hpp index cfb2b5ca..8fc5ca06 100644 --- a/include/etl/concepts.hpp +++ b/include/etl/concepts.hpp @@ -96,6 +96,9 @@ concept etl_3d = etl_expr && decay_traits::dimensions() == 3; template concept etl_4d = etl_expr && decay_traits::dimensions() == 4; +template +concept etl_2d_or_4d = etl_expr && (decay_traits::dimensions() == 2 || decay_traits::dimensions() == 4); + template concept etl_4d_and_plus = etl_expr && decay_traits::dimensions() >= 4; diff --git a/include/etl/expr/batch_embedding_gradients_expr.hpp b/include/etl/expr/batch_embedding_gradients_expr.hpp index 0a0005fd..4851f666 100644 --- a/include/etl/expr/batch_embedding_gradients_expr.hpp +++ b/include/etl/expr/batch_embedding_gradients_expr.hpp @@ -15,7 +15,7 @@ namespace etl { * \brief A transposition expression. * \tparam A The transposed type */ -template +template struct batch_embedding_gradients_expr : base_temporary_expr_tern, A, B, C> { using value_type = value_t; ///< The type of value of the expression using this_type = batch_embedding_gradients_expr; ///< The type of this expression @@ -43,12 +43,8 @@ struct batch_embedding_gradients_expr : base_temporary_expr_tern + template static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c, [[maybe_unused]] const L& lhs) { - static_assert(etl::dimensions() == 2, "The input of batch_embedding_gradients is a 1d matrix"); - static_assert(etl::dimensions() == 3, "The vocabulary input of batch_embedding_gradients is a 2d matrix"); - static_assert(etl::dimensions() == 2, "The output of batch_embedding_gradients is 2d matrix"); - if constexpr (all_fast) { static_assert(etl::dim<0, A>() == etl::dim<0, B>(), "Invalid dimensions for batch_embedding_gradients"); static_assert(etl::dim<1, A>() == etl::dim<1, B>(), "Invalid dimensions for batch_embedding_gradients"); diff --git a/include/etl/expr/batch_k_minus_scale_expr.hpp b/include/etl/expr/batch_k_minus_scale_expr.hpp index cc03f665..3e1f7162 100644 --- a/include/etl/expr/batch_k_minus_scale_expr.hpp +++ b/include/etl/expr/batch_k_minus_scale_expr.hpp @@ -11,7 +11,7 @@ namespace etl { -template +template struct batch_k_minus_scale_expr : base_temporary_expr_tern, A, B, C> { using value_type = value_t; ///< The type of value of the expression using this_type = batch_k_minus_scale_expr; ///< The type of this expression @@ -43,14 +43,9 @@ struct batch_k_minus_scale_expr : base_temporary_expr_tern + template L> static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c, [[maybe_unused]] L& lhs) { if constexpr (D4) { - static_assert(etl::dimensions() == 4, "The output of batch_k_minus_scale is a 4D matrix"); - static_assert(etl::dimensions() == 1, "The lhs of batch_k_minus_scale is a 1D matrix"); - static_assert(etl::dimensions() == 4, "The rhs of batch_k_minus_scale is a 4D matrix"); - static_assert(etl::dimensions() == 1, "The beta of batch_k_minus_scale is a 1D matrix"); - if constexpr (all_fast) { static_assert(etl::dim<0, B>() == etl::dim<0, L>(), "Invalid dimensions for batch_k_minus_scale"); static_assert(etl::dim<1, B>() == etl::dim<1, L>(), "Invalid dimensions for batch_k_minus_scale"); @@ -69,11 +64,6 @@ struct batch_k_minus_scale_expr : base_temporary_expr_tern(a) == etl::dim<0>(c), "Invalid dimensions for batch_k_minus_scale"); } } else { - static_assert(etl::dimensions() == 2, "The output of batch_k_minus_scale is a 2D matrix"); - static_assert(etl::dimensions() == 1, "The lhs of batch_k_minus_scale is a 1D matrix"); - static_assert(etl::dimensions() == 2, "The rhs of batch_k_minus_scale is a 2D matrix"); - static_assert(etl::dimensions() == 1, "The beta of batch_k_minus_scale is a 1D matrix"); - if constexpr (all_fast) { static_assert(etl::dim<0, B>() == etl::dim<0, L>(), "Invalid dimensions for batch_k_minus_scale"); static_assert(etl::dim<1, B>() == etl::dim<1, L>(), "Invalid dimensions for batch_k_minus_scale"); diff --git a/include/etl/expr/bias_batch_mean_2d_expr.hpp b/include/etl/expr/bias_batch_mean_2d_expr.hpp index 1d6207df..29d61f73 100644 --- a/include/etl/expr/bias_batch_mean_2d_expr.hpp +++ b/include/etl/expr/bias_batch_mean_2d_expr.hpp @@ -18,7 +18,7 @@ namespace etl { * \brief A transposition expression. * \tparam A The transposed type */ -template +template struct bias_batch_mean_2d_expr : base_temporary_expr_un, A> { using value_type = value_t; ///< The type of value of the expression using this_type = bias_batch_mean_2d_expr; ///< The type of this expression @@ -48,11 +48,8 @@ struct bias_batch_mean_2d_expr : base_temporary_expr_un + template static void check([[maybe_unused]] const A& a, [[maybe_unused]] const C& c) { - static_assert(etl::dimensions() == 1, "The output of bias_batch_mean_2d is a vector"); - static_assert(etl::dimensions() == 2, "The input of bias_batch_mean_2d is a 2d matrix"); - if constexpr (all_fast) { static_assert(etl::dim<1, A>() == etl::dim<0, C>(), "Invalid dimensions for bias_batch_mean_2d"); } else { diff --git a/include/etl/expr/conv_2d_same_expr.hpp b/include/etl/expr/conv_2d_same_expr.hpp index b5b17129..4c02930f 100644 --- a/include/etl/expr/conv_2d_same_expr.hpp +++ b/include/etl/expr/conv_2d_same_expr.hpp @@ -46,7 +46,7 @@ struct conv_2d_same_expr : base_temporary_expr_bin + template static void check([[maybe_unused]] const I& input, [[maybe_unused]] const K& kernel, [[maybe_unused]] const C& conv) { static_assert(etl::dimensions() == 2, "Invalid number of dimensions for input of conv2_same"); static_assert(etl::dimensions() == 2, "Invalid number of dimensions for kernel of conv2_same"); diff --git a/include/etl/expr/conv_2d_same_multi_expr.hpp b/include/etl/expr/conv_2d_same_multi_expr.hpp index 330ec668..bd76ce06 100644 --- a/include/etl/expr/conv_2d_same_multi_expr.hpp +++ b/include/etl/expr/conv_2d_same_multi_expr.hpp @@ -46,12 +46,8 @@ struct conv_2d_same_multi_expr : base_temporary_expr_bin + template static void check([[maybe_unused]] const I& input, [[maybe_unused]] const K& kernel, [[maybe_unused]] const C& conv) { - static_assert(etl::dimensions() == 2, "Invalid number of dimensions for input of conv2_same_multi"); - static_assert(etl::dimensions() == 3, "Invalid number of dimensions for kernel of conv2_same_multi"); - static_assert(etl::dimensions() == 3, "Invalid number of dimensions for conv of conv2_same_multi"); - if constexpr (all_fast) { static_assert(etl::dim<0, C>() == etl::dim<0, K>(), "Invalid dimensions for conv2_same_multi"); static_assert(etl::dim<1, C>() == etl::dim<0, I>(), "Invalid dimensions for conv2_same_multi"); diff --git a/include/etl/expr/conv_2d_valid_expr.hpp b/include/etl/expr/conv_2d_valid_expr.hpp index dfc1f105..ed3400df 100644 --- a/include/etl/expr/conv_2d_valid_expr.hpp +++ b/include/etl/expr/conv_2d_valid_expr.hpp @@ -54,12 +54,8 @@ struct conv_2d_valid_expr : base_temporary_expr_bin + template static void check([[maybe_unused]] const I& input, [[maybe_unused]] const K& kernel, [[maybe_unused]] const C& conv) { - static_assert(etl::dimensions() == 2, "Invalid number of dimensions for input of conv2_valid"); - static_assert(etl::dimensions() == 2, "Invalid number of dimensions for kernel of conv2_valid"); - static_assert(etl::dimensions() == 2, "Invalid number of dimensions for conv of conv2_valid"); - if constexpr (all_fast) { static_assert(etl::dim<0, C>() == (etl::dim<0, I>() - etl::dim<0, K>() + 2 * P1) / S1 + 1, "Invalid dimensions for conv2_valid"); static_assert(etl::dim<1, C>() == (etl::dim<1, I>() - etl::dim<1, K>() + 2 * P2) / S2 + 1, "Invalid dimensions for conv2_valid"); diff --git a/include/etl/expr/convmtx_2d_expr.hpp b/include/etl/expr/convmtx_2d_expr.hpp index 1887730a..125d27df 100644 --- a/include/etl/expr/convmtx_2d_expr.hpp +++ b/include/etl/expr/convmtx_2d_expr.hpp @@ -47,10 +47,8 @@ struct convmtx_2d_expr : base_temporary_expr_un, A> { * \brief Assign to a matrix of the same storage order * \param c The expression to which assign */ - template + template C> void assign_to(C&& c) const { - static_assert(etl::dimensions() == etl::dimensions(), "max_pool_2d must be applied on matrices of same dimensionality"); - inc_counter("temp:assign"); auto& a = this->a(); diff --git a/include/etl/expr/embedding_gradients_expr.hpp b/include/etl/expr/embedding_gradients_expr.hpp index 5896b596..6769b641 100644 --- a/include/etl/expr/embedding_gradients_expr.hpp +++ b/include/etl/expr/embedding_gradients_expr.hpp @@ -15,7 +15,7 @@ namespace etl { * \brief A transposition expression. * \tparam A The transposed type */ -template +template struct embedding_gradients_expr : base_temporary_expr_tern, A, B, C> { using value_type = value_t; ///< The type of value of the expression using this_type = embedding_gradients_expr; ///< The type of this expression @@ -43,12 +43,8 @@ struct embedding_gradients_expr : base_temporary_expr_tern + template static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c, [[maybe_unused]] const L& lhs) { - static_assert(etl::dimensions() == 1, "The input of embedding_gradients is a 1d matrix"); - static_assert(etl::dimensions() == 2, "The vocabulary input of embedding_gradients is a 2d matrix"); - static_assert(etl::dimensions() == 2, "The output of embedding_gradients is 2d matrix"); - if constexpr (all_fast) { static_assert(etl::dim<0, A>() == etl::dim<0, B>(), "Invalid dimensions for embedding_gradients"); static_assert(etl::dim<1, B>() == etl::dim<1, L>(), "Invalid dimensions for embedding_gradients"); diff --git a/include/etl/expr/fft_expr.hpp b/include/etl/expr/fft_expr.hpp index e16e17c4..52e87ca1 100644 --- a/include/etl/expr/fft_expr.hpp +++ b/include/etl/expr/fft_expr.hpp @@ -44,10 +44,8 @@ struct fft_expr : base_temporary_expr_un, A> { * \brief Assign to a matrix of the same storage order * \param c The expression to which assign */ - template + template C> void assign_to(C&& c) const { - static_assert(etl::dimensions() == etl::dimensions(), "max_pool_2d must be applied on matrices of same dimensionality"); - inc_counter("temp:assign"); Impl::apply(this->a(), c); @@ -57,7 +55,7 @@ struct fft_expr : base_temporary_expr_un, A> { * \brief Add to the given left-hand-side expression * \param lhs The expression to which assign */ - template + template L> void assign_add_to(L&& lhs) const { std_add_evaluate(*this, lhs); } @@ -66,7 +64,7 @@ struct fft_expr : base_temporary_expr_un, A> { * \brief Sub from the given left-hand-side expression * \param lhs The expression to which assign */ - template + template L> void assign_sub_to(L&& lhs) const { std_sub_evaluate(*this, lhs); } @@ -75,7 +73,7 @@ struct fft_expr : base_temporary_expr_un, A> { * \brief Multiply the given left-hand-side expression * \param lhs The expression to which assign */ - template + template L> void assign_mul_to(L&& lhs) const { std_mul_evaluate(*this, lhs); } @@ -84,7 +82,7 @@ struct fft_expr : base_temporary_expr_un, A> { * \brief Divide the given left-hand-side expression * \param lhs The expression to which assign */ - template + template L> void assign_div_to(L&& lhs) const { std_div_evaluate(*this, lhs); } @@ -93,7 +91,7 @@ struct fft_expr : base_temporary_expr_un, A> { * \brief Modulo the given left-hand-side expression * \param lhs The expression to which assign */ - template + template L> void assign_mod_to(L&& lhs) const { std_mod_evaluate(*this, lhs); } @@ -373,10 +371,8 @@ auto ifft_2d_real(A&& a, C&& c) { * \param a The input expression * \return an expression representing several 1D FFT of a */ -template +template fft_expr, detail::fft_value_type, detail::fft1_many_impl> fft_1d_many(A&& a) { - static_assert(decay_traits::dimensions() >= 2, "fft_many requires at least 2D matrices"); - return fft_expr, detail::fft_value_type, detail::fft1_many_impl>{a}; } @@ -389,9 +385,8 @@ fft_expr, detail::fft_value_type, detail::fft1_many_imp * \param c The result * \return an expression representing several 1D FFT of a */ -template +template auto fft_1d_many(A&& a, C&& c) { - static_assert(decay_traits::dimensions() >= 2 && decay_traits::dimensions() >= 2, "fft_many requires at least 2D matrices"); validate_assign(c, a); c = fft_1d_many(a); @@ -406,10 +401,8 @@ auto fft_1d_many(A&& a, C&& c) { * \param a The input expression * \return an expression representing several 1D FFT of a */ -template +template fft_expr, detail::ifft_value_type, detail::ifft1_many_impl> ifft_1d_many(A&& a) { - static_assert(decay_traits::dimensions() >= 2, "ifft_many requires at least 2D matrices"); - return fft_expr, detail::ifft_value_type, detail::ifft1_many_impl>{a}; } @@ -422,9 +415,8 @@ fft_expr, detail::ifft_value_type, detail::ifft1_many_i * \param c The result * \return an expression representing several 1D FFT of a */ -template +template auto ifft_1d_many(A&& a, C&& c) { - static_assert(decay_traits::dimensions() >= 2 && decay_traits::dimensions() >= 2, "ifft_many requires at least 2D matrices"); validate_assign(c, a); c = ifft_1d_many(a); @@ -439,10 +431,8 @@ auto ifft_1d_many(A&& a, C&& c) { * \param a The input expression * \return an expression representing several 2D FFT of a */ -template +template fft_expr, detail::fft_value_type, detail::fft2_many_impl> fft_2d_many(A&& a) { - static_assert(decay_traits::dimensions() >= 3, "fft_many requires at least 3D matrices"); - return fft_expr, detail::fft_value_type, detail::fft2_many_impl>{a}; } @@ -455,9 +445,8 @@ fft_expr, detail::fft_value_type, detail::fft2_many_imp * \param c The result * \return an expression representing several 2D FFT of a */ -template +template auto fft_2d_many(A&& a, C&& c) { - static_assert(decay_traits::dimensions() >= 3 && decay_traits::dimensions() >= 3, "fft_many requires at least 3D matrices"); validate_assign(c, a); c = fft_2d_many(a); @@ -472,10 +461,8 @@ auto fft_2d_many(A&& a, C&& c) { * \param a The input expression * \return an expression representing several 2D FFT of a */ -template +template fft_expr, detail::ifft_value_type, detail::ifft2_many_impl> ifft_2d_many(A&& a) { - static_assert(decay_traits::dimensions() >= 3, "ifft_many requires at least 3D matrices"); - return fft_expr, detail::ifft_value_type, detail::ifft2_many_impl>{a}; } @@ -488,9 +475,8 @@ fft_expr, detail::ifft_value_type, detail::ifft2_many_i * \param c The result * \return an expression representing several 2D FFT of a */ -template +template auto ifft_2d_many(A&& a, C&& c) { - static_assert(decay_traits::dimensions() >= 3 && decay_traits::dimensions() >= 3, "ifft_many requires at least 3D matrices"); validate_assign(c, a); c = ifft_2d_many(a); diff --git a/include/etl/expr/pool_upsample_2d_expr.hpp b/include/etl/expr/pool_upsample_2d_expr.hpp index 28f2aee7..96f8a962 100644 --- a/include/etl/expr/pool_upsample_2d_expr.hpp +++ b/include/etl/expr/pool_upsample_2d_expr.hpp @@ -22,7 +22,7 @@ namespace etl { * \tparam B The output type * \tparam C The errors type */ -template +template B, same_dimensions C, size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, bool Max> struct pool_upsample_2d_expr : base_temporary_expr_tern, A, B, C> { using value_type = value_t; ///< The type of value of the expression using sub_traits = etl::decay_traits; ///< The traits of the first sub type diff --git a/include/etl/expr/pool_upsample_3d_expr.hpp b/include/etl/expr/pool_upsample_3d_expr.hpp index 93a59a36..a735b306 100644 --- a/include/etl/expr/pool_upsample_3d_expr.hpp +++ b/include/etl/expr/pool_upsample_3d_expr.hpp @@ -22,7 +22,7 @@ namespace etl { * \tparam B The output type * \tparam C The errors type */ -template +template B, same_dimensions C, size_t C1, size_t C2, size_t C3, bool Max> struct pool_upsample_3d_expr : base_temporary_expr_tern, A, B, C> { using value_type = value_t; ///< The type of value of the expression using sub_traits = etl::decay_traits; ///< The traits of the first sub type @@ -52,14 +52,10 @@ struct pool_upsample_3d_expr : base_temporary_expr_tern + template R> static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c, [[maybe_unused]] const R& result) { static constexpr size_t D = etl::decay_traits::dimensions(); - static_assert(etl::decay_traits::dimensions() == D, "Invalid dimensions in max_pool_upsampl_3d"); - static_assert(etl::decay_traits::dimensions() == D, "Invalid dimensions in max_pool_upsampl_3d"); - static_assert(etl::decay_traits::dimensions() == D, "Invalid dimensions in max_pool_upsampl_3d"); - if constexpr (all_fast) { static_assert(etl::decay_traits::size() == etl::decay_traits::size(), "max_pool_upsample_3d:A and R must have the same size"); static_assert(etl::decay_traits::size() == etl::decay_traits::size(), "max_pool_upsample_3d:B and C must have the same size");