Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove templated matrix constructors #1433

Merged
merged 33 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/utils/overhead_linop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class Overhead : public EnableLinOp<Overhead<ValueType>>,
parameters_.preconditioner->generate(system_matrix_));
} else {
set_preconditioner(matrix::Identity<ValueType>::create(
this->get_executor(), this->get_size()));
this->get_executor(), this->get_size()[0]));
}
stop_criterion_factory_ =
stop::combine(std::move(parameters_.criteria));
Expand Down
32 changes: 32 additions & 0 deletions core/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ MultiVector<ValueType>::MultiVector(std::shared_ptr<const Executor> exec,
{}


template <typename ValueType>
MultiVector<ValueType>::MultiVector(std::shared_ptr<const Executor> exec,
const batch_dim<2>& size,
array<value_type> values)
: EnablePolymorphicObject<MultiVector<ValueType>>(exec),
batch_size_(size),
values_{exec, std::move(values)}
{
// Ensure that the values array has the correct size
auto num_elems = compute_num_elems(size);
GKO_ENSURE_IN_BOUNDS(num_elems, values_.get_size() + 1);
}


template <typename ValueType>
std::unique_ptr<MultiVector<ValueType>>
MultiVector<ValueType>::create_with_config_of(
Expand All @@ -112,6 +126,24 @@ MultiVector<ValueType>::create_with_config_of(
}


template <typename ValueType>
std::unique_ptr<MultiVector<ValueType>> MultiVector<ValueType>::create(
std::shared_ptr<const Executor> exec, const batch_dim<2>& size)
{
return std::unique_ptr<MultiVector<ValueType>>{new MultiVector{exec, size}};
}


template <typename ValueType>
std::unique_ptr<MultiVector<ValueType>> MultiVector<ValueType>::create(
std::shared_ptr<const Executor> exec, const batch_dim<2>& size,
array<value_type> values)
{
return std::unique_ptr<MultiVector<ValueType>>{
new MultiVector{exec, size, std::move(values)}};
}


template <typename ValueType>
std::unique_ptr<const MultiVector<ValueType>>
MultiVector<ValueType>::create_const(
Expand Down
15 changes: 15 additions & 0 deletions core/base/device_matrix_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ device_matrix_data<ValueType, IndexType>::device_matrix_data(
{}


template <typename ValueType, typename IndexType>
device_matrix_data<ValueType, IndexType>::device_matrix_data(
std::shared_ptr<const Executor> exec, dim<2> size,
array<index_type> row_idxs, array<index_type> col_idxs,
array<value_type> values)
: size_{size},
row_idxs_{exec, std::move(row_idxs)},
col_idxs_{exec, std::move(col_idxs)},
values_{exec, std::move(values)}
{
GKO_ASSERT_EQ(values_.get_size(), row_idxs_.get_size());
GKO_ASSERT_EQ(values_.get_size(), col_idxs_.get_size());
}


template <typename ValueType, typename IndexType>
matrix_data<ValueType, IndexType>
device_matrix_data<ValueType, IndexType>::copy_to_host() const
Expand Down
66 changes: 66 additions & 0 deletions core/base/perturbation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,72 @@ Perturbation<ValueType>::Perturbation(Perturbation&& other)
}


template <typename ValueType>
Perturbation<ValueType>::Perturbation(std::shared_ptr<const Executor> exec)
: EnableLinOp<Perturbation>(std::move(exec))
{}


template <typename ValueType>
Perturbation<ValueType>::Perturbation(std::shared_ptr<const LinOp> scalar,
std::shared_ptr<const LinOp> basis)
: Perturbation(std::move(scalar),
// basis can not be std::move(basis). Otherwise, Program
// deletes basis before applying conjugate transpose
basis,
std::move((as<gko::Transposable>(basis))->conj_transpose()))
{}


template <typename ValueType>
Perturbation<ValueType>::Perturbation(std::shared_ptr<const LinOp> scalar,
std::shared_ptr<const LinOp> basis,
std::shared_ptr<const LinOp> projector)
: EnableLinOp<Perturbation>(basis->get_executor(),
gko::dim<2>{basis->get_size()[0]}),
scalar_{std::move(scalar)},
basis_{std::move(basis)},
projector_{std::move(projector)}
{
this->validate_perturbation();
}


template <typename ValueType>
std::unique_ptr<Perturbation<ValueType>> Perturbation<ValueType>::create(
std::shared_ptr<const Executor> exec)
{
return std::unique_ptr<Perturbation>{new Perturbation{exec}};
}


template <typename ValueType>
std::unique_ptr<Perturbation<ValueType>> Perturbation<ValueType>::create(
std::shared_ptr<const LinOp> scalar, std::shared_ptr<const LinOp> basis)
{
return std::unique_ptr<Perturbation>{new Perturbation{scalar, basis}};
}


template <typename ValueType>
std::unique_ptr<Perturbation<ValueType>> Perturbation<ValueType>::create(
std::shared_ptr<const LinOp> scalar, std::shared_ptr<const LinOp> basis,
std::shared_ptr<const LinOp> projector)
{
return std::unique_ptr<Perturbation>{
new Perturbation{scalar, basis, projector}};
}


template <typename ValueType>
void Perturbation<ValueType>::validate_perturbation()
{
GKO_ASSERT_CONFORMANT(basis_, projector_);
GKO_ASSERT_CONFORMANT(projector_, basis_);
GKO_ASSERT_EQUAL_DIMENSIONS(scalar_, dim<2>(1, 1));
}


template <typename ValueType>
void Perturbation<ValueType>::apply_impl(const LinOp* b, LinOp* x) const
{
Expand Down
43 changes: 34 additions & 9 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,9 @@ GKO_REGISTER_OPERATION(build_local_nonlocal,
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
std::shared_ptr<const Executor> exec, mpi::communicator comm)
: Matrix(exec, comm, with_matrix_type<gko::matrix::Csr>())
{}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
ptr_param<const LinOp> local_matrix_template)
: Matrix(exec, comm, local_matrix_template, local_matrix_template)
: Matrix(exec, comm,
gko::matrix::Csr<ValueType, LocalIndexType>::create(exec),
gko::matrix::Csr<ValueType, LocalIndexType>::create(exec))
{}


Expand Down Expand Up @@ -72,6 +66,37 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
std::unique_ptr<Matrix<ValueType, LocalIndexType, GlobalIndexType>>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::create(
std::shared_ptr<const Executor> exec, mpi::communicator comm)
{
return std::unique_ptr<Matrix>{new Matrix{exec, comm}};
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
std::unique_ptr<Matrix<ValueType, LocalIndexType, GlobalIndexType>>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::create(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
ptr_param<const LinOp> matrix_template)
{
return create(exec, comm, matrix_template, matrix_template);
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
std::unique_ptr<Matrix<ValueType, LocalIndexType, GlobalIndexType>>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::create(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
ptr_param<const LinOp> local_matrix_template,
ptr_param<const LinOp> non_local_matrix_template)
{
return std::unique_ptr<Matrix>{new Matrix{exec, comm, local_matrix_template,
non_local_matrix_template}};
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
Matrix<next_precision<value_type>, local_index_type, global_index_type>*
Expand Down
31 changes: 31 additions & 0 deletions core/distributed/partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,37 @@ GKO_REGISTER_OPERATION(has_ordered_parts, partition::has_ordered_parts);
} // namespace partition


template <typename LocalIndexType, typename GlobalIndexType>
Partition<LocalIndexType, GlobalIndexType>::Partition(
std::shared_ptr<const Executor> exec, comm_index_type num_parts,
size_type num_ranges)
: EnablePolymorphicObject<Partition>{exec},
num_parts_{num_parts},
num_empty_parts_{0},
size_{0},
offsets_{exec, num_ranges + 1},
starting_indices_{exec, num_ranges},
part_sizes_{exec, static_cast<size_type>(num_parts)},
part_ids_{exec, num_ranges}
{
offsets_.fill(0);
starting_indices_.fill(0);
part_sizes_.fill(0);
part_ids_.fill(0);
}


template <typename LocalIndexType, typename GlobalIndexType>
std::unique_ptr<Partition<LocalIndexType, GlobalIndexType>>
Partition<LocalIndexType, GlobalIndexType>::create(
std::shared_ptr<const Executor> exec, comm_index_type num_parts,
size_type num_ranges)
{
return std::unique_ptr<Partition>{
new Partition{exec, num_parts, num_ranges}};
}


template <typename LocalIndexType, typename GlobalIndexType>
std::unique_ptr<Partition<LocalIndexType, GlobalIndexType>>
Partition<LocalIndexType, GlobalIndexType>::build_from_mapping(
Expand Down
39 changes: 39 additions & 0 deletions core/distributed/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,45 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
local_vector->move_to(&local_);
}

template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
dim<2> global_size, dim<2> local_size, size_type stride)
{
return std::unique_ptr<Vector>{
new Vector{exec, comm, global_size, local_size, stride}};
}


template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
dim<2> global_size, dim<2> local_size)
{
return std::unique_ptr<Vector>{
new Vector{exec, comm, global_size, local_size}};
}


template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
dim<2> global_size, std::unique_ptr<local_vector_type> local_vector)
{
return std::unique_ptr<Vector>{
new Vector{exec, comm, global_size, std::move(local_vector)}};
}


template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
std::unique_ptr<local_vector_type> local_vector)
{
return std::unique_ptr<Vector>{
new Vector{exec, comm, std::move(local_vector)}};
}


template <typename ValueType>
std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_const(
Expand Down
2 changes: 1 addition & 1 deletion core/factorization/symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void symbolic_lu_near_symm(
// compute A + A^T symbolically
const auto scalar = gko::initialize<scalar_type>({one<float>()}, exec);
const auto symm_mtx = as<float_matrix_type>(float_mtx->transpose());
const auto id = id_type::create(exec, size);
const auto id = id_type::create(exec, size[0]);
float_mtx->apply(scalar, id, scalar, symm_mtx);
// compute Cholesky factorization
std::unique_ptr<elimination_forest<IndexType>> forest;
Expand Down
21 changes: 21 additions & 0 deletions core/matrix/batch_csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,27 @@ Csr<ValueType, IndexType>::create_const_view_for_item(size_type item_id) const
}


template <typename ValueType, typename IndexType>
std::unique_ptr<Csr<ValueType, IndexType>> Csr<ValueType, IndexType>::create(
std::shared_ptr<const Executor> exec, const batch_dim<2>& size,
size_type num_nonzeros_per_item)
{
return std::unique_ptr<Csr>{new Csr{exec, size, num_nonzeros_per_item}};
}


template <typename ValueType, typename IndexType>
std::unique_ptr<Csr<ValueType, IndexType>> Csr<ValueType, IndexType>::create(
std::shared_ptr<const Executor> exec, const batch_dim<2>& size,
array<value_type> values, array<index_type> col_idxs,
array<index_type> row_ptrs)
{
return std::unique_ptr<Csr>{new Csr{exec, size, std::move(values),
std::move(col_idxs),
std::move(row_ptrs)}};
}


template <typename ValueType, typename IndexType>
std::unique_ptr<const Csr<ValueType, IndexType>>
Csr<ValueType, IndexType>::create_const(
Expand Down
28 changes: 28 additions & 0 deletions core/matrix/batch_dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ Dense<ValueType>::create_const_view_for_item(size_type item_id) const
}


template <typename ValueType>
std::unique_ptr<Dense<ValueType>> Dense<ValueType>::create(
std::shared_ptr<const Executor> exec, const batch_dim<2>& size)
{
return std::unique_ptr<Dense>(new Dense{exec, size});
}


template <typename ValueType>
std::unique_ptr<Dense<ValueType>> Dense<ValueType>::create(
std::shared_ptr<const Executor> exec, const batch_dim<2>& size,
array<value_type> values)
{
return std::unique_ptr<Dense>(new Dense{exec, size, std::move(values)});
}


template <typename ValueType>
std::unique_ptr<const Dense<ValueType>> Dense<ValueType>::create_const(
std::shared_ptr<const Executor> exec, const batch_dim<2>& sizes,
Expand All @@ -92,6 +109,17 @@ Dense<ValueType>::Dense(std::shared_ptr<const Executor> exec,
{}


template <typename ValueType>
Dense<ValueType>::Dense(std::shared_ptr<const Executor> exec,
const batch_dim<2>& size, array<value_type> values)
: EnableBatchLinOp<Dense>(exec, size), values_{exec, std::move(values)}
{
// Ensure that the values array has the correct size
auto num_elems = compute_num_elems(size);
GKO_ENSURE_IN_BOUNDS(num_elems, values_.get_size() + 1);
}


template <typename ValueType>
Dense<ValueType>* Dense<ValueType>::apply(
ptr_param<const MultiVector<ValueType>> b,
Expand Down
Loading
Loading