diff --git a/core/test/matrix/permutation.cpp b/core/test/matrix/permutation.cpp index 6545dcc5d6b..1681b2f9e51 100644 --- a/core/test/matrix/permutation.cpp +++ b/core/test/matrix/permutation.cpp @@ -100,6 +100,15 @@ TEST_F(Permutation, ReturnsNullValuesArrayWhenEmpty) } +TEST_F(Permutation, FactorySetsCorrectPermuteMask) +{ + auto m = gko::matrix::Permutation::create(exec); + auto mask = m->get_permute_mask(); + + ASSERT_EQ(mask, gko::matrix::row_permute); +} + + TEST_F(Permutation, CanBeConstructedWithSize) { auto m = gko::matrix::Permutation::create(exec, gko::dim<2>{2, 3}); @@ -109,6 +118,33 @@ TEST_F(Permutation, CanBeConstructedWithSize) } +TEST_F(Permutation, CanBeConstructedWithSizeAndMask) +{ + auto m = gko::matrix::Permutation::create( + exec, gko::dim<2>{2, 3}, gko::matrix::column_permute); + + ASSERT_EQ(m->get_size(), gko::dim<2>(2, 3)); + ASSERT_EQ(m->get_permutation_size(), 2); + ASSERT_EQ(m->get_permute_mask(), gko::matrix::column_permute); +} + + +TEST_F(Permutation, CanExplicitlyOverrideSetPermuteMask) +{ + auto m = gko::matrix::Permutation::create( + exec, gko::dim<2>{2, 3}, gko::matrix::column_permute); + + auto mask = m->get_permute_mask(); + ASSERT_EQ(mask, gko::matrix::column_permute); + + m->set_permute_mask(gko::matrix::row_permute | + gko::matrix::inverse_permute); + + auto s_mask = m->get_permute_mask(); + ASSERT_EQ(s_mask, gko::matrix::row_permute | gko::matrix::inverse_permute); +} + + TEST_F(Permutation, PermutationCanBeConstructedFromExistingData) { i_type data[] = {1, 0, 2}; @@ -120,6 +156,26 @@ TEST_F(Permutation, PermutationCanBeConstructedFromExistingData) } +TEST_F(Permutation, SettingMaskDoesNotModifyData) +{ + i_type data[] = {1, 0, 2}; + + auto m = gko::matrix::Permutation::create( + exec, gko::dim<2>{3, 5}, gko::Array::view(exec, 3, data)); + + auto mask = m->get_permute_mask(); + ASSERT_EQ(m->get_const_permutation(), data); + ASSERT_EQ(mask, gko::matrix::row_permute); + + m->set_permute_mask(gko::matrix::row_permute | + gko::matrix::inverse_permute); + + auto s_mask = m->get_permute_mask(); + ASSERT_EQ(s_mask, gko::matrix::row_permute | gko::matrix::inverse_permute); + ASSERT_EQ(m->get_const_permutation(), data); +} + + TEST_F(Permutation, PermutationThrowsforWrongRowPermDimensions) { i_type data[] = {0, 2, 1}; @@ -163,6 +219,31 @@ TEST_F(Permutation, CanBeCopied) } +TEST_F(Permutation, CopyingPreservesMask) +{ + auto mtx_copy = gko::matrix::Permutation::create(exec); + + mtx_copy->copy_from(mtx.get()); + + auto o_mask = mtx->get_permute_mask(); + auto n_mask = mtx_copy->get_permute_mask(); + ASSERT_EQ(o_mask, gko::matrix::row_permute); + ASSERT_EQ(o_mask, n_mask); + + mtx->set_permute_mask(gko::matrix::column_permute); + + o_mask = mtx->get_permute_mask(); + n_mask = mtx_copy->get_permute_mask(); + ASSERT_EQ(o_mask, gko::matrix::column_permute); + ASSERT_NE(o_mask, n_mask); + + mtx_copy->copy_from(mtx.get()); + + n_mask = mtx_copy->get_permute_mask(); + ASSERT_EQ(o_mask, n_mask); +} + + TEST_F(Permutation, CanBeMoved) { auto mtx_copy = gko::matrix::Permutation::create(exec); diff --git a/include/ginkgo/core/matrix/permutation.hpp b/include/ginkgo/core/matrix/permutation.hpp index f8c01ab24e6..a273d250017 100644 --- a/include/ginkgo/core/matrix/permutation.hpp +++ b/include/ginkgo/core/matrix/permutation.hpp @@ -114,6 +114,23 @@ class Permutation : public EnableLinOp>, return permutation_.get_num_elems(); } + /** + * Get the permute masks + * + * @param permute_mask the permute masks + */ + mask_type get_permute_mask() { return enabled_permute_; } + + /** + * Set the permute masks + * + * @param permute_mask the permute masks + */ + void set_permute_mask(mask_type permute_mask) + { + enabled_permute_ = permute_mask; + } + protected: /** @@ -122,7 +139,7 @@ class Permutation : public EnableLinOp>, * @param exec Executor associated to the LinOp */ Permutation(std::shared_ptr exec) - : Permutation(std::move(exec), dim<2>{}) + : Permutation(std::move(exec), dim<2>{}, row_permute) {} /**