Skip to content

Commit

Permalink
Complete usage of concepts
Browse files Browse the repository at this point in the history
  • Loading branch information
wichtounet committed Dec 13, 2023
1 parent db282df commit 5d8cc2e
Show file tree
Hide file tree
Showing 12 changed files with 32 additions and 78 deletions.
3 changes: 3 additions & 0 deletions include/etl/concepts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ concept etl_3d = etl_expr<T> && decay_traits<T>::dimensions() == 3;
template <typename T>
concept etl_4d = etl_expr<T> && decay_traits<T>::dimensions() == 4;

template <typename T>
concept etl_2d_or_4d = etl_expr<T> && (decay_traits<T>::dimensions() == 2 || decay_traits<T>::dimensions() == 4);

template <typename T>
concept etl_4d_and_plus = etl_expr<T> && decay_traits<T>::dimensions() >= 4;

Expand Down
8 changes: 2 additions & 6 deletions include/etl/expr/batch_embedding_gradients_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B, typename C>
template <etl_2d A, etl_3d B, typename C>
struct batch_embedding_gradients_expr : base_temporary_expr_tern<batch_embedding_gradients_expr<A, B, C>, A, B, C> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = batch_embedding_gradients_expr<A, B, C>; ///< The type of this expression
Expand Down Expand Up @@ -43,12 +43,8 @@ struct batch_embedding_gradients_expr : base_temporary_expr_tern<batch_embedding
* \param a The input matrix
* \þaram lhs The output matrix
*/
template <typename L>
template <etl_2d L>
static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c, [[maybe_unused]] const L& lhs) {
static_assert(etl::dimensions<A>() == 2, "The input of batch_embedding_gradients is a 1d matrix");
static_assert(etl::dimensions<B>() == 3, "The vocabulary input of batch_embedding_gradients is a 2d matrix");
static_assert(etl::dimensions<L>() == 2, "The output of batch_embedding_gradients is 2d matrix");

if constexpr (all_fast<A, B, C, L>) {
static_assert(etl::dim<0, A>() == etl::dim<0, B>(), "Invalid dimensions for batch_embedding_gradients");
static_assert(etl::dim<1, A>() == etl::dim<1, B>(), "Invalid dimensions for batch_embedding_gradients");
Expand Down
14 changes: 2 additions & 12 deletions include/etl/expr/batch_k_minus_scale_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace etl {

template <etl_expr A, etl_expr B, etl_expr C>
template <etl_1d A, etl_2d_or_4d B, etl_1d C>
struct batch_k_minus_scale_expr : base_temporary_expr_tern<batch_k_minus_scale_expr<A, B, C>, A, B, C> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = batch_k_minus_scale_expr<A, B, C>; ///< The type of this expression
Expand Down Expand Up @@ -43,14 +43,9 @@ struct batch_k_minus_scale_expr : base_temporary_expr_tern<batch_k_minus_scale_e
* \param a The input matrix
* \þaram c The output matrix
*/
template <typename L>
template <same_dimensions<B> L>
static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c, [[maybe_unused]] L& lhs) {
if constexpr (D4) {
static_assert(etl::dimensions<L>() == 4, "The output of batch_k_minus_scale is a 4D matrix");
static_assert(etl::dimensions<A>() == 1, "The lhs of batch_k_minus_scale is a 1D matrix");
static_assert(etl::dimensions<B>() == 4, "The rhs of batch_k_minus_scale is a 4D matrix");
static_assert(etl::dimensions<C>() == 1, "The beta of batch_k_minus_scale is a 1D matrix");

if constexpr (all_fast<A, B, C, L>) {
static_assert(etl::dim<0, B>() == etl::dim<0, L>(), "Invalid dimensions for batch_k_minus_scale");
static_assert(etl::dim<1, B>() == etl::dim<1, L>(), "Invalid dimensions for batch_k_minus_scale");
Expand All @@ -69,11 +64,6 @@ struct batch_k_minus_scale_expr : base_temporary_expr_tern<batch_k_minus_scale_e
cpp_assert(etl::dim<0>(a) == etl::dim<0>(c), "Invalid dimensions for batch_k_minus_scale");
}
} else {
static_assert(etl::dimensions<L>() == 2, "The output of batch_k_minus_scale is a 2D matrix");
static_assert(etl::dimensions<A>() == 1, "The lhs of batch_k_minus_scale is a 1D matrix");
static_assert(etl::dimensions<B>() == 2, "The rhs of batch_k_minus_scale is a 2D matrix");
static_assert(etl::dimensions<C>() == 1, "The beta of batch_k_minus_scale is a 1D matrix");

if constexpr (all_fast<A, B, C, L>) {
static_assert(etl::dim<0, B>() == etl::dim<0, L>(), "Invalid dimensions for batch_k_minus_scale");
static_assert(etl::dim<1, B>() == etl::dim<1, L>(), "Invalid dimensions for batch_k_minus_scale");
Expand Down
7 changes: 2 additions & 5 deletions include/etl/expr/bias_batch_mean_2d_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <etl_expr A, bool Mean>
template <etl_2d A, bool Mean>
struct bias_batch_mean_2d_expr : base_temporary_expr_un<bias_batch_mean_2d_expr<A, Mean>, A> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = bias_batch_mean_2d_expr<A, Mean>; ///< The type of this expression
Expand Down Expand Up @@ -48,11 +48,8 @@ struct bias_batch_mean_2d_expr : base_temporary_expr_un<bias_batch_mean_2d_expr<
* \param a The input matrix
* \þaram c The output matrix
*/
template <typename C>
template <etl_1d C>
static void check([[maybe_unused]] const A& a, [[maybe_unused]] const C& c) {
static_assert(etl::dimensions<C>() == 1, "The output of bias_batch_mean_2d is a vector");
static_assert(etl::dimensions<A>() == 2, "The input of bias_batch_mean_2d is a 2d matrix");

if constexpr (all_fast<A, C>) {
static_assert(etl::dim<1, A>() == etl::dim<0, C>(), "Invalid dimensions for bias_batch_mean_2d");
} else {
Expand Down
2 changes: 1 addition & 1 deletion include/etl/expr/conv_2d_same_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct conv_2d_same_expr : base_temporary_expr_bin<conv_2d_same_expr<A, B, Flipp
/*!
* \brief Assert that the convolution is done on correct dimensions
*/
template <typename I, typename K, typename C>
template <etl_2d I, etl_2d K, etl_2d C>
static void check([[maybe_unused]] const I& input, [[maybe_unused]] const K& kernel, [[maybe_unused]] const C& conv) {
static_assert(etl::dimensions<I>() == 2, "Invalid number of dimensions for input of conv2_same");
static_assert(etl::dimensions<K>() == 2, "Invalid number of dimensions for kernel of conv2_same");
Expand Down
6 changes: 1 addition & 5 deletions include/etl/expr/conv_2d_same_multi_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,8 @@ struct conv_2d_same_multi_expr : base_temporary_expr_bin<conv_2d_same_multi_expr
/*!
* \brief Assert that the convolution is done on correct dimensions
*/
template <typename I, typename K, typename C>
template <etl_2d I, etl_3d K, etl_3d C>
static void check([[maybe_unused]] const I& input, [[maybe_unused]] const K& kernel, [[maybe_unused]] const C& conv) {
static_assert(etl::dimensions<I>() == 2, "Invalid number of dimensions for input of conv2_same_multi");
static_assert(etl::dimensions<K>() == 3, "Invalid number of dimensions for kernel of conv2_same_multi");
static_assert(etl::dimensions<C>() == 3, "Invalid number of dimensions for conv of conv2_same_multi");

if constexpr (all_fast<A, B, C>) {
static_assert(etl::dim<0, C>() == etl::dim<0, K>(), "Invalid dimensions for conv2_same_multi");
static_assert(etl::dim<1, C>() == etl::dim<0, I>(), "Invalid dimensions for conv2_same_multi");
Expand Down
6 changes: 1 addition & 5 deletions include/etl/expr/conv_2d_valid_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,8 @@ struct conv_2d_valid_expr : base_temporary_expr_bin<conv_2d_valid_expr<A, B, S1,
/*!
* \brief Assert that the convolution is done on correct dimensions
*/
template <typename I, typename K, typename C>
template <etl_2d I, etl_2d K, etl_2d C>
static void check([[maybe_unused]] const I& input, [[maybe_unused]] const K& kernel, [[maybe_unused]] const C& conv) {
static_assert(etl::dimensions<I>() == 2, "Invalid number of dimensions for input of conv2_valid");
static_assert(etl::dimensions<K>() == 2, "Invalid number of dimensions for kernel of conv2_valid");
static_assert(etl::dimensions<C>() == 2, "Invalid number of dimensions for conv of conv2_valid");

if constexpr (all_fast<A, B, C>) {
static_assert(etl::dim<0, C>() == (etl::dim<0, I>() - etl::dim<0, K>() + 2 * P1) / S1 + 1, "Invalid dimensions for conv2_valid");
static_assert(etl::dim<1, C>() == (etl::dim<1, I>() - etl::dim<1, K>() + 2 * P2) / S2 + 1, "Invalid dimensions for conv2_valid");
Expand Down
4 changes: 1 addition & 3 deletions include/etl/expr/convmtx_2d_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ struct convmtx_2d_expr : base_temporary_expr_un<convmtx_2d_expr<A, K1, K2>, A> {
* \brief Assign to a matrix of the same storage order
* \param c The expression to which assign
*/
template <etl_expr C>
template <same_dimensions<A> C>
void assign_to(C&& c) const {
static_assert(etl::dimensions<A>() == etl::dimensions<C>(), "max_pool_2d must be applied on matrices of same dimensionality");

inc_counter("temp:assign");

auto& a = this->a();
Expand Down
8 changes: 2 additions & 6 deletions include/etl/expr/embedding_gradients_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B, typename C>
template <etl_1d A, etl_2d B, typename C>
struct embedding_gradients_expr : base_temporary_expr_tern<embedding_gradients_expr<A, B, C>, A, B, C> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = embedding_gradients_expr<A, B, C>; ///< The type of this expression
Expand Down Expand Up @@ -43,12 +43,8 @@ struct embedding_gradients_expr : base_temporary_expr_tern<embedding_gradients_e
* \param a The input matrix
* \þaram lhs The output matrix
*/
template <typename L>
template <etl_2d L>
static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c, [[maybe_unused]] const L& lhs) {
static_assert(etl::dimensions<A>() == 1, "The input of embedding_gradients is a 1d matrix");
static_assert(etl::dimensions<B>() == 2, "The vocabulary input of embedding_gradients is a 2d matrix");
static_assert(etl::dimensions<L>() == 2, "The output of embedding_gradients is 2d matrix");

if constexpr (all_fast<A, B, C, L>) {
static_assert(etl::dim<0, A>() == etl::dim<0, B>(), "Invalid dimensions for embedding_gradients");
static_assert(etl::dim<1, B>() == etl::dim<1, L>(), "Invalid dimensions for embedding_gradients");
Expand Down
42 changes: 14 additions & 28 deletions include/etl/expr/fft_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ struct fft_expr : base_temporary_expr_un<fft_expr<A, T, Impl>, A> {
* \brief Assign to a matrix of the same storage order
* \param c The expression to which assign
*/
template <etl_expr C>
template <same_dimensions<A> C>
void assign_to(C&& c) const {
static_assert(etl::dimensions<A>() == etl::dimensions<C>(), "max_pool_2d must be applied on matrices of same dimensionality");

inc_counter("temp:assign");

Impl::apply(this->a(), c);
Expand All @@ -57,7 +55,7 @@ struct fft_expr : base_temporary_expr_un<fft_expr<A, T, Impl>, A> {
* \brief Add to the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <same_dimensions<A> L>
void assign_add_to(L&& lhs) const {
std_add_evaluate(*this, lhs);
}
Expand All @@ -66,7 +64,7 @@ struct fft_expr : base_temporary_expr_un<fft_expr<A, T, Impl>, A> {
* \brief Sub from the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <same_dimensions<A> L>
void assign_sub_to(L&& lhs) const {
std_sub_evaluate(*this, lhs);
}
Expand All @@ -75,7 +73,7 @@ struct fft_expr : base_temporary_expr_un<fft_expr<A, T, Impl>, A> {
* \brief Multiply the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <same_dimensions<A> L>
void assign_mul_to(L&& lhs) const {
std_mul_evaluate(*this, lhs);
}
Expand All @@ -84,7 +82,7 @@ struct fft_expr : base_temporary_expr_un<fft_expr<A, T, Impl>, A> {
* \brief Divide the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <same_dimensions<A> L>
void assign_div_to(L&& lhs) const {
std_div_evaluate(*this, lhs);
}
Expand All @@ -93,7 +91,7 @@ struct fft_expr : base_temporary_expr_un<fft_expr<A, T, Impl>, A> {
* \brief Modulo the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <same_dimensions<A> L>
void assign_mod_to(L&& lhs) const {
std_mod_evaluate(*this, lhs);
}
Expand Down Expand Up @@ -373,10 +371,8 @@ auto ifft_2d_real(A&& a, C&& c) {
* \param a The input expression
* \return an expression representing several 1D FFT of a
*/
template <etl_expr A>
template <matrix A>
fft_expr<detail::build_type<A>, detail::fft_value_type<A>, detail::fft1_many_impl> fft_1d_many(A&& a) {
static_assert(decay_traits<A>::dimensions() >= 2, "fft_many requires at least 2D matrices");

return fft_expr<detail::build_type<A>, detail::fft_value_type<A>, detail::fft1_many_impl>{a};
}

Expand All @@ -389,9 +385,8 @@ fft_expr<detail::build_type<A>, detail::fft_value_type<A>, detail::fft1_many_imp
* \param c The result
* \return an expression representing several 1D FFT of a
*/
template <etl_expr A, etl_expr C>
template <matrix A, matrix C>
auto fft_1d_many(A&& a, C&& c) {
static_assert(decay_traits<A>::dimensions() >= 2 && decay_traits<C>::dimensions() >= 2, "fft_many requires at least 2D matrices");
validate_assign(c, a);

c = fft_1d_many(a);
Expand All @@ -406,10 +401,8 @@ auto fft_1d_many(A&& a, C&& c) {
* \param a The input expression
* \return an expression representing several 1D FFT of a
*/
template <etl_expr A>
template <matrix A>
fft_expr<detail::build_type<A>, detail::ifft_value_type<A>, detail::ifft1_many_impl> ifft_1d_many(A&& a) {
static_assert(decay_traits<A>::dimensions() >= 2, "ifft_many requires at least 2D matrices");

return fft_expr<detail::build_type<A>, detail::ifft_value_type<A>, detail::ifft1_many_impl>{a};
}

Expand All @@ -422,9 +415,8 @@ fft_expr<detail::build_type<A>, detail::ifft_value_type<A>, detail::ifft1_many_i
* \param c The result
* \return an expression representing several 1D FFT of a
*/
template <etl_expr A, etl_expr C>
template <matrix A, matrix C>
auto ifft_1d_many(A&& a, C&& c) {
static_assert(decay_traits<A>::dimensions() >= 2 && decay_traits<C>::dimensions() >= 2, "ifft_many requires at least 2D matrices");
validate_assign(c, a);

c = ifft_1d_many(a);
Expand All @@ -439,10 +431,8 @@ auto ifft_1d_many(A&& a, C&& c) {
* \param a The input expression
* \return an expression representing several 2D FFT of a
*/
template <etl_expr A>
template <deep_mat A>
fft_expr<detail::build_type<A>, detail::fft_value_type<A>, detail::fft2_many_impl> fft_2d_many(A&& a) {
static_assert(decay_traits<A>::dimensions() >= 3, "fft_many requires at least 3D matrices");

return fft_expr<detail::build_type<A>, detail::fft_value_type<A>, detail::fft2_many_impl>{a};
}

Expand All @@ -455,9 +445,8 @@ fft_expr<detail::build_type<A>, detail::fft_value_type<A>, detail::fft2_many_imp
* \param c The result
* \return an expression representing several 2D FFT of a
*/
template <etl_expr A, etl_expr C>
template <deep_mat A, deep_mat C>
auto fft_2d_many(A&& a, C&& c) {
static_assert(decay_traits<A>::dimensions() >= 3 && decay_traits<C>::dimensions() >= 3, "fft_many requires at least 3D matrices");
validate_assign(c, a);

c = fft_2d_many(a);
Expand All @@ -472,10 +461,8 @@ auto fft_2d_many(A&& a, C&& c) {
* \param a The input expression
* \return an expression representing several 2D FFT of a
*/
template <etl_expr A>
template <deep_mat A>
fft_expr<detail::build_type<A>, detail::ifft_value_type<A>, detail::ifft2_many_impl> ifft_2d_many(A&& a) {
static_assert(decay_traits<A>::dimensions() >= 3, "ifft_many requires at least 3D matrices");

return fft_expr<detail::build_type<A>, detail::ifft_value_type<A>, detail::ifft2_many_impl>{a};
}

Expand All @@ -488,9 +475,8 @@ fft_expr<detail::build_type<A>, detail::ifft_value_type<A>, detail::ifft2_many_i
* \param c The result
* \return an expression representing several 2D FFT of a
*/
template <etl_expr A, etl_expr C>
template <deep_mat A, deep_mat C>
auto ifft_2d_many(A&& a, C&& c) {
static_assert(decay_traits<A>::dimensions() >= 3 && decay_traits<C>::dimensions() >= 3, "ifft_many requires at least 3D matrices");
validate_assign(c, a);

c = ifft_2d_many(a);
Expand Down
2 changes: 1 addition & 1 deletion include/etl/expr/pool_upsample_2d_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace etl {
* \tparam B The output type
* \tparam C The errors type
*/
template <etl_expr A, etl_expr B, etl_expr C, size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, bool Max>
template <etl_expr A, same_dimensions<A> B, same_dimensions<A> C, size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, bool Max>
struct pool_upsample_2d_expr : base_temporary_expr_tern<pool_upsample_2d_expr<A, B, C, C1, C2, S1, S2, P1, P2, Max>, A, B, C> {
using value_type = value_t<A>; ///< The type of value of the expression
using sub_traits = etl::decay_traits<A>; ///< The traits of the first sub type
Expand Down
8 changes: 2 additions & 6 deletions include/etl/expr/pool_upsample_3d_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace etl {
* \tparam B The output type
* \tparam C The errors type
*/
template <etl_expr A, etl_expr B, etl_expr C, size_t C1, size_t C2, size_t C3, bool Max>
template <etl_expr A, same_dimensions<A> B, same_dimensions<A> C, size_t C1, size_t C2, size_t C3, bool Max>
struct pool_upsample_3d_expr : base_temporary_expr_tern<pool_upsample_3d_expr<A, B, C, C1, C2, C3, Max>, A, B, C> {
using value_type = value_t<A>; ///< The type of value of the expression
using sub_traits = etl::decay_traits<A>; ///< The traits of the first sub type
Expand Down Expand Up @@ -52,14 +52,10 @@ struct pool_upsample_3d_expr : base_temporary_expr_tern<pool_upsample_3d_expr<A,
* \param a The input matrix
* \þaram c The output matrix
*/
template <etl_expr R>
template <same_dimensions<A> R>
static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c, [[maybe_unused]] const R& result) {
static constexpr size_t D = etl::decay_traits<A>::dimensions();

static_assert(etl::decay_traits<B>::dimensions() == D, "Invalid dimensions in max_pool_upsampl_3d");
static_assert(etl::decay_traits<C>::dimensions() == D, "Invalid dimensions in max_pool_upsampl_3d");
static_assert(etl::decay_traits<R>::dimensions() == D, "Invalid dimensions in max_pool_upsampl_3d");

if constexpr (all_fast<A, B, C, R>) {
static_assert(etl::decay_traits<R>::size() == etl::decay_traits<A>::size(), "max_pool_upsample_3d:A and R must have the same size");
static_assert(etl::decay_traits<B>::size() == etl::decay_traits<C>::size(), "max_pool_upsample_3d:B and C must have the same size");
Expand Down

0 comments on commit 5d8cc2e

Please sign in to comment.