Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add setters and getters for some solver parameters #527

Merged
merged 7 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 111 additions & 58 deletions core/test/solver/bicg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/stop/residual_norm_reduction.hpp>


#include "core/test/utils.hpp"


namespace {


template <typename T>
class Bicg : public ::testing::Test {
protected:
using Mtx = gko::matrix::Dense<>;
using Solver = gko::solver::Bicg<>;
using value_type = T;
using Mtx = gko::matrix::Dense<value_type>;
using Solver = gko::solver::Bicg<value_type>;

Bicg()
: exec(gko::ReferenceExecutor::create()),
Expand All @@ -63,15 +68,15 @@ class Bicg : public ::testing::Test {
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec),
gko::stop::ResidualNormReduction<>::build()
.with_reduction_factor(1e-6)
.with_reduction_factor(gko::remove_complex<T>{1e-6})
.on(exec))
.on(exec)),
solver(bicg_factory->generate(mtx))
{}

std::shared_ptr<const gko::Executor> exec;
std::shared_ptr<Mtx> mtx;
std::unique_ptr<Solver::Factory> bicg_factory;
std::unique_ptr<typename Solver::Factory> bicg_factory;
std::unique_ptr<gko::LinOp> solver;

static void assert_same_matrices(const Mtx *m1, const Mtx *m2)
Expand All @@ -86,154 +91,202 @@ class Bicg : public ::testing::Test {
}
};

TYPED_TEST_CASE(Bicg, gko::test::ValueTypes);


TEST_F(Bicg, BicgFactoryKnowsItsExecutor)
TYPED_TEST(Bicg, BicgFactoryKnowsItsExecutor)
{
ASSERT_EQ(bicg_factory->get_executor(), exec);
ASSERT_EQ(this->bicg_factory->get_executor(), this->exec);
}


TEST_F(Bicg, BicgFactoryCreatesCorrectSolver)
TYPED_TEST(Bicg, BicgFactoryCreatesCorrectSolver)
{
ASSERT_EQ(solver->get_size(), gko::dim<2>(3, 3));
auto bicg_solver = static_cast<Solver *>(solver.get());
using Solver = typename TestFixture::Solver;

ASSERT_EQ(this->solver->get_size(), gko::dim<2>(3, 3));
auto bicg_solver = static_cast<Solver *>(this->solver.get());
ASSERT_NE(bicg_solver->get_system_matrix(), nullptr);
ASSERT_EQ(bicg_solver->get_system_matrix(), mtx);
ASSERT_EQ(bicg_solver->get_system_matrix(), this->mtx);
}


TEST_F(Bicg, CanBeCopied)
TYPED_TEST(Bicg, CanBeCopied)
{
auto copy = bicg_factory->generate(Mtx::create(exec));
using Mtx = typename TestFixture::Mtx;
using Solver = typename TestFixture::Solver;
auto copy = this->bicg_factory->generate(Mtx::create(this->exec));

copy->copy_from(solver.get());
copy->copy_from(this->solver.get());

ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
auto copy_mtx = static_cast<Solver *>(copy.get())->get_system_matrix();
assert_same_matrices(static_cast<const Mtx *>(copy_mtx.get()), mtx.get());
this->assert_same_matrices(static_cast<const Mtx *>(copy_mtx.get()),
this->mtx.get());
}


TEST_F(Bicg, CanBeMoved)
TYPED_TEST(Bicg, CanBeMoved)
{
auto copy = bicg_factory->generate(Mtx::create(exec));
using Mtx = typename TestFixture::Mtx;
using Solver = typename TestFixture::Solver;
auto copy = this->bicg_factory->generate(Mtx::create(this->exec));

copy->copy_from(std::move(solver));
copy->copy_from(std::move(this->solver));

ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
auto copy_mtx = static_cast<Solver *>(copy.get())->get_system_matrix();
assert_same_matrices(static_cast<const Mtx *>(copy_mtx.get()), mtx.get());
this->assert_same_matrices(static_cast<const Mtx *>(copy_mtx.get()),
this->mtx.get());
}


TEST_F(Bicg, CanBeCloned)
TYPED_TEST(Bicg, CanBeCloned)
{
auto clone = solver->clone();
using Mtx = typename TestFixture::Mtx;
using Solver = typename TestFixture::Solver;
auto clone = this->solver->clone();

ASSERT_EQ(clone->get_size(), gko::dim<2>(3, 3));
auto clone_mtx = static_cast<Solver *>(clone.get())->get_system_matrix();
assert_same_matrices(static_cast<const Mtx *>(clone_mtx.get()), mtx.get());
this->assert_same_matrices(static_cast<const Mtx *>(clone_mtx.get()),
this->mtx.get());
}


TEST_F(Bicg, CanBeCleared)
TYPED_TEST(Bicg, CanBeCleared)
{
solver->clear();
using Solver = typename TestFixture::Solver;
this->solver->clear();

ASSERT_EQ(solver->get_size(), gko::dim<2>(0, 0));
auto solver_mtx = static_cast<Solver *>(solver.get())->get_system_matrix();
ASSERT_EQ(this->solver->get_size(), gko::dim<2>(0, 0));
auto solver_mtx =
static_cast<Solver *>(this->solver.get())->get_system_matrix();
ASSERT_EQ(solver_mtx, nullptr);
}


TEST_F(Bicg, ApplyUsesInitialGuessReturnsTrue)
TYPED_TEST(Bicg, ApplyUsesInitialGuessReturnsTrue)
{
ASSERT_TRUE(solver->apply_uses_initial_guess());
using Solver = typename TestFixture::Solver;
ASSERT_TRUE(this->solver->apply_uses_initial_guess());
}


TEST_F(Bicg, CanSetPreconditionerGenerator)
TYPED_TEST(Bicg, CanSetPreconditionerGenerator)
{
using Solver = typename TestFixture::Solver;
using value_type = typename TestFixture::value_type;
auto bicg_factory =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec),
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec),
gko::stop::ResidualNormReduction<>::build()
.with_reduction_factor(1e-6)
.on(exec))
.with_preconditioner(Solver::build().on(exec))
.on(exec);
auto solver = bicg_factory->generate(mtx);
auto precond = static_cast<const gko::solver::Bicg<> *>(
static_cast<gko::solver::Bicg<> *>(solver.get())
.with_reduction_factor(
gko::remove_complex<value_type>(1e-6))
.on(this->exec))
.with_preconditioner(Solver::build().on(this->exec))
.on(this->exec);
auto solver = bicg_factory->generate(this->mtx);
auto precond = dynamic_cast<const gko::solver::Bicg<value_type> *>(
static_cast<gko::solver::Bicg<value_type> *>(solver.get())
->get_preconditioner()
.get());

ASSERT_NE(precond, nullptr);
ASSERT_EQ(precond->get_size(), gko::dim<2>(3, 3));
ASSERT_EQ(precond->get_system_matrix(), mtx);
ASSERT_EQ(precond->get_system_matrix(), this->mtx);
}


TEST_F(Bicg, CanSetPreconditionerInFactory)
TYPED_TEST(Bicg, CanSetPreconditionerInFactory)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<Solver> bicg_precond =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
.on(exec)
->generate(mtx);
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.on(this->exec)
->generate(this->mtx);

auto bicg_factory =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.with_generated_preconditioner(bicg_precond)
.on(exec);
auto solver = bicg_factory->generate(mtx);
.on(this->exec);
auto solver = bicg_factory->generate(this->mtx);
auto precond = solver->get_preconditioner();

ASSERT_NE(precond.get(), nullptr);
ASSERT_EQ(precond.get(), bicg_precond.get());
}


TEST_F(Bicg, ThrowsOnWrongPreconditionerInFactory)
TYPED_TEST(Bicg, CanSetCriteriaAgain)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<gko::stop::CriterionFactory> init_crit =
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec);
auto bicg_factory = Solver::build().with_criteria(init_crit).on(this->exec);

ASSERT_EQ((bicg_factory->get_parameters().criteria).back(), init_crit);

auto solver = bicg_factory->generate(this->mtx);
std::shared_ptr<gko::stop::CriterionFactory> new_crit =
gko::stop::Iteration::build().with_max_iters(5u).on(this->exec);

solver->set_stop_criterion_factory(new_crit);
auto new_crit_fac = solver->get_stop_criterion_factory();
auto niter =
static_cast<const gko::stop::Iteration::Factory *>(new_crit_fac.get())
->get_parameters()
.max_iters;

ASSERT_EQ(niter, 5);
}


TYPED_TEST(Bicg, ThrowsOnWrongPreconditionerInFactory)
{
std::shared_ptr<Mtx> wrong_sized_mtx = Mtx::create(exec, gko::dim<2>{1, 3});
using Mtx = typename TestFixture::Mtx;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> wrong_sized_mtx =
Mtx::create(this->exec, gko::dim<2>{1, 3});
std::shared_ptr<Solver> bicg_precond =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
.on(exec)
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.on(this->exec)
->generate(wrong_sized_mtx);

auto bicg_factory =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.with_generated_preconditioner(bicg_precond)
.on(exec);
.on(this->exec);

ASSERT_THROW(bicg_factory->generate(mtx), gko::DimensionMismatch);
ASSERT_THROW(bicg_factory->generate(this->mtx), gko::DimensionMismatch);
}


TEST_F(Bicg, CanSetPreconditioner)
TYPED_TEST(Bicg, CanSetPreconditioner)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<Solver> bicg_precond =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
.on(exec)
->generate(mtx);
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.on(this->exec)
->generate(this->mtx);

auto bicg_factory =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(exec))
.on(exec);
auto solver = bicg_factory->generate(mtx);
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.on(this->exec);
auto solver = bicg_factory->generate(this->mtx);
solver->set_preconditioner(bicg_precond);
auto precond = solver->get_preconditioner();

Expand Down
25 changes: 25 additions & 0 deletions core/test/solver/bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,31 @@ TYPED_TEST(Bicgstab, CanSetPreconditionerGenerator)
}


TYPED_TEST(Bicgstab, CanSetCriteriaAgain)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<gko::stop::CriterionFactory> init_crit =
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec);
auto bicgstab_factory =
Solver::build().with_criteria(init_crit).on(this->exec);

ASSERT_EQ((bicgstab_factory->get_parameters().criteria).back(), init_crit);

auto solver = bicgstab_factory->generate(this->mtx);
std::shared_ptr<gko::stop::CriterionFactory> new_crit =
gko::stop::Iteration::build().with_max_iters(5u).on(this->exec);

solver->set_stop_criterion_factory(new_crit);
auto new_crit_fac = solver->get_stop_criterion_factory();
auto niter =
static_cast<const gko::stop::Iteration::Factory *>(new_crit_fac.get())
->get_parameters()
.max_iters;

ASSERT_EQ(niter, 5);
}


TYPED_TEST(Bicgstab, CanSetPreconditionerInFactory)
{
using Solver = typename TestFixture::Solver;
Expand Down
24 changes: 24 additions & 0 deletions core/test/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,30 @@ TYPED_TEST(Cg, CanSetPreconditionerInFactory)
}


TYPED_TEST(Cg, CanSetCriteriaAgain)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<gko::stop::CriterionFactory> init_crit =
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec);
auto cg_factory = Solver::build().with_criteria(init_crit).on(this->exec);

ASSERT_EQ((cg_factory->get_parameters().criteria).back(), init_crit);

auto solver = cg_factory->generate(this->mtx);
std::shared_ptr<gko::stop::CriterionFactory> new_crit =
gko::stop::Iteration::build().with_max_iters(5u).on(this->exec);

solver->set_stop_criterion_factory(new_crit);
auto new_crit_fac = solver->get_stop_criterion_factory();
auto niter =
static_cast<const gko::stop::Iteration::Factory *>(new_crit_fac.get())
->get_parameters()
.max_iters;

ASSERT_EQ(niter, 5);
}


TYPED_TEST(Cg, ThrowsOnWrongPreconditionerInFactory)
{
using Mtx = typename TestFixture::Mtx;
Expand Down
24 changes: 24 additions & 0 deletions core/test/solver/cgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,30 @@ TYPED_TEST(Cgs, CanSetPreconditionerGenerator)
}


TYPED_TEST(Cgs, CanSetCriteriaAgain)
{
using Solver = typename TestFixture::Solver;
std::shared_ptr<gko::stop::CriterionFactory> init_crit =
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec);
auto cgs_factory = Solver::build().with_criteria(init_crit).on(this->exec);

ASSERT_EQ((cgs_factory->get_parameters().criteria).back(), init_crit);

auto solver = cgs_factory->generate(this->mtx);
std::shared_ptr<gko::stop::CriterionFactory> new_crit =
gko::stop::Iteration::build().with_max_iters(5u).on(this->exec);

solver->set_stop_criterion_factory(new_crit);
auto new_crit_fac = solver->get_stop_criterion_factory();
auto niter =
static_cast<const gko::stop::Iteration::Factory *>(new_crit_fac.get())
->get_parameters()
.max_iters;

ASSERT_EQ(niter, 5);
}


TYPED_TEST(Cgs, CanSetPreconditionerInFactory)
{
using Solver = typename TestFixture::Solver;
Expand Down
Loading