Skip to content

Commit

Permalink
explicit casting of zero
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed May 6, 2024
1 parent 20f9fde commit aa8440c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
18 changes: 10 additions & 8 deletions sparse/impl/KokkosSparse_spmv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,17 +450,18 @@ static void spmv_beta_transpose(const execution_space& exec,
const AMatrix& A, const XVector& x,
typename YVector::const_value_type& beta,
const YVector& y) {
using ordinal_type = typename AMatrix::non_const_ordinal_type;
using size_type = typename AMatrix::non_const_size_type;
using ordinal_type = typename AMatrix::non_const_ordinal_type;
using size_type = typename AMatrix::non_const_size_type;
using y_scalar_type = typename YVector::non_const_value_type;

if (A.numRows() <= static_cast<ordinal_type>(0)) {
return;
}

// We need to scale y first ("scaling" by zero just means filling
// with zeros), since the functor works by atomic-adding into y.
if (0 == dobeta || 0 == beta) {
Kokkos::deep_copy(exec, y, 0);
if (0 == dobeta || y_scalar_type(0) == beta) {
Kokkos::deep_copy(exec, y, y_scalar_type(0));
} else if (dobeta != 1) {
KokkosBlas::scal(exec, y, beta, y);
}
Expand Down Expand Up @@ -542,17 +543,18 @@ static void spmv_beta_transpose(const execution_space& exec,
const AMatrix& A, const XVector& x,
typename YVector::const_value_type& beta,
const YVector& y) {
using ordinal_type = typename AMatrix::non_const_ordinal_type;
using size_type = typename AMatrix::non_const_size_type;
using ordinal_type = typename AMatrix::non_const_ordinal_type;
using size_type = typename AMatrix::non_const_size_type;
using y_scalar_type = typename YVector::non_const_value_type;

if (A.numRows() <= static_cast<ordinal_type>(0)) {
return;
}

// We need to scale y first ("scaling" by zero just means filling
// with zeros), since the functor works by atomic-adding into y.
if (0 == dobeta || 0 == beta) {
Kokkos::deep_copy(exec, y, 0);
if (0 == dobeta || y_scalar_type(0) == beta) {
Kokkos::deep_copy(exec, y, y_scalar_type(0));
} else if (dobeta != 1) {
KokkosBlas::scal(exec, y, beta, y);
}
Expand Down
4 changes: 2 additions & 2 deletions sparse/impl/KokkosSparse_spmv_impl_merge.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ struct SpmvMergeHierarchical {
static_assert(XVector::rank == 1, "");
static_assert(YVector::rank == 1, "");

if (0 == beta) {
Kokkos::deep_copy(space, y, 0);
if (y_value_type(0) == beta) {
Kokkos::deep_copy(space, y, y_value_type(0));
} else {
KokkosBlas::scal(space, y, beta, y);
}
Expand Down
5 changes: 3 additions & 2 deletions sparse/unit_test/Test_Sparse_spmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ void sequential_spmv(crsMat_t input_mat, x_vector_type x, y_vector_type y,
using size_type_view_t = typename graph_t::row_map_type;
using lno_view_t = typename graph_t::entries_type;
using scalar_view_t = typename crsMat_t::values_type::non_const_type;
using y_scalar_t = typename y_vector_type::non_const_value_type;

using size_type = typename size_type_view_t::non_const_value_type;
using lno_t = typename lno_view_t::non_const_value_type;
Expand Down Expand Up @@ -149,8 +150,8 @@ void sequential_spmv(crsMat_t input_mat, x_vector_type x, y_vector_type y,

// first, scale y by beta
for (size_t i = 0; i < h_y.extent(0); i++) {
if (beta == 0) {
h_y(i) = 0;
if (beta == y_scalar_t(0)) {
h_y(i) = y_scalar_t(0);
} else {
h_y(i) *= beta;
}
Expand Down

0 comments on commit aa8440c

Please sign in to comment.