Skip to content

Commit

Permalink
Fix spadd (simplied interface) C dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
brian-kelley committed Sep 21, 2020
1 parent 039831b commit 8d98bfb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
21 changes: 14 additions & 7 deletions src/sparse/KokkosSparse_spadd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,15 @@ void spadd_symbolic(KernelHandle* handle, const AMatrix& A, const BMatrix& B,
using entries_type = typename CMatrix::index_type::non_const_type;
using values_type = typename CMatrix::values_type::non_const_type;

//Check that A,B dimensions match
if(A.numRows() != B.numRows() || A.numCols() != B.numCols())
{
char msg[256];
sprintf(msg, "SpAdd: A+B is not defined, since A is %dx%d but B is %dx%d.",
(int) A.numRows(), (int) A.numCols(), (int) B.numRows(), (int) B.numCols());
throw std::invalid_argument(std::string(msg));
}

// Create the row_map of C, no need to initialize it
row_map_type row_mapC(Kokkos::ViewAllocateWithoutInitializing("row map"),
A.numRows() + 1);
Expand All @@ -744,17 +753,14 @@ void spadd_symbolic(KernelHandle* handle, const AMatrix& A, const BMatrix& B,
// views so we can build a graph and then matrix C
// and subsequently construct C.
auto addHandle = handle->get_spadd_handle();
entries_type entriesC(Kokkos::ViewAllocateWithoutInitializing("entries"),
addHandle->get_c_nnz());
graph_type graphC(entriesC, row_mapC);
C = CMatrix("matrix", graphC);
auto c_nnz = addHandle->get_c_nnz();
entries_type entriesC(Kokkos::ViewAllocateWithoutInitializing("Centries"), c_nnz);

// Finally since we already have the number of nnz handy
// we can go ahead and allocate C's values and set them.
values_type valuesC(Kokkos::ViewAllocateWithoutInitializing("values"),
addHandle->get_c_nnz());
values_type valuesC(Kokkos::ViewAllocateWithoutInitializing("Cvalues"), c_nnz);

C.values = valuesC;
C = CMatrix("C=Sum", A.numRows(), A.numCols(), c_nnz, valuesC, row_mapC, entriesC);
}

// Symbolic: count entries in each row in C to produce rowmap
Expand Down Expand Up @@ -784,6 +790,7 @@ crsMat_t spadd(const crsMat_t& A, const crsMat_t& B,
using KernelHandle = KokkosKernels::Experimental::KokkosKernelsHandle<size_type, lno_t, scalar_t, exec_space, mem_space, mem_space>;
KernelHandle kh;
kh.create_spadd_handle(bothSorted);
crsMat_t C;
spadd_symbolic(&kh, A, B, C);
spadd_numeric(&kh, alpha, A, beta, B, C);
return C;
Expand Down
8 changes: 4 additions & 4 deletions unit_test/sparse/Test_Sparse_gauss_seidel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,13 @@ void create_problem(int numRows, int num_vecs, bool symmetric, crsMat_t& A, vec_
{
//Symmetrize on host, rather than relying on the parallel versions (those can be tested for symmetric=false)
crsMat_t A_trans = KokkosKernels::Impl::transpose_matrix(A);
A = KokkosSparse::Experimental::spadd(A, A_trans);
A = KokkosSparse::spadd(A, A_trans, false);
}
//Create random LHS vector (x)
x = vec_t(Kokkos::ViewAllocateWithoutInitializing("X"), A.numCols());
create_x_vector(x);
//do a SPMV to find the RHS vector (y)
y = vec_t(Kokkos::ViewAllocateWithoutInitializing("Y"), A.numCols());
y = vec_t(Kokkos::ViewAllocateWithoutInitializing("Y"), A.numRows());
create_y_vector(A, x, y);
}

Expand All @@ -268,13 +268,13 @@ void create_problem(int numRows, int num_vecs, bool symmetric, crsMat_t& A, vec_
{
//Symmetrize on host, rather than relying on the parallel versions (those can be tested for symmetric=false)
crsMat_t A_trans = KokkosKernels::Impl::transpose_matrix(A);
A = KokkosSparse::Experimental::spadd(A, A_trans);
A = KokkosSparse::spadd(A, A_trans, false);
}
//Create random LHS vector (x)
x = vec_t(Kokkos::ViewAllocateWithoutInitializing("X"), A.numCols(), num_vecs);
create_x_vector(x);
//do a SPMV to find the RHS vector (y)
y = vec_t(Kokkos::ViewAllocateWithoutInitializing("Y"), A.numCols(), num_vecs);
y = vec_t(Kokkos::ViewAllocateWithoutInitializing("Y"), A.numRows(), num_vecs);
create_y_vector(A, x, y);
}

Expand Down

0 comments on commit 8d98bfb

Please sign in to comment.