Skip to content

Commit

Permalink
Fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
uhetmaniuk committed Jan 10, 2022
1 parent 5b7a0f4 commit a20a543
Showing 1 changed file with 63 additions and 62 deletions.
125 changes: 63 additions & 62 deletions src/sparse/impl/KokkosSparse_spmv_bsrmatrix_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,54 +217,55 @@ struct SPMV_MV_BSRMATRIX<AT, AO, AD, AM, AS, XT, XL, XD, XM, YT, YL, YD, YM,
}
//
bool use_tc = false;
if ((controls.isParameter("algorithm")) && (controls.getParameter("algorithm") == "experim
if ((controls.isParameter("algorithm")) &&
(controls.getParameter("algorithm") == "experimental_bsr_tc")) {
if (Kokkos::Details::ArithTraits<YScalar>::is_complex == false)
use_tc = true;
}
}
#endif

#if defined(KOKKOS_ARCH_AMPERE)
typedef typename XVector::non_const_value_type XScalar;
typedef typename AMatrix::non_const_value_type AScalar;
typedef Kokkos::Experimental::half_t Half;
typedef typename XVector::non_const_value_type XScalar;
typedef typename AMatrix::non_const_value_type AScalar;
typedef Kokkos::Experimental::half_t Half;

/* Ampere has double += double * double and float += half * half
/* Ampere has double += double * double and float += half * half
use whichever is requested.
If none requested, used mixed precision if the inputs are mixed, otherwise
use double
*/
use whichever is requested.
If none requested, used mixed precision if the inputs are mixed, otherwise
use double
*/

// input precision matches a tensor core fragment type
constexpr bool operandsHalfHalfFloat = std::is_same<AScalar, Half>::value &&
std::is_same<XScalar, Half>::value &&
std::is_same<YScalar, float>::value;
// input precision matches a tensor core fragment type
constexpr bool operandsHalfHalfFloat = std::is_same<AScalar, Half>::value &&
std::is_same<XScalar, Half>::value &&
std::is_same<YScalar, float>::value;

if (use_tc) {
if (requestMixed) {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, half, XVector, half, YVector,
float, 16, 16, 16>::dispatch(alpha, A,
X, beta,
Y);
return;
} else if (requestDouble) {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, double, XVector, double,
YVector, double, 8, 8,
4>::dispatch(alpha, A, X, beta, Y);
return;
} else if (operandsHalfHalfFloat) {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, half, XVector, half, YVector,
float, 16, 16, 16>::dispatch(alpha, A,
X, beta,
Y);
return;
} else {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, double, XVector, double,
YVector, double, 8, 8,
4>::dispatch(alpha, A, x, beta, y);
return;
if (use_tc) {
if (requestMixed) {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, half, XVector, half, YVector,
float, 16, 16, 16>::dispatch(alpha, A,
X, beta,
Y);
return;
} else if (requestDouble) {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, double, XVector, double,
YVector, double, 8, 8,
4>::dispatch(alpha, A, X, beta, Y);
return;
} else if (operandsHalfHalfFloat) {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, half, XVector, half, YVector,
float, 16, 16, 16>::dispatch(alpha, A,
X, beta,
Y);
return;
} else {
BsrMatrixSpMVTensorCoreDispatcher<AMatrix, double, XVector, double,
YVector, double, 8, 8,
4>::dispatch(alpha, A, x, beta, y);
return;
}
}
}
#elif defined(KOKKOS_ARCH_VOLTA)
/* Volta has float += half * half
use it for all matrices
Expand All @@ -284,32 +285,32 @@ struct SPMV_MV_BSRMATRIX<AT, AO, AD, AM, AS, XT, XL, XD, XM, YT, YL, YD, YM,
}
#endif // KOKKOS_ARCH

if ((mode[0] == KokkosSparse::NoTranspose[0]) ||
(mode[0] == KokkosSparse::Conjugate[0])) {
bool useConjugate = (mode[0] == KokkosSparse::Conjugate[0]);
if (X.extent(1) == 1) {
const auto x0 = Kokkos::subview(X, Kokkos::ALL(), 0);
auto y0 = Kokkos::subview(Y, Kokkos::ALL(), 0);
return Bsr::spMatVec_no_transpose(controls, alpha, A, x0, beta, y0,
useConjugate);
} else {
return Bsr::spMatMultiVec_no_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
}
} else if ((mode[0] == KokkosSparse::Transpose[0]) ||
(mode[0] == KokkosSparse::ConjugateTranspose[0])) {
bool useConjugate = (mode[0] == KokkosSparse::ConjugateTranspose[0]);
if (X.extent(1) == 1) {
const auto x0 = Kokkos::subview(X, Kokkos::ALL(), 0);
auto y0 = Kokkos::subview(Y, Kokkos::ALL(), 0);
return Bsr::spMatVec_transpose(controls, alpha, A, x0, beta, y0,
useConjugate);
} else {
return Bsr::spMatMultiVec_transpose(controls, alpha, A, X, beta, Y,
if ((mode[0] == KokkosSparse::NoTranspose[0]) ||
(mode[0] == KokkosSparse::Conjugate[0])) {
bool useConjugate = (mode[0] == KokkosSparse::Conjugate[0]);
if (X.extent(1) == 1) {
const auto x0 = Kokkos::subview(X, Kokkos::ALL(), 0);
auto y0 = Kokkos::subview(Y, Kokkos::ALL(), 0);
return Bsr::spMatVec_no_transpose(controls, alpha, A, x0, beta, y0,
useConjugate);
} else {
return Bsr::spMatMultiVec_no_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
}
} else if ((mode[0] == KokkosSparse::Transpose[0]) ||
(mode[0] == KokkosSparse::ConjugateTranspose[0])) {
bool useConjugate = (mode[0] == KokkosSparse::ConjugateTranspose[0]);
if (X.extent(1) == 1) {
const auto x0 = Kokkos::subview(X, Kokkos::ALL(), 0);
auto y0 = Kokkos::subview(Y, Kokkos::ALL(), 0);
return Bsr::spMatVec_transpose(controls, alpha, A, x0, beta, y0,
useConjugate);
} else {
return Bsr::spMatMultiVec_transpose(controls, alpha, A, X, beta, Y,
useConjugate);
}
}
}
}
};

template <class AT, class AO, class AD, class AM, class AS, class XT, class XL,
Expand Down Expand Up @@ -340,9 +341,9 @@ struct SPMV_MV_BSRMATRIX<AT, AO, AD, AM, AS, XT, XL, XD, XM, YT, YL, YD, YM,
};
#endif // !defined(KOKKOSKERNELS_ETI_ONLY) ||
// KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
} // namespace Impl
} // namespace Experimental
} // namespace KokkosSparse
} // namespace KokkosSparse

// declare / instantiate the vector version
// Instantiate with A,x,y are all the requested Scalar type (no instantiation of
Expand Down

0 comments on commit a20a543

Please sign in to comment.