Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wichtounet committed Oct 30, 2023
1 parent edd87fd commit dd93f1a
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions include/etl/impl/std/avg_pooling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#pragma once

#include "etl/concepts.hpp"
namespace etl::impl::standard {

/*!
Expand Down Expand Up @@ -122,7 +123,7 @@ struct avg_pool_2d {
* \tparam C1 The first dimension pooling ratio
* \tparam C2 The second dimension pooling ratio
*/
template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, typename A, typename M, cpp_enable_iff(is_2d<A>)>
template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, etl_2d A, typename M>
static void apply(const A& sub, M&& m) {
const size_t o1 = (etl::dim<0>(sub) - C1 + 2 * P1) / S1 + 1;
const size_t o2 = (etl::dim<1>(sub) - C2 + 2 * P2) / S2 + 1;
Expand Down Expand Up @@ -239,7 +240,7 @@ struct avg_pool_2d {
* \param c1 The first dimension pooling ratio
* \param c2 The second dimension pooling ratio
*/
template <typename A, typename M, cpp_enable_iff(is_2d<A>)>
template <etl_2d A, typename M>
static void apply(const A& sub, M&& m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
const size_t o1 = (etl::dim<0>(sub) - c1 + 2 * p1) / s1 + 1;
const size_t o2 = (etl::dim<1>(sub) - c2 + 2 * p2) / s2 + 1;
Expand Down Expand Up @@ -292,7 +293,7 @@ struct avg_pool_2d {
* \tparam C1 The first dimension pooling ratio
* \tparam C2 The second dimension pooling ratio
*/
template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, typename A, typename M, cpp_enable_iff(is_3d<A>)>
template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, etl_3d A, typename M>
static void apply(const A& sub, M&& m) {
auto batch_fun_n = [&](const size_t first, const size_t last) {
if (last - first) {
Expand Down Expand Up @@ -325,7 +326,7 @@ struct avg_pool_2d {
* \param c1 The first dimension pooling ratio
* \param c2 The second dimension pooling ratio
*/
template <typename A, typename M, cpp_enable_iff(is_3d<A>)>
template <etl_3d A, typename M>
static void apply(const A& sub, M&& m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
auto batch_fun_n = [&](const size_t first, const size_t last) {
if (last - first) {
Expand Down Expand Up @@ -365,7 +366,7 @@ struct avg_pool_2d {
* \tparam C1 The first dimension pooling ratio
* \tparam C2 The second dimension pooling ratio
*/
template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, typename A, typename M, cpp_enable_iff(is_4d<A>)>
template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, etl_4d A, typename M>
static void apply(const A& sub, M&& m) {
auto batch_fun_n = [&](const size_t first, const size_t last) {
if (last - first) {
Expand Down Expand Up @@ -401,7 +402,7 @@ struct avg_pool_2d {
* \param c1 The first dimension pooling ratio
* \param c2 The second dimension pooling ratio
*/
template <typename A, typename M, cpp_enable_iff(is_4d<A>)>
template <etl_4d A, typename M>
static void apply(const A& sub, M&& m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
auto batch_fun_n = [&](const size_t first, const size_t last) {
if (last - first) {
Expand Down Expand Up @@ -441,7 +442,7 @@ struct avg_pool_2d {
* \tparam S1 The first dimension stride
* \tparam S2 The second dimension stride
*/
template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, typename A, typename M, cpp_enable_iff(!is_2d<A> && !is_3d<A> && !is_4d<A>)>
template <size_t C1, size_t C2, size_t S1, size_t S2, size_t P1, size_t P2, etl_5d_and_plus A, typename M>
static void apply(const A& sub, M&& m) {
for (size_t i = 0; i < etl::dim<0>(sub); ++i) {
apply<C1, C2, S1, S2, P1, P2>(sub(i), m(i));
Expand All @@ -455,7 +456,7 @@ struct avg_pool_2d {
* \param c1 The first dimension pooling ratio
* \param c2 The second dimension pooling ratio
*/
template <typename A, typename M, cpp_enable_iff(!is_2d<A> && !is_3d<A> && !is_4d<A>)>
template <etl_5d_and_plus A, typename M>
static void apply(const A& sub, M&& m, size_t c1, size_t c2, size_t s1, size_t s2, size_t p1, size_t p2) {
for (size_t i = 0; i < etl::dim<0>(sub); ++i) {
apply(sub(i), m(i), c1, c2, s1, s2, p1, p2);
Expand Down Expand Up @@ -581,9 +582,8 @@ struct avg_pool_3d {
size_t P1,
size_t P2,
size_t P3,
typename A,
typename M,
cpp_enable_iff(is_3d<A>)>
etl_3d A,
typename M>
static void apply(const A& sub, M&& m) {
const size_t o1 = (etl::dim<0>(sub) - C1 + 2 * P1) / S1 + 1;
const size_t o2 = (etl::dim<1>(sub) - C2 + 2 * P2) / S2 + 1;
Expand Down Expand Up @@ -715,7 +715,7 @@ struct avg_pool_3d {
* \param c2 The second dimension pooling ratio
* \param c3 The third dimension pooling ratio
*/
template <typename A, typename M, cpp_enable_iff(is_3d<A>)>
template <etl_3d A, typename M>
static void apply(const A& sub, M&& m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3) {
const size_t o1 = (etl::dim<0>(sub) - c1 + 2 * p1) / s1 + 1;
const size_t o2 = (etl::dim<1>(sub) - c2 + 2 * p2) / s2 + 1;
Expand Down Expand Up @@ -805,9 +805,8 @@ struct avg_pool_3d {
size_t P1,
size_t P2,
size_t P3,
typename A,
typename M,
cpp_enable_iff(is_4d<A>)>
etl_4d A,
typename M>
static void apply(const A& sub, M&& m) {
auto batch_fun_n = [&](const size_t first, const size_t last) {
if (last - first) {
Expand Down Expand Up @@ -843,7 +842,7 @@ struct avg_pool_3d {
* \param c2 The second dimension pooling ratio
* \param c3 The third dimension pooling ratio
*/
template <typename A, typename M, cpp_enable_iff(is_4d<A>)>
template <etl_4d A, typename M>
static void apply(const A& sub, M&& m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3) {
auto batch_fun_n = [&](const size_t first, const size_t last) {
if (last - first) {
Expand Down Expand Up @@ -889,9 +888,8 @@ struct avg_pool_3d {
size_t P1,
size_t P2,
size_t P3,
typename A,
typename M,
cpp_enable_iff(!is_3d<A> && !is_4d<A>)>
etl_5d_and_plus A,
typename M>
static void apply(const A& sub, M&& m) {
for (size_t i = 0; i < etl::dim<0>(sub); ++i) {
apply<C1, C2, C3, S1, S2, S3, P1, P2, P3>(sub(i), m(i));
Expand All @@ -906,7 +904,7 @@ struct avg_pool_3d {
* \param c2 The second dimension pooling ratio
* \param c3 The third dimension pooling ratio
*/
template <typename A, typename M, cpp_enable_iff(!is_3d<A> && !is_4d<A>)>
template <etl_5d_and_plus A, typename M>
static void apply(const A& sub, M&& m, size_t c1, size_t c2, size_t c3, size_t s1, size_t s2, size_t s3, size_t p1, size_t p2, size_t p3) {
for (size_t i = 0; i < etl::dim<0>(sub); ++i) {
apply(sub(i), m(i), c1, c2, c3, s1, s2, s3, p1, p2, p3);
Expand Down

0 comments on commit dd93f1a

Please sign in to comment.