Skip to content

Commit

Permalink
Fix some final review comments.
Browse files Browse the repository at this point in the history
+ Fix GKO_COMMA and generate method parameter issues thanks to Thomas.
  • Loading branch information
pratikvn committed Aug 15, 2019
1 parent 58e59f9 commit 3f08cf0
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 34 deletions.
8 changes: 3 additions & 5 deletions core/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,10 @@ GKO_REGISTER_OPERATION(solve, lower_trs::solve);


template <typename ValueType, typename IndexType>
void LowerTrs<ValueType, IndexType>::generate(
const matrix::Csr<ValueType, IndexType> *system_matrix,
const matrix::Dense<ValueType> *b)
void LowerTrs<ValueType, IndexType>::generate()
{
this->get_executor()->run(
lower_trs::make_generate(gko::lend(system_matrix), gko::lend(b)));
lower_trs::make_generate(gko::lend(system_matrix_), gko::lend(b_)));
}


Expand Down Expand Up @@ -95,7 +93,7 @@ void LowerTrs<ValueType, IndexType>::apply_impl(const LinOp *alpha,
auto x_clone = dense_x->clone();
this->apply(b, x_clone.get());
dense_x->scale(beta);
dense_x->add_scaled(alpha, x_clone.get());
dense_x->add_scaled(alpha, gko::lend(x_clone));
}


Expand Down
5 changes: 4 additions & 1 deletion include/ginkgo/core/base/polymorphic_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GKO_CORE_BASE_POLYMORPHIC_OBJECT_HPP_


#include <memory>


#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/base/utils.hpp>
#include <ginkgo/core/log/logger.hpp>
Expand Down Expand Up @@ -484,7 +487,7 @@ std::unique_ptr<const R, std::function<void(const R *)>> copy_and_convert_to(

/**
* Converts the object to R and places it on Executor exec. This is the version
* that takes in the shared_ptr and returns a shared_ptr
* that takes in the std::shared_ptr and returns a std::shared_ptr
*
* If the object is already of the requested type and on the requested executor,
* the copy and conversion is avoided and a reference to the original object is
Expand Down
20 changes: 5 additions & 15 deletions include/ginkgo/core/solver/lower_trs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,7 @@ class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER(
preconditioner, nullptr);
};
#define GKO_COMMA ,
GKO_ENABLE_LOWER_TRS_FACTORY(LowerTrs<ValueType GKO_COMMA IndexType>,
parameters, Factory);
#undef GKO_COMMA
GKO_ENABLE_LOWER_TRS_FACTORY(LowerTrs, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);

protected:
Expand All @@ -236,17 +233,10 @@ class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
LinOp *x) const override;

/**
* Generates the solver.
*
* @param system_matrix the source matrix used to generate the
* solver.
* @param b the right hand side used to generate the solver.
*
* @note the system_matrix to be passed in has to be convertible to CSR.
* Otherwise an exception is thrown.
* Generates the analysis structure from the system matrix and the right
* hand side needed for the level solver.
*/
void generate(const matrix::Csr<ValueType, IndexType> *system_matrix,
const matrix::Dense<ValueType> *b);
void generate();

explicit LowerTrs(std::shared_ptr<const Executor> exec)
: EnableLinOp<LowerTrs>(std::move(exec))
Expand Down Expand Up @@ -279,7 +269,7 @@ class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
preconditioner_ = matrix::Identity<ValueType>::create(
this->get_executor(), this->get_size()[0]);
}
this->generate(gko::lend(system_matrix_), gko::lend(b_));
this->generate();
}

private:
Expand Down
26 changes: 13 additions & 13 deletions reference/test/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class LowerTrs : public ::testing::Test {
mtx(gko::initialize<Mtx>(
{{2, 0.0, 0.0}, {3.0, 1, 0.0}, {1.0, 2.0, 3}}, exec)),
b(gko::initialize<Mtx>({{2, 0.0, 0.0}}, exec)),
csr_mtx(gko::copy_and_convert_to<CsrMtx>(exec, mtx.get())),
csr_mtx(gko::copy_and_convert_to<CsrMtx>(exec, gko::lend(mtx))),
lower_trs_factory(Solver::build().on(exec)),
lower_trs_solver(lower_trs_factory->generate(mtx, b))
{}
Expand Down Expand Up @@ -94,14 +94,14 @@ TEST_F(LowerTrs, CanBeCopied)
{
auto copy = Solver::build().on(exec)->generate(Mtx::create(exec),
Mtx::create(exec));
copy->copy_from(gko::lend(lower_trs_solver));

ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
auto copy_mtx = copy.get()->get_system_matrix();
copy->copy_from(gko::lend(lower_trs_solver));
auto copy_mtx = copy->get_system_matrix();
auto d_copy_mtx = Mtx::create(exec);
copy_mtx->convert_to(gko::lend(d_copy_mtx));
auto copy_b = copy.get()->get_rhs();
auto copy_b = copy->get_rhs();

ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
GKO_ASSERT_MTX_NEAR(d_copy_mtx, mtx, 0);
GKO_ASSERT_MTX_NEAR(copy_b, b, 0);
}
Expand All @@ -111,14 +111,14 @@ TEST_F(LowerTrs, CanBeMoved)
{
auto copy =
lower_trs_factory->generate(Mtx::create(exec), Mtx::create(exec));
copy->copy_from(std::move(lower_trs_solver));

ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
auto copy_mtx = copy.get()->get_system_matrix();
copy->copy_from(std::move(lower_trs_solver));
auto copy_mtx = copy->get_system_matrix();
auto d_copy_mtx = Mtx::create(exec);
copy_mtx->convert_to(gko::lend(d_copy_mtx));
auto copy_b = copy.get()->get_rhs();
auto copy_b = copy->get_rhs();

ASSERT_EQ(copy->get_size(), gko::dim<2>(3, 3));
GKO_ASSERT_MTX_NEAR(d_copy_mtx, mtx, 0);
GKO_ASSERT_MTX_NEAR(copy_b, b, 0);
}
Expand All @@ -128,10 +128,10 @@ TEST_F(LowerTrs, CanBeCloned)
{
auto clone = lower_trs_solver->clone();

auto clone_mtx = clone.get()->get_system_matrix();
auto clone_mtx = clone->get_system_matrix();
auto d_clone_mtx = Mtx::create(exec);
clone_mtx->convert_to(gko::lend(d_clone_mtx));
auto clone_b = clone.get()->get_rhs();
auto clone_b = clone->get_rhs();

ASSERT_EQ(clone->get_size(), gko::dim<2>(3, 3));
GKO_ASSERT_MTX_NEAR(d_clone_mtx, mtx, 0);
Expand All @@ -143,8 +143,8 @@ TEST_F(LowerTrs, CanBeCleared)
{
lower_trs_solver->clear();

auto solver_mtx = lower_trs_solver.get()->get_system_matrix();
auto solver_b = lower_trs_solver.get()->get_rhs();
auto solver_mtx = lower_trs_solver->get_system_matrix();
auto solver_b = lower_trs_solver->get_rhs();

ASSERT_EQ(lower_trs_solver->get_size(), gko::dim<2>(0, 0));
ASSERT_EQ(solver_mtx, nullptr);
Expand Down

0 comments on commit 3f08cf0

Please sign in to comment.