Skip to content

Commit

Permalink
Use rocsparse_*bsrmv for BsrMatrix SpMV when rocSparse enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed Apr 5, 2023
1 parent eb5ac42 commit eb1b04e
Show file tree
Hide file tree
Showing 7 changed files with 370 additions and 22 deletions.
39 changes: 39 additions & 0 deletions common/src/KokkosKernels_AlwaysFalse.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSKERNELS_ALWAYSFALSE_HPP
#define KOKKOSKERNELS_ALWAYSFALSE_HPP

#include <type_traits>

/*! \file KokkosKernels_AlwaysFalse.hpp
\brief A convenience type to be used in a static_assert that should always
fail
*/

namespace KokkosKernels {
namespace Impl {

template <typename T>
using always_false = std::false_type;

template <typename T>
inline constexpr bool always_false_v = always_false<T>::value;

} // namespace Impl
} // namespace KokkosKernels

#endif //
22 changes: 10 additions & 12 deletions sparse/impl/KokkosSparse_spmv_bsrmatrix_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,13 @@ struct SPMV_BSRMATRIX<AT, AO, AD, AM, AS, XT, XL, XD, XM, YT, YL, YD, YM, false,
const YScalar &alpha, const AMatrix &A, const XVector &X,
const YScalar &beta, const YVector &Y) {
//
if ((mode[0] == KokkosSparse::NoTranspose[0]) ||
(mode[0] == KokkosSparse::Conjugate[0])) {
bool useConjugate = (mode[0] == KokkosSparse::Conjugate[0]);
if ((mode[0] == NoTranspose[0]) || (mode[0] == Conjugate[0])) {
bool useConjugate = (mode[0] == Conjugate[0]);
return Bsr::spMatVec_no_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
} else if ((mode[0] == KokkosSparse::Transpose[0]) ||
(mode[0] == KokkosSparse::ConjugateTranspose[0])) {
bool useConjugate = (mode[0] == KokkosSparse::ConjugateTranspose[0]);
} else if ((mode[0] == Transpose[0]) ||
(mode[0] == ConjugateTranspose[0])) {
bool useConjugate = (mode[0] == ConjugateTranspose[0]);
return Bsr::spMatVec_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
}
Expand Down Expand Up @@ -292,14 +291,13 @@ struct SPMV_MV_BSRMATRIX<AT, AO, AD, AM, AS, XT, XL, XD, XM, YT, YL, YD, YM,
}
#endif // KOKKOS_ARCH

if ((mode[0] == KokkosSparse::NoTranspose[0]) ||
(mode[0] == KokkosSparse::Conjugate[0])) {
bool useConjugate = (mode[0] == KokkosSparse::Conjugate[0]);
if ((mode[0] == NoTranspose[0]) || (mode[0] == Conjugate[0])) {
bool useConjugate = (mode[0] == Conjugate[0]);
return Bsr::spMatMultiVec_no_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
} else if ((mode[0] == KokkosSparse::Transpose[0]) ||
(mode[0] == KokkosSparse::ConjugateTranspose[0])) {
bool useConjugate = (mode[0] == KokkosSparse::ConjugateTranspose[0]);
} else if ((mode[0] == Transpose[0]) ||
(mode[0] == ConjugateTranspose[0])) {
bool useConjugate = (mode[0] == ConjugateTranspose[0]);
return Bsr::spMatMultiVec_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
}
Expand Down
16 changes: 13 additions & 3 deletions sparse/src/KokkosSparse_Utils_rocsparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#ifndef _KOKKOSKERNELS_SPARSEUTILS_ROCSPARSE_HPP
#define _KOKKOSKERNELS_SPARSEUTILS_ROCSPARSE_HPP

#include <type_traits>

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
#include <rocm_version.h>
#include "rocsparse/rocsparse.h"
Expand Down Expand Up @@ -150,21 +152,29 @@ inline rocsparse_datatype rocsparse_compute_type<Kokkos::complex<double>>() {
return rocsparse_datatype_f64_c;
}

template <typename Scalar>
struct kokkos_to_rocsparse_type {
using type = Scalar;
template <typename T, typename E = void>
struct kokkos_to_rocsparse_type;

// for floats, rocsparse uses c++ builtin types
template <typename T>
struct kokkos_to_rocsparse_type<T,
std::enable_if_t<std::is_floating_point_v<T>>> {
using type = T;
};

// translate complex float
template <>
struct kokkos_to_rocsparse_type<Kokkos::complex<float>> {
using type = rocsparse_float_complex;
};

// translate complex double
template <>
struct kokkos_to_rocsparse_type<Kokkos::complex<double>> {
using type = rocsparse_double_complex;
};

// e.g. 5.4 -> 50400
#define KOKKOSSPARSE_IMPL_ROCM_VERSION \
ROCM_VERSION_MAJOR * 10000 + ROCM_VERSION_MINOR * 100 + ROCM_VERSION_PATCH

Expand Down
8 changes: 8 additions & 0 deletions sparse/src/KokkosSparse_spmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,14 @@ void spmv(KokkosKernels::Experimental::Controls controls, const char mode[],
}
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
// rocSparse does not support the modes (C), (T), (H)
if constexpr (std::is_same_v<typename AMatrix_Internal::memory_space,
Kokkos::HIPSpace>) {
useFallback = useFallback || (mode[0] != NoTranspose[0]);
}
#endif

if (useFallback) {
// Explicitly call the non-TPL SPMV_BSRMATRIX implementation
std::string label =
Expand Down
65 changes: 64 additions & 1 deletion sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_MKL(Kokkos::complex<double>,
Kokkos::OpenMP)
#endif

#endif
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

// Specialization struct which defines whether a specialization exists
template <class AT, class AO, class AD, class AM, class AS, class XT, class XL,
Expand Down Expand Up @@ -248,6 +248,69 @@ KOKKOSSPARSE_SPMV_MV_BSRMATRIX_TPL_SPEC_AVAIL_MKL(Kokkos::complex<double>,

#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE)

#include "KokkosSparse_Utils_rocsparse.hpp"

#define KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE( \
SCALAR, ORDINAL, OFFSET, LAYOUT, MEMSPACE) \
template <> \
struct spmv_bsrmatrix_tpl_spec_avail< \
const SCALAR, const ORDINAL, Kokkos::Device<Kokkos::HIP, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>, const OFFSET, const SCALAR*, \
LAYOUT, Kokkos::Device<Kokkos::HIP, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged | Kokkos::RandomAccess>, SCALAR*, \
LAYOUT, Kokkos::Device<Kokkos::HIP, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > { \
enum : bool { value = true }; \
};

// These things may also be valid before 5.4, but I haven't tested it.
#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50400

KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(float, rocsparse_int,
rocsparse_int,
Kokkos::LayoutLeft,
Kokkos::HIPSpace)
KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(double, rocsparse_int,
rocsparse_int,
Kokkos::LayoutLeft,
Kokkos::HIPSpace)
KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(float, rocsparse_int,
rocsparse_int,
Kokkos::LayoutRight,
Kokkos::HIPSpace)
KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(double, rocsparse_int,
rocsparse_int,
Kokkos::LayoutRight,
Kokkos::HIPSpace)
KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(Kokkos::complex<float>,
rocsparse_int,
rocsparse_int,
Kokkos::LayoutLeft,
Kokkos::HIPSpace)
KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(Kokkos::complex<double>,
rocsparse_int,
rocsparse_int,
Kokkos::LayoutLeft,
Kokkos::HIPSpace)
KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(Kokkos::complex<float>,
rocsparse_int,
rocsparse_int,
Kokkos::LayoutRight,
Kokkos::HIPSpace)
KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(Kokkos::complex<double>,
rocsparse_int,
rocsparse_int,
Kokkos::LayoutRight,
Kokkos::HIPSpace)

#endif // KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50400

#undef KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE

#endif // defined(KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE)

} // namespace Impl
} // namespace Experimental
} // namespace KokkosSparse
Expand Down
Loading

0 comments on commit eb1b04e

Please sign in to comment.