Skip to content

Commit

Permalink
Implement axis_function_mask (#103)
Browse files Browse the repository at this point in the history
Implement axis_function_mask
  • Loading branch information
martinRenou authored and JohanMabille committed Jun 18, 2018
1 parent 93d78fd commit 28dded7
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 84 deletions.
137 changes: 63 additions & 74 deletions include/xframe/xaxis_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define XFRAME_XAXIS_FUNCTION_HPP

#include "xtensor/xoptional.hpp"
#include "xtensor/xgenerator.hpp"

#include "xframe_expression.hpp"
#include "xframe_utils.hpp"
Expand Down Expand Up @@ -59,47 +60,6 @@ namespace xf
functor_type m_f;
};

/**************************
* xaxis_function_wrapper *
**************************/

template <class AF, class DM>
class xaxis_function_wrapper
{
public:

using self_type = xaxis_function_wrapper<AF, DM>;

using axis_function_type = std::remove_reference_t<AF>;

using value_type = typename axis_function_type::value_type;
using reference = typename axis_function_type::reference;
using const_reference = typename axis_function_type::const_reference;
using pointer = typename axis_function_type::pointer;
using const_pointer = typename axis_function_type::const_pointer;
using name_type = typename axis_function_type::name_type;
using size_type = typename axis_function_type::size_type;

template <std::size_t N = dynamic()>
using selector_sequence_type = detail::xselector_sequence_t<std::pair<name_type, size_type>, N>;

xaxis_function_wrapper(AF&& axis_function, DM&& dim_mapping) noexcept;

template <std::size_t N = dynamic()>
const_reference operator()(const selector_sequence_type<N>& selector) const;

template <class... Args>
const_reference operator()(Args... args) const;

private:

template <class... Args, std::size_t... I>
selector_sequence_type<sizeof...(Args)> make_selector(std::index_sequence<I...>, Args&&... args) const;

AF m_axis_function;
DM m_dimension_mapping;
};

/*********************************
* xaxis_function implementation *
*********************************/
Expand Down Expand Up @@ -130,51 +90,80 @@ namespace xf
#endif
}

/*********************************
* xaxis_function implementation *
*********************************/
/**********************
* axis_function_mask *
**********************/

template <class AF, class DM>
inline xaxis_function_wrapper<AF, DM>::xaxis_function_wrapper(AF&& axis_function, DM&& dim_mapping) noexcept
: m_axis_function(std::forward<AF>(axis_function)),
m_dimension_mapping(std::forward<DM>(dim_mapping))
{
}

template <class AF, class DM>
template <std::size_t N>
inline auto xaxis_function_wrapper<AF, DM>::operator()(const selector_sequence_type<N>& selector) const -> const_reference
namespace detail
{
template <class AF, class DM>
class axis_function_mask_impl
{
public:

using axis_function_type = std::remove_reference_t<AF>;

using value_type = typename axis_function_type::value_type;
using name_type = typename axis_function_type::name_type;
using size_type = typename axis_function_type::size_type;

template <std::size_t N = dynamic()>
using selector_sequence_type = detail::xselector_sequence_t<std::pair<name_type, size_type>, N>;

axis_function_mask_impl(AF&& axis_function, DM&& dim_mapping)
: m_axis_function(std::forward<AF>(axis_function)),
m_dimension_mapping(std::forward<DM>(dim_mapping))
{
}

template <class... Args>
inline value_type operator()(Args... args) const
{
auto selector = make_selector(std::make_index_sequence<sizeof...(Args)>(), args...);
#ifdef _MSC_VER
return m_axis_function.operator()<N>(selector);
return m_axis_function.operator()<sizeof...(Args)>(selector);
#else
return m_axis_function.template operator()<N>(selector);
return m_axis_function.template operator()<sizeof...(Args)>(selector);
#endif
}

template <class AF, class DM>
template <class... Args>
inline auto xaxis_function_wrapper<AF, DM>::operator()(Args... args) const -> const_reference
{
auto selector = make_selector(std::make_index_sequence<sizeof...(Args)>(), args...);
}

template <class It>
inline value_type element(It first, It last) const
{
// TODO avoid dynamic allocation
auto selector = selector_sequence_type<dynamic()>();
std::size_t i = 0;
for (It it = first; it != last; ++it)
{
selector.push_back(std::make_pair(m_dimension_mapping.label(i++), static_cast<size_type>(*it)));
}
#ifdef _MSC_VER
return m_axis_function.operator()<sizeof...(Args)>(selector);
return m_axis_function.operator()<dynamic()>(selector);
#else
return m_axis_function.template operator()<sizeof...(Args)>(selector);
return m_axis_function.template operator()<dynamic()>(selector);
#endif
}
}

template <class AF, class DM>
template <class... Args, std::size_t... I>
inline auto xaxis_function_wrapper<AF, DM>::make_selector(std::index_sequence<I...>, Args&&... args) const -> selector_sequence_type<sizeof...(Args)>
{
return {std::make_pair(m_dimension_mapping.label(I), static_cast<size_type>(args))...};
private:

AF m_axis_function;
DM m_dimension_mapping;

template <class... Args, std::size_t... I>
inline selector_sequence_type<sizeof...(Args)> make_selector(std::index_sequence<I...>, Args&&... args) const
{
return {std::make_pair(m_dimension_mapping.label(I), static_cast<size_type>(args))...};
}
};
}

template <class AF, class DM>
inline xaxis_function_wrapper<AF, DM> axis_function_wrapper(AF&& axis_function, DM&& dim_mapping)
template <class AF, class DM, class S>
inline auto axis_function_mask(AF&& axis_function, DM&& dim_mapping, const S& shape) noexcept
{
return xaxis_function_wrapper<AF, DM>(std::forward<AF>(axis_function), std::forward<DM>(dim_mapping));
return xt::detail::make_xgenerator(
detail::axis_function_mask_impl<AF, DM>(std::forward<AF>(axis_function), std::forward<DM>(dim_mapping)),
shape
);
}
}

Expand Down
60 changes: 50 additions & 10 deletions test/test_xaxis_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

#include "gtest/gtest.h"

#include "xtensor/xarray.hpp"
#include "xtensor/xfunction.hpp"

#include "xframe/xnamed_axis.hpp"

#include "test_fixture.hpp"
Expand Down Expand Up @@ -75,19 +78,56 @@ namespace xf
EXPECT_EQ(func2({{"abs", 10}, {"ord", 5}}), 29);
}

TEST(xaxis_function, wrapper)
TEST(xaxis_function, mask)
{
auto axis1 = named_axis(fstring("abs"), axis(15));
auto axis2 = named_axis(fstring("ord"), axis(10, 20, 1));
auto axis3 = named_axis(fstring("alt"), axis('b', 'j'));
auto axis1 = named_axis(fstring("abs"), axis({0, 2, 5}));
auto axis2 = named_axis(fstring("ord"), axis({'a', 'c', 'i'}));

auto array = xt::xarray<bool>({
{true, true, true},
{true, true, true},
{true, true, true}
});

auto mask = axis_function_mask(
equal(axis2, 'c') || equal(axis1, 0),
dimension_type({"abs", "ord"}),
array.shape()
);

auto wrapper = axis_function_wrapper(
equal(axis3, 'b') || axis3 >= 'g' && not_equal(axis3, 'i'),
dimension_type({"abs", "alt", "ord"})
auto expected = xt::xarray<bool>({
{ true, true, true},
{false, true, false},
{false, true, false}
});

EXPECT_EQ(mask, expected);
}

TEST(xaxis_function, mask_op)
{
auto axis1 = named_axis(fstring("abs"), axis({0, 2, 5}));
auto axis2 = named_axis(fstring("ord"), axis({'a', 'c', 'i'}));

auto array = xt::xarray<bool>({
{ true, true, false},
{ true, true, true},
{ true, true, true}
});

auto mask = axis_function_mask(
equal(axis2, 'i') || equal(axis1, 0),
dimension_type({"abs", "ord"}),
array.shape()
);

EXPECT_EQ(wrapper({{"ord", 5}, {"abs", 6}, {"alt", 0}}), wrapper(6, 0, 5));
EXPECT_EQ(wrapper({{"alt", 0}, {"abs", 6}, {"ord", 5}}), wrapper(6, 0, 5));
EXPECT_EQ(wrapper({{"alt", 1}, {"abs", 6}, {"ord", 0}}), wrapper(6, 1, 0));
auto expected = xt::xarray<bool>({
{ true, true, false},
{false, false, true},
{false, false, true}
});

xt::xarray<bool> val = array && mask;
EXPECT_EQ(val, expected);
}
}

0 comments on commit 28dded7

Please sign in to comment.