Skip to content

Commit

Permalink
Better usage of concepts
Browse files Browse the repository at this point in the history
  • Loading branch information
wichtounet committed Dec 5, 2023
1 parent 577eec7 commit ed873da
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 106 deletions.
5 changes: 1 addition & 4 deletions include/etl/builder/mul_expression_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ namespace etl {
* \param b The right hand side matrix
* \return An expression representing the matrix-matrix multiplication of a and b
*/
template <typename A, typename B>
template <etl_2d A, etl_2d B>
auto lazy_mul(A&& a, B&& b) -> detail::stable_transform_binary_helper<A, B, mm_mul_transformer> {
static_assert(all_etl_expr<A, B>, "Matrix multiplication only supported for ETL expressions");
static_assert(all_2d<A, B>, "Matrix multiplication only works in 2D");

return detail::stable_transform_binary_helper<A, B, mm_mul_transformer>{mm_mul_transformer<detail::build_type<A>, detail::build_type<B>>(a, b)};
}

Expand Down
6 changes: 2 additions & 4 deletions include/etl/expr/batch_outer_product_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
struct batch_outer_product_expr : base_temporary_expr_bin<batch_outer_product_expr<A, B>, A, B> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = batch_outer_product_expr<A, B>; ///< The type of this expression
Expand Down Expand Up @@ -139,10 +139,8 @@ struct batch_outer_product_expr : base_temporary_expr_bin<batch_outer_product_ex
* \brief Assign to a matrix of the same storage order
* \param c The expression to which assign
*/
template <typename C>
template <etl_expr C>
void assign_to(C&& c) const {
static_assert(all_etl_expr<A, B, C>, "batch_outer_product only supported for ETL expressions");

inc_counter("temp:assign");

auto& a = this->a();
Expand Down
12 changes: 3 additions & 9 deletions include/etl/expr/bias_add_2d_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
struct bias_add_2d_expr : base_temporary_expr_bin<bias_add_2d_expr<A, B>, A, B> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = bias_add_2d_expr<A, B>; ///< The type of this expression
Expand Down Expand Up @@ -74,10 +74,8 @@ struct bias_add_2d_expr : base_temporary_expr_bin<bias_add_2d_expr<A, B>, A, B>
* \brief Assign to a matrix of the same storage order
* \param lhs The expression to which assign
*/
template <typename L>
template <etl_expr L>
void assign_to(L&& lhs) const {
static_assert(all_etl_expr<A, L>, "bias_add_2d only supported for ETL expressions");

inc_counter("temp:assign");

auto& a = this->a();
Expand Down Expand Up @@ -380,12 +378,8 @@ struct etl_traits<etl::bias_add_2d_expr<A, B>> {
* \param biases The vector of biases
* \return The transpose of the given expression.
*/
template <typename E, typename B>
template <etl_2d E, etl_1d B>
bias_add_2d_expr<detail::build_type<E>, detail::build_type<B>> bias_add_2d(const E& x, const B& biases) {
static_assert(all_etl_expr<E, B>, "etl::bias_add_2d can only be used on ETL expressions");
static_assert(is_2d<E>, "etl::bias_add_2d is only defined for 2D input");
static_assert(is_1d<B>, "etl::bias_add_2d is only defined for 1D bias vector");

return bias_add_2d_expr<detail::build_type<E>, detail::build_type<B>>{x, biases};
}

Expand Down
22 changes: 6 additions & 16 deletions include/etl/expr/conv_2d_same_deep_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B, bool Flipped>
template <etl_expr A, etl_expr B, bool Flipped>
struct conv_2d_same_deep_expr : base_temporary_expr_bin<conv_2d_same_deep_expr<A, B, Flipped>, A, B> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = conv_2d_same_deep_expr<A, B, Flipped>; ///< The type of this expression
Expand Down Expand Up @@ -66,10 +66,8 @@ struct conv_2d_same_deep_expr : base_temporary_expr_bin<conv_2d_same_deep_expr<A
* \brief Assign to a matrix of the full storage order
* \param c The expression to which assign
*/
template <typename C>
template <etl_expr C>
void assign_to(C&& c) const {
static_assert(all_etl_expr<A, B, C>, "conv2_same_deep only supported for ETL expressions");

inc_counter("temp:assign");

auto& a = this->a();
Expand Down Expand Up @@ -241,10 +239,8 @@ struct etl_traits<etl::conv_2d_same_deep_expr<A, B, Flipped>> {
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
conv_2d_same_deep_expr<detail::build_type<A>, detail::build_type<B>, false> conv_2d_same_deep(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_2d_same_deep_expr<detail::build_type<A>, detail::build_type<B>, false>{a, b};
}

Expand All @@ -257,10 +253,8 @@ conv_2d_same_deep_expr<detail::build_type<A>, detail::build_type<B>, false> conv
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B, typename C>
template <etl_expr A, etl_expr B, etl_expr C>
auto conv_2d_same_deep(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_2d_same_deep(a, b);

return c;
Expand All @@ -274,10 +268,8 @@ auto conv_2d_same_deep(A&& a, B&& b, C&& c) {
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
conv_2d_same_deep_expr<detail::build_type<A>, detail::build_type<B>, true> conv_2d_same_deep_flipped(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_2d_same_deep_expr<detail::build_type<A>, detail::build_type<B>, true>{a, b};
}

Expand All @@ -290,10 +282,8 @@ conv_2d_same_deep_expr<detail::build_type<A>, detail::build_type<B>, true> conv_
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B, typename C>
template <etl_expr A, etl_expr B, etl_expr C>
auto conv_2d_same_deep_flipped(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_2d_same_deep_flipped(a, b);

return c;
Expand Down
22 changes: 6 additions & 16 deletions include/etl/expr/conv_2d_same_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B, bool Flipped>
template <etl_expr A, etl_expr B, bool Flipped>
struct conv_2d_same_expr : base_temporary_expr_bin<conv_2d_same_expr<A, B, Flipped>, A, B> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = conv_2d_same_expr<A, B, Flipped>; ///< The type of this expression
Expand Down Expand Up @@ -65,10 +65,8 @@ struct conv_2d_same_expr : base_temporary_expr_bin<conv_2d_same_expr<A, B, Flipp
* \brief Assign to a matrix of the same storage order
* \param c The expression to which assign
*/
template <typename C>
template <etl_expr C>
void assign_to(C&& c) const {
static_assert(all_etl_expr<A, B, C>, "conv2_same only supported for ETL expressions");

inc_counter("temp:assign");

auto& a = this->a();
Expand Down Expand Up @@ -240,10 +238,8 @@ struct etl_traits<etl::conv_2d_same_expr<A, B, Flipped>> {
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
conv_2d_same_expr<detail::build_type<A>, detail::build_type<B>, false> conv_2d_same(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_2d_same_expr<detail::build_type<A>, detail::build_type<B>, false>{a, b};
}

Expand All @@ -259,10 +255,8 @@ conv_2d_same_expr<detail::build_type<A>, detail::build_type<B>, false> conv_2d_s
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B, typename C>
template <etl_expr A, etl_expr B, etl_expr C>
auto conv_2d_same(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_2d_same(a, b);

return c;
Expand All @@ -279,10 +273,8 @@ auto conv_2d_same(A&& a, B&& b, C&& c) {
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
conv_2d_same_expr<detail::build_type<A>, detail::build_type<B>, true> conv_2d_same_flipped(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_2d_same_expr<detail::build_type<A>, detail::build_type<B>, true>{a, b};
}

Expand All @@ -298,10 +290,8 @@ conv_2d_same_expr<detail::build_type<A>, detail::build_type<B>, true> conv_2d_sa
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B, typename C>
template <etl_expr A, etl_expr B, etl_expr C>
auto conv_2d_same_flipped(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_2d_same_flipped(a, b);

return c;
Expand Down
22 changes: 6 additions & 16 deletions include/etl/expr/conv_2d_same_multi_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B, bool Flipped>
template <etl_expr A, etl_expr B, bool Flipped>
struct conv_2d_same_multi_expr : base_temporary_expr_bin<conv_2d_same_multi_expr<A, B, Flipped>, A, B> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = conv_2d_same_multi_expr<A, B, Flipped>; ///< The type of this expression
Expand Down Expand Up @@ -141,10 +141,8 @@ struct conv_2d_same_multi_expr : base_temporary_expr_bin<conv_2d_same_multi_expr
* \brief Assign to a matrix of the same storage order
* \param conv The expression to which assign
*/
template <typename C>
template <etl_expr C>
void assign_to(C&& conv) const {
static_assert(all_etl_expr<A, B, C>, "conv2_same_multi only supported for ETL expressions");

inc_counter("temp:assign");

auto& input = this->a();
Expand Down Expand Up @@ -348,10 +346,8 @@ struct etl_traits<etl::conv_2d_same_multi_expr<A, B, Flipped>> {
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
conv_2d_same_multi_expr<detail::build_type<A>, detail::build_type<B>, false> conv_2d_same_multi(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_2d_same_multi_expr<detail::build_type<A>, detail::build_type<B>, false>{a, b};
}

Expand All @@ -367,10 +363,8 @@ conv_2d_same_multi_expr<detail::build_type<A>, detail::build_type<B>, false> con
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B, typename C>
template <etl_expr A, etl_expr B, etl_expr C>
auto conv_2d_same_multi(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_2d_same_multi(a, b);

return c;
Expand All @@ -387,10 +381,8 @@ auto conv_2d_same_multi(A&& a, B&& b, C&& c) {
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B>
template <etl_expr A, etl_expr B>
conv_2d_same_multi_expr<detail::build_type<A>, detail::build_type<B>, true> conv_2d_same_multi_flipped(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_2d_same_multi_expr<detail::build_type<A>, detail::build_type<B>, true>{a, b};
}

Expand All @@ -406,10 +398,8 @@ conv_2d_same_multi_expr<detail::build_type<A>, detail::build_type<B>, true> conv
*
* \return an expression representing the 'same' 1D convolution of a and b
*/
template <typename A, typename B, typename C>
template <etl_expr A, etl_expr B, etl_expr C>
auto conv_2d_same_multi_flipped(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_2d_same_multi_flipped(a, b);

return c;
Expand Down
22 changes: 6 additions & 16 deletions include/etl/expr/conv_2d_valid_deep_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, typename B, size_t S1, size_t S2, size_t P1, size_t P2, bool Flipped>
template <etl_expr A, etl_expr B, size_t S1, size_t S2, size_t P1, size_t P2, bool Flipped>
struct conv_2d_valid_deep_expr : base_temporary_expr_bin<conv_2d_valid_deep_expr<A, B, S1, S2, P1, P2, Flipped>, A, B> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = conv_2d_valid_deep_expr<A, B, S1, S2, P1, P2, Flipped>; ///< The type of this expression
Expand Down Expand Up @@ -68,10 +68,8 @@ struct conv_2d_valid_deep_expr : base_temporary_expr_bin<conv_2d_valid_deep_expr
* \brief Assign to a matrix of the valid storage order
* \param c The expression to which assign
*/
template <typename C>
template <etl_expr C>
void assign_to(C&& c) const {
static_assert(all_etl_expr<A, B, C>, "conv2_valid_deep only supported for ETL expressions");

inc_counter("temp:assign");

auto& a = this->a();
Expand Down Expand Up @@ -266,10 +264,8 @@ struct etl_traits<etl::conv_2d_valid_deep_expr<A, B, S1, S2, P1, P2, Flipped>> {
*
* \return an expression representing the 'valid' 1D convolution of a and b
*/
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, typename A, typename B>
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, etl_expr A, etl_expr B>
conv_2d_valid_deep_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1, P2, false> conv_2d_valid_deep(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_2d_valid_deep_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1, P2, false>{a, b};
}

Expand All @@ -285,10 +281,8 @@ conv_2d_valid_deep_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1
*
* \return an expression representing the 'valid' 1D convolution of a and b
*/
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, typename A, typename B, typename C>
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, etl_expr A, etl_expr B, etl_expr C>
auto conv_2d_valid_deep(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_2d_valid_deep<S1, S2, P1, P2>(a, b);

return c;
Expand All @@ -305,10 +299,8 @@ auto conv_2d_valid_deep(A&& a, B&& b, C&& c) {
*
* \return an expression representing the 'valid' 1D convolution of a and b
*/
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, typename A, typename B>
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, etl_expr A, etl_expr B>
conv_2d_valid_deep_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1, P2, true> conv_2d_valid_deep_flipped(A&& a, B&& b) {
static_assert(all_etl_expr<A, B>, "Convolution only supported for ETL expressions");

return conv_2d_valid_deep_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1, P2, true>{a, b};
}

Expand All @@ -324,10 +316,8 @@ conv_2d_valid_deep_expr<detail::build_type<A>, detail::build_type<B>, S1, S2, P1
*
* \return an expression representing the 'valid' 1D convolution of a and b
*/
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, typename A, typename B, typename C>
template <size_t S1 = 1, size_t S2 = 1, size_t P1 = 0, size_t P2 = 0, etl_expr A, etl_expr B, etl_expr C>
auto conv_2d_valid_deep_flipped(A&& a, B&& b, C&& c) {
static_assert(all_etl_expr<A, B, C>, "Convolution only supported for ETL expressions");

c = conv_2d_valid_deep_flipped<S1, S2, P1, P2>(a, b);

return c;
Expand Down
5 changes: 2 additions & 3 deletions include/etl/expr/convmtx_2d_expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace etl {
* \brief A transposition expression.
* \tparam A The transposed type
*/
template <typename A, size_t K1, size_t K2>
template <etl_expr A, size_t K1, size_t K2>
struct convmtx_2d_expr : base_temporary_expr_un<convmtx_2d_expr<A, K1, K2>, A> {
using value_type = value_t<A>; ///< The type of value of the expression
using this_type = convmtx_2d_expr<A, K1, K2>; ///< The type of this expression
Expand Down Expand Up @@ -47,9 +47,8 @@ struct convmtx_2d_expr : base_temporary_expr_un<convmtx_2d_expr<A, K1, K2>, A> {
* \brief Assign to a matrix of the same storage order
* \param c The expression to which assign
*/
template <typename C>
template <etl_expr C>
void assign_to(C&& c) const {
static_assert(all_etl_expr<A, C>, "max_pool_2d only supported for ETL expressions");
static_assert(etl::dimensions<A>() == etl::dimensions<C>(), "max_pool_2d must be applied on matrices of same dimensionality");

inc_counter("temp:assign");
Expand Down
Loading

0 comments on commit ed873da

Please sign in to comment.