Skip to content

Commit

Permalink
Fix switching issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
uhetmaniuk committed Jan 10, 2022
1 parent b661667 commit 5b7a0f4
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 49 deletions.
6 changes: 1 addition & 5 deletions src/sparse/KokkosSparse_spmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,7 @@ void spmv(KokkosKernels::Experimental::Controls controls, const char mode[],
Kokkos::CudaSpace>::value ||
std::is_same<typename AMatrix_Internal::memory_space,
Kokkos::CudaUVMSpace>::value) {
#if defined(CUSPARSE_VERSION)
useFallback = useFallback || (mode[0] != NoTranspose[0]);
#endif
}
#endif

Expand All @@ -465,7 +463,7 @@ void spmv(KokkosKernels::Experimental::Controls controls, const char mode[],
if (useFallback) {
// Explicitly call the non-TPL SPMV_BSRMATRIX implementation
std::string label =
"KokkosSparse::spmv[NATIVE,BSMATRIX," +
"KokkosSparse::spmv[NATIVE,BSRMATRIX," +
Kokkos::ArithTraits<
typename AMatrix_Internal::non_const_value_type>::name() +
"]";
Expand Down Expand Up @@ -844,9 +842,7 @@ void spmv(KokkosKernels::Experimental::Controls controls, const char mode[],
Kokkos::CudaSpace>::value ||
std::is_same<typename AMatrix_Internal::memory_space,
Kokkos::CudaUVMSpace>::value) {
#if defined(CUSPARSE_VERSION)
useFallback = useFallback || (mode[0] != NoTranspose[0]);
#endif
}
#endif

Expand Down
104 changes: 60 additions & 44 deletions src/sparse/impl/KokkosSparse_spmv_bsrmatrix_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,85 +215,101 @@ struct SPMV_MV_BSRMATRIX<AT, AO, AD, AM, AS, XT, XL, XD, XM, YT, YL, YD, YM,
requestDouble = true;
}
}
//
bool use_tc = false;
if ((controls.isParameter("algorithm")) && (controls.getParameter("algorithm") == "experim
if (Kokkos::Details::ArithTraits<YScalar>::is_complex == false)
use_tc = true;
}
#endif
#if defined(KOKKOS_ARCH_AMPERE)
typedef typename XVector::non_const_value_type XScalar;
typedef typename AMatrix::non_const_value_type AScalar;
typedef Kokkos::Experimental::half_t Half;
typedef typename XVector::non_const_value_type XScalar;
typedef typename AMatrix::non_const_value_type AScalar;
typedef Kokkos::Experimental::half_t Half;
/* Ampere has double += double * double and float += half * half
/* Ampere has double += double * double and float += half * half
use whichever is requested.
If none requested, used mixed precision if the inputs are mixed, otherwise
use double
*/
use whichever is requested.
If none requested, used mixed precision if the inputs are mixed, otherwise
use double
*/
// input precision matches a tensor core fragment type
constexpr bool operandsHalfHalfFloat = std::is_same<AScalar, Half>::value &&
std::is_same<XScalar, Half>::value &&
std::is_same<YScalar, float>::value;
// input precision matches a tensor core fragment type
constexpr bool operandsHalfHalfFloat = std::is_same<AScalar, Half>::value &&
std::is_same<XScalar, Half>::value &&
std::is_same<YScalar, float>::value;
if (use_tc) {
if (requestMixed) {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, half, XVector, half, YVector,
float, 16, 16, 16>::dispatch(alpha, A,
X, beta,
Y);
return;
} else if (requestDouble) {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, double, XVector, double,
YVector, double, 8, 8,
4>::dispatch(alpha, A, X, beta, Y);
return;
} else if (operandsHalfHalfFloat) {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, half, XVector, half, YVector,
float, 16, 16, 16>::dispatch(alpha, A,
X, beta,
Y);
return;
} else {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, double, XVector, double,
YVector, double, 8, 8,
4>::dispatch(alpha, A, x, beta, y);
return;
}
}
#elif defined(KOKKOS_ARCH_VOLTA)
/* Volta has float += half * half
use it for all matrices
*/
if (requestDouble) {
Kokkos::Impl::throw_runtime_exception(
"KokkosSparse::spmv[algorithm=experimental_bsr_tc] "
"tc_precision=double unsupported KOKKOS_ARCH_VOLTA");
if (use_tc) {
if (requestDouble) {
Kokkos::Impl::throw_runtime_exception(
"KokkosSparse::spmv[algorithm=experimental_bsr_tc] "
"tc_precision=double unsupported KOKKOS_ARCH_VOLTA");
}
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, half, XVector, half, YVector,
float, 16, 16, 16>::dispatch(alpha, A,
X, beta,
Y);
(void)requestMixed; // unused
return;
}
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, half, XVector, half, YVector,
float, 16, 16, 16>::dispatch(alpha, A, X,
beta, Y);
(void)requestMixed; // unused
#endif // KOKKOS_ARCH
if ((mode[0] == KokkosSparse::NoTranspose[0]) ||
(mode[0] == KokkosSparse::Conjugate[0])) {
bool useConjugate = (mode[0] == KokkosSparse::Conjugate[0]);
if (X.extent(1) == 1) {
const auto x0 = Kokkos::subview(X, Kokkos::ALL(), 0);
auto y0 = Kokkos::subview(Y, Kokkos::ALL(), 0);
return Bsr::spMatVec_no_transpose(controls, alpha, A, x0, beta, y0,
if ((mode[0] == KokkosSparse::NoTranspose[0]) ||
(mode[0] == KokkosSparse::Conjugate[0])) {
bool useConjugate = (mode[0] == KokkosSparse::Conjugate[0]);
if (X.extent(1) == 1) {
const auto x0 = Kokkos::subview(X, Kokkos::ALL(), 0);
auto y0 = Kokkos::subview(Y, Kokkos::ALL(), 0);
return Bsr::spMatVec_no_transpose(controls, alpha, A, x0, beta, y0,
useConjugate);
} else {
return Bsr::spMatMultiVec_no_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
}
} else if ((mode[0] == KokkosSparse::Transpose[0]) ||
(mode[0] == KokkosSparse::ConjugateTranspose[0])) {
bool useConjugate = (mode[0] == KokkosSparse::ConjugateTranspose[0]);
if (X.extent(1) == 1) {
const auto x0 = Kokkos::subview(X, Kokkos::ALL(), 0);
auto y0 = Kokkos::subview(Y, Kokkos::ALL(), 0);
return Bsr::spMatVec_transpose(controls, alpha, A, x0, beta, y0,
useConjugate);
} else {
return Bsr::spMatMultiVec_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
} else {
return Bsr::spMatMultiVec_no_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
}
} else if ((mode[0] == KokkosSparse::Transpose[0]) ||
(mode[0] == KokkosSparse::ConjugateTranspose[0])) {
bool useConjugate = (mode[0] == KokkosSparse::ConjugateTranspose[0]);
if (X.extent(1) == 1) {
const auto x0 = Kokkos::subview(X, Kokkos::ALL(), 0);
auto y0 = Kokkos::subview(Y, Kokkos::ALL(), 0);
return Bsr::spMatVec_transpose(controls, alpha, A, x0, beta, y0,
useConjugate);
} else {
return Bsr::spMatMultiVec_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
}
}
}
}
};
template <class AT, class AO, class AD, class AM, class AS, class XT, class XL,
Expand Down Expand Up @@ -324,9 +340,9 @@ struct SPMV_MV_BSRMATRIX<AT, AO, AD, AM, AS, XT, XL, XD, XM, YT, YL, YD, YM,
};
#endif // !defined(KOKKOSKERNELS_ETI_ONLY) ||
// KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
} // namespace Impl
} // namespace Experimental
} // namespace KokkosSparse
} // namespace KokkosSparse
// declare / instantiate the vector version
// Instantiate with A,x,y are all the requested Scalar type (no instantiation of
Expand Down

0 comments on commit 5b7a0f4

Please sign in to comment.