Skip to content

Commit

Permalink
Allow setting and getting of permutation masks.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Feb 28, 2020
1 parent a0a0a9e commit f7b6e80
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
81 changes: 81 additions & 0 deletions core/test/matrix/permutation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ TEST_F(Permutation, ReturnsNullValuesArrayWhenEmpty)
}


TEST_F(Permutation, FactorySetsCorrectPermuteMask)
{
auto m = gko::matrix::Permutation<i_type>::create(exec);
auto mask = m->get_permute_mask();

ASSERT_EQ(mask, gko::matrix::row_permute);
}


TEST_F(Permutation, CanBeConstructedWithSize)
{
auto m = gko::matrix::Permutation<i_type>::create(exec, gko::dim<2>{2, 3});
Expand All @@ -109,6 +118,33 @@ TEST_F(Permutation, CanBeConstructedWithSize)
}


TEST_F(Permutation, CanBeConstructedWithSizeAndMask)
{
auto m = gko::matrix::Permutation<i_type>::create(
exec, gko::dim<2>{2, 3}, gko::matrix::column_permute);

ASSERT_EQ(m->get_size(), gko::dim<2>(2, 3));
ASSERT_EQ(m->get_permutation_size(), 2);
ASSERT_EQ(m->get_permute_mask(), gko::matrix::column_permute);
}


TEST_F(Permutation, CanExplicitlyOverrideSetPermuteMask)
{
auto m = gko::matrix::Permutation<i_type>::create(
exec, gko::dim<2>{2, 3}, gko::matrix::column_permute);

auto mask = m->get_permute_mask();
ASSERT_EQ(mask, gko::matrix::column_permute);

m->set_permute_mask(gko::matrix::row_permute |
gko::matrix::inverse_permute);

auto s_mask = m->get_permute_mask();
ASSERT_EQ(s_mask, gko::matrix::row_permute | gko::matrix::inverse_permute);
}


TEST_F(Permutation, PermutationCanBeConstructedFromExistingData)
{
i_type data[] = {1, 0, 2};
Expand All @@ -120,6 +156,26 @@ TEST_F(Permutation, PermutationCanBeConstructedFromExistingData)
}


TEST_F(Permutation, SettingMaskDoesNotModifyData)
{
i_type data[] = {1, 0, 2};

auto m = gko::matrix::Permutation<i_type>::create(
exec, gko::dim<2>{3, 5}, gko::Array<i_type>::view(exec, 3, data));

auto mask = m->get_permute_mask();
ASSERT_EQ(m->get_const_permutation(), data);
ASSERT_EQ(mask, gko::matrix::row_permute);

m->set_permute_mask(gko::matrix::row_permute |
gko::matrix::inverse_permute);

auto s_mask = m->get_permute_mask();
ASSERT_EQ(s_mask, gko::matrix::row_permute | gko::matrix::inverse_permute);
ASSERT_EQ(m->get_const_permutation(), data);
}


TEST_F(Permutation, PermutationThrowsforWrongRowPermDimensions)
{
i_type data[] = {0, 2, 1};
Expand Down Expand Up @@ -163,6 +219,31 @@ TEST_F(Permutation, CanBeCopied)
}


TEST_F(Permutation, CopyingPreservesMask)
{
auto mtx_copy = gko::matrix::Permutation<i_type>::create(exec);

mtx_copy->copy_from(mtx.get());

auto o_mask = mtx->get_permute_mask();
auto n_mask = mtx_copy->get_permute_mask();
ASSERT_EQ(o_mask, gko::matrix::row_permute);
ASSERT_EQ(o_mask, n_mask);

mtx->set_permute_mask(gko::matrix::column_permute);

o_mask = mtx->get_permute_mask();
n_mask = mtx_copy->get_permute_mask();
ASSERT_EQ(o_mask, gko::matrix::column_permute);
ASSERT_NE(o_mask, n_mask);

mtx_copy->copy_from(mtx.get());

n_mask = mtx_copy->get_permute_mask();
ASSERT_EQ(o_mask, n_mask);
}


TEST_F(Permutation, CanBeMoved)
{
auto mtx_copy = gko::matrix::Permutation<i_type>::create(exec);
Expand Down
19 changes: 18 additions & 1 deletion include/ginkgo/core/matrix/permutation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ class Permutation : public EnableLinOp<Permutation<IndexType>>,
return permutation_.get_num_elems();
}

/**
* Get the permute masks
*
* @param permute_mask the permute masks
*/
mask_type get_permute_mask() { return enabled_permute_; }

/**
* Set the permute masks
*
* @param permute_mask the permute masks
*/
void set_permute_mask(mask_type permute_mask)
{
enabled_permute_ = permute_mask;
}


protected:
/**
Expand All @@ -122,7 +139,7 @@ class Permutation : public EnableLinOp<Permutation<IndexType>>,
* @param exec Executor associated to the LinOp
*/
Permutation(std::shared_ptr<const Executor> exec)
: Permutation(std::move(exec), dim<2>{})
: Permutation(std::move(exec), dim<2>{}, row_permute)
{}

/**
Expand Down

0 comments on commit f7b6e80

Please sign in to comment.