Skip to content

Commit

Permalink
Merge: Add setters and getters for criterion factories for solvers
Browse files Browse the repository at this point in the history
This PR adds setter and getter functions for the stopping criterion factories to the solvers so that the criteria can be reset after the solver has been generated. 

It also adds the missing BiCG typed tests.

Related PR: #527
  • Loading branch information
pratikvn authored May 20, 2020
2 parents fd146da + 48c590d commit 7034730
Show file tree
Hide file tree
Showing 15 changed files with 564 additions and 140 deletions.
169 changes: 111 additions & 58 deletions core/test/solver/bicg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/stop/residual_norm_reduction.hpp>


#include "core/test/utils.hpp"


namespace {


template <typename T>
class Bicg : public ::testing::Test {
protected:
using Mtx = gko::matrix::Dense<>;
using Solver = gko::solver::Bicg<>;
using value_type = T;
using Mtx = gko::matrix::Dense<value_type>;
using Solver = gko::solver::Bicg<value_type>;

Bicg()
: exec(gko::ReferenceExecutor::create()),
Expand All @@ -63,15 +68,15 @@ class Bicg : public ::testing::Test {
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec),
gko::stop::ResidualNormReduction<>::build()
.with_reduction_factor(1e-6)
.with_reduction_factor(gko::remove_complex<T>{1e-6})
.on(exec))
.on(exec)),
solver(bicg_factory->generate(mtx))
{}

std::shared_ptr<const gko::Executor> exec;
std::shared_ptr<Mtx> mtx;
std::unique_ptr<Solver::Factory> bicg_factory;
std::unique_ptr<typename Solver::Factory> bicg_factory;
std::unique_ptr<gko::LinOp> solver;

static void assert_same_matrices(const Mtx *m1, const Mtx *m2)
Expand All @@ -86,154 +91,202 @@ class Bicg : public ::testing::Test {
}
};

TYPED_TEST_CASE(Bicg, gko::test::ValueTypes);


TEST_F(Bicg, BicgFactoryKnowsItsExecutor)
TYPED_TEST(Bicg, BicgFactoryKnowsItsExecutor)
{
ASSERT_EQ(bicg_factory->get_executor(), exec);
ASSERT_EQ(this->bicg_factory->get_executor(), this->exec);
}


TEST_F(Bicg, BicgFactoryCreatesCorrectSolver)
TYPED_TEST(Bicg, BicgFactoryCreatesCorrectSolver)
{
ASSERT_EQ(solver->get_size(), gko::dim<2>(3, 3));
auto bicg_solver = static_cast<Solver *>(solver.get());
using Solver = typename TestFixture::Solver;

ASSERT_EQ(this->solver->get_size(), gko::dim<2>(3, 3));
auto bicg_solver = static_cast<Solver *>(this->solver.get());
ASSERT_NE(bicg_solver->get_system_matrix(), nullptr);
ASSERT_EQ(bicg_solver->get_system_matrix(), mtx);
ASSERT_EQ(bicg_solver->get_system_matrix(), this->mtx);
}


TEST_F(Bicg, CanBeCopied)
TYPED_TEST(Bicg, CanBeCopied)
{
auto copy = bicg_factory->generate(Mtx::create(exec));
using Mtx = typename TestFixture::Mtx;
using Solver = typename TestFixture::Solver;
auto copy = this->bicg_factory->generate(Mtx::create(this->exec));

copy->copy_from(solver.get());
copy->copy_from(this->solver.get());

ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
auto copy_mtx = static_cast<Solver *>(copy.get())->get_system_matrix();
assert_same_matrices(static_cast<const Mtx *>(copy_mtx.get()), mtx.get());
this->assert_same_matrices(static_cast<const Mtx *>(copy_mtx.get()),
this->mtx.get());
}


TEST_F(Bicg, CanBeMoved)
TYPED_TEST(Bicg, CanBeMoved)
{
auto copy = bicg_factory->generate(Mtx::create(exec));
using Mtx = typename TestFixture::Mtx;
using Solver = typename TestFixture::Solver;
auto copy = this->bicg_factory->generate(Mtx::create(this->exec));

copy->copy_from(std::move(solver));
copy->copy_from(std::move(this->solver));

ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
auto copy_mtx = static_cast<Solver *>(copy.get())->get_system_matrix();
assert_same_matrices(static_cast<const Mtx *>(copy_mtx.get()), mtx.get());
this->assert_same_matrices(static_cast<const Mtx *>(copy_mtx.get()),
this->mtx.get());
}


TEST_F(Bicg, CanBeCloned)
TYPED_TEST(Bicg, CanBeCloned)
{
auto clone = solver->clone();
using Mtx = typename TestFixture::Mtx;
using Solver = typename TestFixture::Solver;
auto clone = this->solver->clone();

ASSERT_EQ(clone->get_size(), gko::dim<2>(3, 3));
auto clone_mtx = static_cast<Solver *>(clone.get())->get_system_matrix();
assert_same_matrices(static_cast<const Mtx *>(clone_mtx.get()), mtx.get());
this->assert_same_matrices(static_cast<const Mtx *>(clone_mtx.get()),
this->mtx.get());
}


TEST_F(Bicg, CanBeCleared)
TYPED_TEST(Bicg, CanBeCleared)
{
solver->clear();
using Solver = typename TestFixture::Solver;
this->solver->clear();

ASSERT_EQ(solver->get_size(), gko::dim<2>(0, 0));
auto solver_mtx = static_cast<Solver *>(solver.get())->get_system_matrix();
ASSERT_EQ(this->solver->get_size(), gko::dim<2>(0, 0));
auto solver_mtx =
static_cast<Solver *>(this->solver.get())->get_system_matrix();
ASSERT_EQ(solver_mtx, nullptr);
}


TEST_F(Bicg, ApplyUsesInitialGuessReturnsTrue)
TYPED_TEST(Bicg, ApplyUsesInitialGuessReturnsTrue)
{
ASSERT_TRUE(solver->apply_uses_initial_guess());
using Solver = typename TestFixture::Solver;
ASSERT_TRUE(this->solver->apply_uses_initial_guess());
}


TEST_F(Bicg, CanSetPreconditionerGenerator)
TYPED_TEST(Bicg, CanSetPreconditionerGenerator)
{
using Solver = typename TestFixture::Solver;
using value_type = typename TestFixture::value_type;
auto bicg_factory =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec),
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec),
gko::stop::ResidualNormReduction<>::build()
.with_reduction_factor(1e-6)
.on(exec))
.with_preconditioner(Solver::build().on(exec))
.on(exec);
auto solver = bicg_factory->generate(mtx);
auto precond = static_cast<const gko::solver::Bicg<> *>(
static_cast<gko::solver::Bicg<> *>(solver.get())
.with_reduction_factor(
gko::remove_complex<value_type>(1e-6))
.on(this->exec))
.with_preconditioner(Solver::build().on(this->exec))
.on(this->exec);
auto solver = bicg_factory->generate(this->mtx);
auto precond = dynamic_cast<const gko::solver::Bicg<value_type> *>(
static_cast<gko::solver::Bicg<value_type> *>(solver.get())
->get_preconditioner()
.get());

ASSERT_NE(precond, nullptr);
ASSERT_EQ(precond->get_size(), gko::dim<2>(3, 3));
ASSERT_EQ(precond->get_system_matrix(), mtx);
ASSERT_EQ(precond->get_system_matrix(), this->mtx);
}


TEST_F(Bicg, CanSetPreconditionerInFactory)
TYPED_TEST(Bicg, CanSetPreconditionerInFactory)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<Solver> bicg_precond =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
.on(exec)
->generate(mtx);
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.on(this->exec)
->generate(this->mtx);

auto bicg_factory =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.with_generated_preconditioner(bicg_precond)
.on(exec);
auto solver = bicg_factory->generate(mtx);
.on(this->exec);
auto solver = bicg_factory->generate(this->mtx);
auto precond = solver->get_preconditioner();

ASSERT_NE(precond.get(), nullptr);
ASSERT_EQ(precond.get(), bicg_precond.get());
}


TEST_F(Bicg, ThrowsOnWrongPreconditionerInFactory)
TYPED_TEST(Bicg, CanSetCriteriaAgain)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<gko::stop::CriterionFactory> init_crit =
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec);
auto bicg_factory = Solver::build().with_criteria(init_crit).on(this->exec);

ASSERT_EQ((bicg_factory->get_parameters().criteria).back(), init_crit);

auto solver = bicg_factory->generate(this->mtx);
std::shared_ptr<gko::stop::CriterionFactory> new_crit =
gko::stop::Iteration::build().with_max_iters(5u).on(this->exec);

solver->set_stop_criterion_factory(new_crit);
auto new_crit_fac = solver->get_stop_criterion_factory();
auto niter =
static_cast<const gko::stop::Iteration::Factory *>(new_crit_fac.get())
->get_parameters()
.max_iters;

ASSERT_EQ(niter, 5);
}


TYPED_TEST(Bicg, ThrowsOnWrongPreconditionerInFactory)
{
std::shared_ptr<Mtx> wrong_sized_mtx = Mtx::create(exec, gko::dim<2>{1, 3});
using Mtx = typename TestFixture::Mtx;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> wrong_sized_mtx =
Mtx::create(this->exec, gko::dim<2>{1, 3});
std::shared_ptr<Solver> bicg_precond =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
.on(exec)
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.on(this->exec)
->generate(wrong_sized_mtx);

auto bicg_factory =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.with_generated_preconditioner(bicg_precond)
.on(exec);
.on(this->exec);

ASSERT_THROW(bicg_factory->generate(mtx), gko::DimensionMismatch);
ASSERT_THROW(bicg_factory->generate(this->mtx), gko::DimensionMismatch);
}


TEST_F(Bicg, CanSetPreconditioner)
TYPED_TEST(Bicg, CanSetPreconditioner)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<Solver> bicg_precond =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
.on(exec)
->generate(mtx);
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.on(this->exec)
->generate(this->mtx);

auto bicg_factory =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
.on(exec);
auto solver = bicg_factory->generate(mtx);
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.on(this->exec);
auto solver = bicg_factory->generate(this->mtx);
solver->set_preconditioner(bicg_precond);
auto precond = solver->get_preconditioner();

Expand Down
25 changes: 25 additions & 0 deletions core/test/solver/bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,31 @@ TYPED_TEST(Bicgstab, CanSetPreconditionerGenerator)
}


TYPED_TEST(Bicgstab, CanSetCriteriaAgain)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<gko::stop::CriterionFactory> init_crit =
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec);
auto bicgstab_factory =
Solver::build().with_criteria(init_crit).on(this->exec);

ASSERT_EQ((bicgstab_factory->get_parameters().criteria).back(), init_crit);

auto solver = bicgstab_factory->generate(this->mtx);
std::shared_ptr<gko::stop::CriterionFactory> new_crit =
gko::stop::Iteration::build().with_max_iters(5u).on(this->exec);

solver->set_stop_criterion_factory(new_crit);
auto new_crit_fac = solver->get_stop_criterion_factory();
auto niter =
static_cast<const gko::stop::Iteration::Factory *>(new_crit_fac.get())
->get_parameters()
.max_iters;

ASSERT_EQ(niter, 5);
}


TYPED_TEST(Bicgstab, CanSetPreconditionerInFactory)
{
using Solver = typename TestFixture::Solver;
Expand Down
24 changes: 24 additions & 0 deletions core/test/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,30 @@ TYPED_TEST(Cg, CanSetPreconditionerInFactory)
}


TYPED_TEST(Cg, CanSetCriteriaAgain)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<gko::stop::CriterionFactory> init_crit =
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec);
auto cg_factory = Solver::build().with_criteria(init_crit).on(this->exec);

ASSERT_EQ((cg_factory->get_parameters().criteria).back(), init_crit);

auto solver = cg_factory->generate(this->mtx);
std::shared_ptr<gko::stop::CriterionFactory> new_crit =
gko::stop::Iteration::build().with_max_iters(5u).on(this->exec);

solver->set_stop_criterion_factory(new_crit);
auto new_crit_fac = solver->get_stop_criterion_factory();
auto niter =
static_cast<const gko::stop::Iteration::Factory *>(new_crit_fac.get())
->get_parameters()
.max_iters;

ASSERT_EQ(niter, 5);
}


TYPED_TEST(Cg, ThrowsOnWrongPreconditionerInFactory)
{
using Mtx = typename TestFixture::Mtx;
Expand Down
24 changes: 24 additions & 0 deletions core/test/solver/cgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,30 @@ TYPED_TEST(Cgs, CanSetPreconditionerGenerator)
}


TYPED_TEST(Cgs, CanSetCriteriaAgain)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<gko::stop::CriterionFactory> init_crit =
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec);
auto cgs_factory = Solver::build().with_criteria(init_crit).on(this->exec);

ASSERT_EQ((cgs_factory->get_parameters().criteria).back(), init_crit);

auto solver = cgs_factory->generate(this->mtx);
std::shared_ptr<gko::stop::CriterionFactory> new_crit =
gko::stop::Iteration::build().with_max_iters(5u).on(this->exec);

solver->set_stop_criterion_factory(new_crit);
auto new_crit_fac = solver->get_stop_criterion_factory();
auto niter =
static_cast<const gko::stop::Iteration::Factory *>(new_crit_fac.get())
->get_parameters()
.max_iters;

ASSERT_EQ(niter, 5);
}


TYPED_TEST(Cgs, CanSetPreconditionerInFactory)
{
using Solver = typename TestFixture::Solver;
Expand Down
Loading

0 comments on commit 7034730

Please sign in to comment.