Skip to content

Commit

Permalink
Better use of concepts
Browse files Browse the repository at this point in the history
  • Loading branch information
wichtounet committed Dec 5, 2023
1 parent 12ffde0 commit ec7e84b
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 72 deletions.
31 changes: 9 additions & 22 deletions include/etl/expr/bias_batch_var_4d_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
struct bias_batch_var_4d_expr : base_temporary_expr_bin<bias_batch_var_4d_expr<A, B>, A, B> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = bias_batch_var_4d_expr<A, B>; ///< The type of this expression
Expand Down Expand Up @@ -47,7 +47,7 @@ struct bias_batch_var_4d_expr : base_temporary_expr_bin<bias_batch_var_4d_expr<A
* \param a The input matrix
* \þaram c The output matrix
*/
template <typename C>
template <etl_expr C>
static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c) {
static_assert(etl::dimensions<C>() == 1, "The output of bias_batch_var_4d is a vector");
static_assert(etl::dimensions<A>() == 4, "The input of bias_batch_var_4d is a 2d matrix");
Expand All @@ -68,10 +68,8 @@ struct bias_batch_var_4d_expr : base_temporary_expr_bin<bias_batch_var_4d_expr<A
* \brief Assign to a matrix of the same storage order
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_to(L&& lhs) const {
static_assert(all_etl_expr<A, L>, "bias_batch_var_4d only supported for ETL expressions");

inc_counter("temp:assign");

auto& a = this->a();
Expand Down Expand Up @@ -142,10 +140,8 @@ struct bias_batch_var_4d_expr : base_temporary_expr_bin<bias_batch_var_4d_expr<A
* \brief Add to the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_add_to(L&& lhs) const {
static_assert(all_etl_expr<A, L>, "bias_batch_var_4d only supported for ETL expressions");

auto& a = this->a();
auto& b = this->b();

Expand Down Expand Up @@ -188,10 +184,8 @@ struct bias_batch_var_4d_expr : base_temporary_expr_bin<bias_batch_var_4d_expr<A
* \brief Sub from the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_sub_to(L&& lhs) const {
static_assert(all_etl_expr<A, L>, "bias_batch_var_4d only supported for ETL expressions");

auto& a = this->a();
auto& b = this->b();

Expand Down Expand Up @@ -234,7 +228,7 @@ struct bias_batch_var_4d_expr : base_temporary_expr_bin<bias_batch_var_4d_expr<A
* \brief Multiply the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_mul_to(L&& lhs) const {
static_assert(all_etl_expr<A, L>, "bias_batch_var_4d only supported for ETL expressions");

Expand Down Expand Up @@ -280,7 +274,7 @@ struct bias_batch_var_4d_expr : base_temporary_expr_bin<bias_batch_var_4d_expr<A
* \brief Divide the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_div_to(L&& lhs) const {
static_assert(all_etl_expr<A, L>, "bias_batch_var_4d only supported for ETL expressions");

Expand Down Expand Up @@ -326,10 +320,8 @@ struct bias_batch_var_4d_expr : base_temporary_expr_bin<bias_batch_var_4d_expr<A
* \brief Modulo the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_mod_to(L&& lhs) const {
static_assert(all_etl_expr<A, L>, "bias_batch_var_4d only supported for ETL expressions");

auto& a = this->a();
auto& b = this->b();

Expand Down Expand Up @@ -474,13 +466,8 @@ struct etl_traits<etl::bias_batch_var_4d_expr<A, B>> {
* \param value The expression
* \return The transpose of the given expression.
*/
template <typename A, typename B>
template <etl_4d A, etl_1d B>
bias_batch_var_4d_expr<detail::build_type<A>, detail::build_type<B>> bias_batch_var_4d(const A& a, const B& b) {
static_assert(is_etl_expr<A>, "etl::bias_batch_var_4d can only be used on ETL expressions");
static_assert(is_etl_expr<B>, "etl::bias_batch_var_4d can only be used on ETL expressions");
static_assert(is_4d<A>, "etl::bias_batch_var_4d is only defined for 4d input");
static_assert(is_1d<B>, "etl::bias_batch_var_4d is only defined for 1d mean");

return bias_batch_var_4d_expr<detail::build_type<A>, detail::build_type<B>>{a, b};
}

Expand Down
22 changes: 6 additions & 16 deletions include/etl/expr/conv_4d_backward_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace etl {
* \tparam P2 The padding of the second dimension
* \tparam Flipped Indicates if Flipped already or not or not
*/
template <typename A, typename B, size_t S1, size_t S2, size_t P1, size_t P2, bool Flipped>
template <etl_expr A, etl_expr B, size_t S1, size_t S2, size_t P1, size_t P2, bool Flipped>
struct conv_4d_backward_expr : base_temporary_expr_bin<conv_4d_backward_expr<A, B, S1, S2, P1, P2, Flipped>, A, B> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = conv_4d_backward_expr<A, B, S1, S2, P1, P2, Flipped>; ///< The type of this expression
Expand Down Expand Up @@ -95,10 +95,8 @@ struct conv_4d_backward_expr : base_temporary_expr_bin<conv_4d_backward_expr<A,
* \brief Assign to a matrix
* \param conv The expression to which assign
*/
template <typename C>
template <etl_expr C>
void assign_to(C&& conv) const {
static_assert(all_etl_expr<A, B, C>, "conv4_backward only supported for ETL expressions");

inc_counter("temp:assign");

auto& input = this->a();
Expand Down Expand Up @@ -345,10 +343,8 @@ struct etl_traits<etl::conv_4d_backward_expr<A, B, S1, S2, P1, P2, Flipped>> {
*
* \return an expression representing the transposed convolution convolution of a and b
*/
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, typename A, typename B>
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, etl_expr A, etl_expr B>
conv_4d_backward_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1, P2, false> conv_4d_backward(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_4d_backward_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1, P2, false>{a, b};
}

Expand All @@ -365,10 +361,8 @@ conv_4d_backward_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1,
*
* \return an expression representing the transposed 2D convolution of a and b
*/
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, typename A, typename B, typename C>
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, etl_expr A, etl_expr B, etl_expr C>
auto conv_4d_backward(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_4d_backward<S1, S2, P1, P2>(a, b);

return c;
Expand All @@ -385,10 +379,8 @@ auto conv_4d_backward(A&& a, B&& b, C&& c) {
*
* \return an expression representing the transposed 2D convolution of a and b
*/
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, typename A, typename B>
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, etl_expr A, etl_expr B>
conv_4d_backward_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1, P2, true> conv_4d_backward_flipped(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_4d_backward_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1, P2, true>{a, b};
}

Expand All @@ -405,10 +397,8 @@ conv_4d_backward_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1,
*
* \return an expression representing the transposed 2D convolution of a and b
*/
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, typename A, typename B, typename C>
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, etl_expr A, etl_expr B, etl_expr C>
auto conv_4d_backward_flipped(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_4d_backward_flipped<S1, S2, P1, P2>(a, b);

return c;
Expand Down
26 changes: 12 additions & 14 deletions include/etl/expr/outer_product_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A, B> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = outer_product_expr<A, B>; ///< The type of this expression
Expand Down Expand Up @@ -54,7 +54,7 @@ struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A,
* \tparam C The type of c expression
* \return The implementation to use
*/
template <typename C>
template <etl_expr C>
static constexpr etl::outer_impl select_default_outer_impl() {
if (cblas_enabled) {
return etl::outer_impl::BLAS;
Expand All @@ -70,7 +70,7 @@ struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A,
* \tparam C The type of c expression
* \return The implementation to use
*/
template <typename C>
template <etl_expr C>
static etl::outer_impl select_outer_impl() {
if (local_context().outer_selector.forced) {
auto forced = local_context().outer_selector.impl;
Expand Down Expand Up @@ -102,7 +102,7 @@ struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A,
* \tparam C The type of c expression
* \return The implementation to use
*/
template <typename C>
template <etl_expr C>
static constexpr etl::outer_impl select_outer_impl() {
return select_default_outer_impl<C>();
}
Expand All @@ -113,10 +113,8 @@ struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A,
* \brief Assign to a matrix of the same storage order
* \param c The expression to which assign
*/
template <typename C>
template <etl_expr C>
void assign_to(C&& c) const {
static_assert(all_etl_expr<A, B, C>, "batch_outer_product only supported for ETL expressions");

inc_counter("temp:assign");

auto& a = this->a();
Expand All @@ -139,7 +137,7 @@ struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A,
* \brief Add to the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_add_to(L&& lhs) const {
std_add_evaluate(*this, lhs);
}
Expand All @@ -148,7 +146,7 @@ struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A,
* \brief Sub from the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_sub_to(L&& lhs) const {
std_sub_evaluate(*this, lhs);
}
Expand All @@ -157,7 +155,7 @@ struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A,
* \brief Multiply the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_mul_to(L&& lhs) const {
std_mul_evaluate(*this, lhs);
}
Expand All @@ -166,7 +164,7 @@ struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A,
* \brief Divide the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_div_to(L&& lhs) const {
std_div_evaluate(*this, lhs);
}
Expand All @@ -175,7 +173,7 @@ struct outer_product_expr : base_temporary_expr_bin<outer_product_expr<A, B>, A,
* \brief Modulo the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_mod_to(L&& lhs) const {
std_mod_evaluate(*this, lhs);
}
Expand Down Expand Up @@ -291,7 +289,7 @@ struct etl_traits<etl::outer_product_expr<A, B>> {
* \param b The right hand side matrix
* \return An expression representing the matrix-matrix multiplication of a and b
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
outer_product_expr<detail::build_type<A>, detail::build_type<B>> outer(A&& a, B&& b) {
return outer_product_expr<detail::build_type<A>, detail::build_type<B>>{a, b};
}
Expand All @@ -303,7 +301,7 @@ outer_product_expr<detail::build_type<A>, detail::build_type<B>> outer(A&& a, B&
* \param c The expression used to store the result
* \return An expression representing the matrix-matrix multiplication of a and b
*/
template <typename A, typename B, typename C>
template <etl_expr A, etl_expr B, etl_expr C>
auto outer(A&& a, B&& b, C&& c) {
c = outer(a, b);
return c;
Expand Down
18 changes: 8 additions & 10 deletions include/etl/expr/transpose_front_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace etl {
* \brief A transposition expression for the first layers.
* \tparam A The transposed type
*/
template <typename A>
template <etl_expr A>
struct transpose_front_expr : base_temporary_expr_un<transpose_front_expr<A>, A> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = transpose_front_expr<A>; ///< The type of this expression
Expand Down Expand Up @@ -46,7 +46,7 @@ struct transpose_front_expr : base_temporary_expr_un<transpose_front_expr<A>, A>
* \param a The input matrix
* \þaram c The output matrix
*/
template <typename C>
template <etl_expr C>
static void check([[maybe_unused]] const A& a, [[maybe_unused]] const C& c) {
if constexpr (all_fast<A, C>) {
static_assert(etl::dim<0, A>() == etl::dim<1, C>(), "Invalid dimensions for front transposition");
Expand All @@ -65,10 +65,8 @@ struct transpose_front_expr : base_temporary_expr_un<transpose_front_expr<A>, A>
* \brief Assign to a matrix of the same storage order
* \param c The expression to which assign
*/
template <typename C>
template <etl_expr C>
void assign_to(C&& lhs) const {
static_assert(all_etl_expr<A, C>, "Front Transpose only supported for ETL expressions");

auto& a = this->a();

check(a, lhs);
Expand Down Expand Up @@ -110,7 +108,7 @@ struct transpose_front_expr : base_temporary_expr_un<transpose_front_expr<A>, A>
* \brief Add to the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_add_to(L&& lhs) const {
std_add_evaluate(*this, lhs);
}
Expand All @@ -119,7 +117,7 @@ struct transpose_front_expr : base_temporary_expr_un<transpose_front_expr<A>, A>
* \brief Sub from the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_sub_to(L&& lhs) const {
std_sub_evaluate(*this, lhs);
}
Expand All @@ -128,7 +126,7 @@ struct transpose_front_expr : base_temporary_expr_un<transpose_front_expr<A>, A>
* \brief Multiply the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_mul_to(L&& lhs) const {
std_mul_evaluate(*this, lhs);
}
Expand All @@ -137,7 +135,7 @@ struct transpose_front_expr : base_temporary_expr_un<transpose_front_expr<A>, A>
* \brief Divide the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_div_to(L&& lhs) const {
std_div_evaluate(*this, lhs);
}
Expand All @@ -146,7 +144,7 @@ struct transpose_front_expr : base_temporary_expr_un<transpose_front_expr<A>, A>
* \brief Modulo the given left-hand-side expression
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_mod_to(L&& lhs) const {
std_mod_evaluate(*this, lhs);
}
Expand Down
Loading

0 comments on commit ec7e84b

Please sign in to comment.