Skip to content

Commit

Permalink
use temporary float matrix in AMD
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed May 30, 2023
1 parent c9c0378 commit 5f5590e
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions core/reorder/amd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ std::unique_ptr<LinOp> Amd<IndexType>::generate_impl(
const auto exec = this->get_executor();
const auto host_exec = exec->get_master();
const auto num_rows = system_matrix->get_size()[0];
using complex_scalar = matrix::Dense<std::complex<double>>;
using real_scalar = matrix::Dense<double>;
using complex_identity = matrix::Identity<std::complex<double>>;
using real_identity = matrix::Identity<double>;
using complex_mtx = matrix::Csr<std::complex<double>, IndexType>;
using real_mtx = matrix::Csr<double, IndexType>;
using sparsity_mtx = matrix::SparsityCsr<double, IndexType>;
using complex_scalar = matrix::Dense<std::complex<float>>;
using real_scalar = matrix::Dense<float>;
using complex_identity = matrix::Identity<std::complex<float>>;
using real_identity = matrix::Identity<float>;
using complex_mtx = matrix::Csr<std::complex<float>, IndexType>;
using real_mtx = matrix::Csr<float, IndexType>;
using sparsity_mtx = matrix::SparsityCsr<float, IndexType>;
std::unique_ptr<LinOp> converted;
// extract row pointers and column indices
IndexType* d_row_ptrs{};
Expand All @@ -143,7 +143,7 @@ std::unique_ptr<LinOp> Amd<IndexType>::generate_impl(
}
if (!parameters_.skip_symmetrize) {
auto scalar =
initialize<complex_scalar>({one<std::complex<double>>()}, exec);
initialize<complex_scalar>({one<std::complex<float>>()}, exec);
auto id = complex_identity::create(exec, conv_csr->get_size()[0]);
// compute A^T + A
conv_csr->transpose()->apply(scalar, id, scalar, conv_csr);
Expand All @@ -159,7 +159,7 @@ std::unique_ptr<LinOp> Amd<IndexType>::generate_impl(
conv_csr->sort_by_column_index();
}
if (!parameters_.skip_symmetrize) {
auto scalar = initialize<real_scalar>({one<double>()}, exec);
auto scalar = initialize<real_scalar>({one<float>()}, exec);
auto id = real_identity::create(exec, conv_csr->get_size()[0]);
// compute A^T + A
conv_csr->transpose()->apply(scalar, id, scalar, conv_csr);
Expand Down

0 comments on commit 5f5590e

Please sign in to comment.