Skip to content

Commit

Permalink
Merge pull request #1131 from lucbv/gemm_on_streams
Browse files Browse the repository at this point in the history
Stream interface: adding stream support in GEMV and GEMM
  • Loading branch information
lucbv authored Oct 25, 2021
2 parents f631acb + 064249b commit 7f72a92
Show file tree
Hide file tree
Showing 12 changed files with 842 additions and 529 deletions.
39 changes: 34 additions & 5 deletions src/blas/KokkosBlas2_gemv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ namespace KokkosBlas {
/// \tparam AlphaCoeffType Type of input coefficient alpha
/// \tparam BetaCoeffType Type of input coefficient beta
///
/// \param space [in] execution space instance on which to run the
/// kernel. This may contain information about which stream to
/// run on.
/// \param trans [in] "N" for non-transpose, "T" for transpose, "C"
/// for conjugate transpose. All characters after the first are
/// ignored. This works just like the BLAS routines.
Expand All @@ -72,9 +75,10 @@ namespace KokkosBlas {
/// \param beta [in] Input coefficient of y
/// \param y [in/out] Output vector, as a nonconst 1-D Kokkos::View
template <class AViewType, class XViewType, class YViewType>
void gemv(const char trans[], typename AViewType::const_value_type& alpha,
const AViewType& A, const XViewType& x,
typename YViewType::const_value_type& beta, const YViewType& y) {
void gemv(const typename AViewType::execution_space& space, const char trans[],
typename AViewType::const_value_type& alpha, const AViewType& A,
const XViewType& x, typename YViewType::const_value_type& beta,
const YViewType& y) {
static_assert(Kokkos::Impl::is_view<AViewType>::value,
"AViewType must be a Kokkos::View.");
static_assert(Kokkos::Impl::is_view<XViewType>::value,
Expand Down Expand Up @@ -144,13 +148,38 @@ void gemv(const char trans[], typename AViewType::const_value_type& alpha,
const bool eti_spec_avail =
KokkosBlas::Impl::gemv_eti_spec_avail<AVT, XVT, YVT>::value;
typedef Impl::GEMV<AVT, XVT, YVT, false, eti_spec_avail> fallback_impl_type;
fallback_impl_type::gemv(trans, alpha, A, x, beta, y);
fallback_impl_type::gemv(space, trans, alpha, A, x, beta, y);
} else {
typedef Impl::GEMV<AVT, XVT, YVT> impl_type;
impl_type::gemv(trans, alpha, A, x, beta, y);
impl_type::gemv(space, trans, alpha, A, x, beta, y);
}
}

/// \brief Dense matrix-vector multiply: y = beta*y + alpha*A*x.
///
/// \tparam AViewType Input matrix, as a 2-D Kokkos::View
/// \tparam XViewType Input vector, as a 1-D Kokkos::View
/// \tparam YViewType Output vector, as a nonconst 1-D Kokkos::View
/// \tparam AlphaCoeffType Type of input coefficient alpha
/// \tparam BetaCoeffType Type of input coefficient beta
///
/// \param trans [in] "N" for non-transpose, "T" for transpose, "C"
/// for conjugate transpose. All characters after the first are
/// ignored. This works just like the BLAS routines.
/// \param alpha [in] Input coefficient of A*x
/// \param A [in] Input matrix, as a 2-D Kokkos::View
/// \param x [in] Input vector, as a 1-D Kokkos::View
/// \param beta [in] Input coefficient of y
/// \param y [in/out] Output vector, as a nonconst 1-D Kokkos::View
template <class AViewType, class XViewType, class YViewType>
void gemv(const char trans[], typename AViewType::const_value_type& alpha,
const AViewType& A, const XViewType& x,
typename YViewType::const_value_type& beta, const YViewType& y) {
const typename AViewType::execution_space space =
typename AViewType::execution_space();
gemv(space, trans, alpha, A, x, beta, y);
}

} // namespace KokkosBlas

#endif // KOKKOS_BLAS2_MV_HPP_
52 changes: 41 additions & 11 deletions src/blas/KokkosBlas3_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ namespace Impl {
// cuBLAS.
template <class AViewType, class BViewType, class CViewType>
bool gemv_based_gemm(
const char transA[], const char transB[],
typename AViewType::const_value_type& alpha, const AViewType& A,
const BViewType& B, typename CViewType::const_value_type& beta,
const CViewType& C,
const typename CViewType::execution_space& space, const char transA[],
const char transB[], typename AViewType::const_value_type& alpha,
const AViewType& A, const BViewType& B,
typename CViewType::const_value_type& beta, const CViewType& C,
typename std::enable_if<!std::is_same<typename BViewType::array_layout,
Kokkos::LayoutStride>::value &&
!std::is_same<typename CViewType::array_layout,
Expand All @@ -91,7 +91,7 @@ bool gemv_based_gemm(
typename CViewType::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>
Cvec(C.data(), C.extent(0));
KokkosBlas::gemv("N", alpha, A, Bvec, beta, Cvec);
KokkosBlas::gemv(space, "N", alpha, A, Bvec, beta, Cvec);
return true;
}
return false;
Expand All @@ -102,6 +102,7 @@ bool gemv_based_gemm(
// tests.
template <class AViewType, class BViewType, class CViewType>
bool gemv_based_gemm(
const typename CViewType::execution_space& /*space*/,
const char /*transA*/[], const char /*transB*/[],
typename AViewType::const_value_type& /*alpha*/, const AViewType& /*A*/,
const BViewType& /*B*/, typename CViewType::const_value_type& /*beta*/,
Expand All @@ -121,6 +122,7 @@ bool gemv_based_gemm(
/// \tparam BViewType Input matrix, as a 2-D Kokkos::View
/// \tparam CViewType Output matrix, as a nonconst 2-D Kokkos::View
///
/// \param space [in] an execution space instance
/// \param transA [in] "N" for non-transpose, "T" for transpose, "C"
/// for conjugate transpose. All characters after the first are
/// ignored. This works just like the BLAS routines.
Expand All @@ -133,10 +135,10 @@ bool gemv_based_gemm(
/// \param beta [in] Input coefficient of C
/// \param C [in/out] Output vector, as a nonconst 2-D Kokkos::View
template <class AViewType, class BViewType, class CViewType>
void gemm(const char transA[], const char transB[],
typename AViewType::const_value_type& alpha, const AViewType& A,
const BViewType& B, typename CViewType::const_value_type& beta,
const CViewType& C) {
void gemm(const typename CViewType::execution_space& space, const char transA[],
const char transB[], typename AViewType::const_value_type& alpha,
const AViewType& A, const BViewType& B,
typename CViewType::const_value_type& beta, const CViewType& C) {
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
static_assert(Kokkos::Impl::is_view<AViewType>::value,
"AViewType must be a Kokkos::View.");
Expand Down Expand Up @@ -203,7 +205,8 @@ void gemm(const char transA[], const char transB[],
}

// Check if gemv code path is allowed and profitable, and if so run it.
if (Impl::gemv_based_gemm(transA, transB, alpha, A, B, beta, C)) return;
if (Impl::gemv_based_gemm(space, transA, transB, alpha, A, B, beta, C))
return;

// Minimize the number of Impl::GEMM instantiations, by
// standardizing on particular View specializations for its template
Expand All @@ -222,7 +225,34 @@ void gemm(const char transA[], const char transB[],
Kokkos::MemoryTraits<Kokkos::Unmanaged>>
CVT;
typedef Impl::GEMM<AVT, BVT, CVT> impl_type;
impl_type::gemm(transA, transB, alpha, A, B, beta, C);
impl_type::gemm(space, transA, transB, alpha, A, B, beta, C);
}

/// \brief Dense matrix-matrix multiply: C = beta*C + alpha*op(A)*op(B).
///
/// \tparam AViewType Input matrix, as a 2-D Kokkos::View
/// \tparam BViewType Input matrix, as a 2-D Kokkos::View
/// \tparam CViewType Output matrix, as a nonconst 2-D Kokkos::View
///
/// \param transA [in] "N" for non-transpose, "T" for transpose, "C"
/// for conjugate transpose. All characters after the first are
/// ignored. This works just like the BLAS routines.
/// \param transB [in] "N" for non-transpose, "T" for transpose, "C"
/// for conjugate transpose. All characters after the first are
/// ignored. This works just like the BLAS routines.
/// \param alpha [in] Input coefficient of A*x
/// \param A [in] Input matrix, as a 2-D Kokkos::View
/// \param B [in] Input matrix, as a 2-D Kokkos::View
/// \param beta [in] Input coefficient of C
/// \param C [in/out] Output vector, as a nonconst 2-D Kokkos::View
template <class AViewType, class BViewType, class CViewType>
void gemm(const char transA[], const char transB[],
typename AViewType::const_value_type& alpha, const AViewType& A,
const BViewType& B, typename CViewType::const_value_type& beta,
const CViewType& C) {
const typename CViewType::execution_space space =
typename CViewType::execution_space();
gemm(space, transA, transB, alpha, A, B, beta, C);
}

} // namespace KokkosBlas
Expand Down
28 changes: 16 additions & 12 deletions src/blas/impl/KokkosBlas2_gemv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ struct SingleLevelTransposeGEMV {
// Single-level parallel version of GEMV.
template <class AViewType, class XViewType, class YViewType,
class IndexType = typename AViewType::size_type>
void singleLevelGemv(const char trans[],
void singleLevelGemv(const typename AViewType::execution_space& space,
const char trans[],
typename AViewType::const_value_type& alpha,
const AViewType& A, const XViewType& x,
typename YViewType::const_value_type& beta,
Expand All @@ -255,7 +256,7 @@ void singleLevelGemv(const char trans[],
using AlphaCoeffType = typename AViewType::non_const_value_type;
using BetaCoeffType = typename YViewType::non_const_value_type;

policy_type range(0, A.extent(0));
policy_type range(space, 0, A.extent(0));
const char tr = trans[0];

// The transpose and conjugate transpose cases where A has zero rows
Expand Down Expand Up @@ -666,7 +667,8 @@ struct TwoLevelTransposeGEMV {
// Two-level parallel version of GEMV.
template <class AViewType, class XViewType, class YViewType,
class IndexType = typename AViewType::size_type>
void twoLevelGemv(const char trans[],
void twoLevelGemv(const typename AViewType::execution_space& space,
const char trans[],
typename AViewType::const_value_type& alpha,
const AViewType& A, const XViewType& x,
typename YViewType::const_value_type& beta,
Expand Down Expand Up @@ -717,7 +719,7 @@ void twoLevelGemv(const char trans[],
IndexType>;
functor_type functor(alpha, A, x, beta, y);
Kokkos::parallel_for("KokkosBlas::gemv[SingleLevel]",
range_policy_type(0, y.extent(0)), functor);
range_policy_type(space, 0, y.extent(0)), functor);
}
return;
}
Expand Down Expand Up @@ -747,11 +749,11 @@ void twoLevelGemv(const char trans[],
if ((size_t)teamSize > 32 * A.extent(1)) teamSize = 32 * A.extent(1);
int numBlocks = teamSize / 32;
functor.columnsPerThread = (A.extent(1) + numBlocks - 1) / numBlocks;
team = tagged_policy(numTeams, teamSize)
team = tagged_policy(space, numTeams, teamSize)
.set_scratch_size(0, Kokkos::PerTeam(sharedPerTeam));
} else {
// LayoutRight: one team per row
team = tagged_policy(A.extent(0), Kokkos::AUTO);
team = tagged_policy(space, A.extent(0), Kokkos::AUTO);
}
Kokkos::parallel_for("KokkosBlas::gemv[twoLevel]", team, functor);
} else {
Expand All @@ -762,15 +764,15 @@ void twoLevelGemv(const char trans[],
// Do nothing (y := 1 * y)
} else if (tr == 'T') {
// transpose, and not conj transpose
team_policy_type team(A.extent(1), Kokkos::AUTO);
team_policy_type team(space, A.extent(1), Kokkos::AUTO);
using functor_type = TwoLevelTransposeGEMV<AViewType, XViewType,
YViewType, false, IndexType>;
functor_type functor(alpha, A, x, beta, y);
Kokkos::parallel_for("KokkosBlas::gemv[twoLevelTranspose]", team,
functor);
} else if (tr == 'C' || tr == 'H') {
// conjugate transpose
team_policy_type team(A.extent(1), Kokkos::AUTO);
team_policy_type team(space, A.extent(1), Kokkos::AUTO);
using functor_type = TwoLevelTransposeGEMV<AViewType, XViewType,
YViewType, true, IndexType>;
functor_type functor(alpha, A, x, beta, y);
Expand All @@ -786,23 +788,25 @@ void twoLevelGemv(const char trans[],
template <class AViewType, class XViewType, class YViewType, class IndexType,
typename std::enable_if<!KokkosKernels::Impl::kk_is_gpu_exec_space<
typename AViewType::execution_space>()>::type* = nullptr>
void generalGemvImpl(const char trans[],
void generalGemvImpl(const typename AViewType::execution_space& space,
const char trans[],
typename AViewType::const_value_type& alpha,
const AViewType& A, const XViewType& x,
typename YViewType::const_value_type& beta,
const YViewType& y) {
singleLevelGemv(trans, alpha, A, x, beta, y);
singleLevelGemv(space, trans, alpha, A, x, beta, y);
}

template <class AViewType, class XViewType, class YViewType, class IndexType,
typename std::enable_if<KokkosKernels::Impl::kk_is_gpu_exec_space<
typename AViewType::execution_space>()>::type* = nullptr>
void generalGemvImpl(const char trans[],
void generalGemvImpl(const typename AViewType::execution_space& space,
const char trans[],
typename AViewType::const_value_type& alpha,
const AViewType& A, const XViewType& x,
typename YViewType::const_value_type& beta,
const YViewType& y) {
twoLevelGemv(trans, alpha, A, x, beta, y);
twoLevelGemv(space, trans, alpha, A, x, beta, y);
}

} // namespace Impl
Expand Down
11 changes: 6 additions & 5 deletions src/blas/impl/KokkosBlas2_gemv_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ template <class AViewType, class XViewType, class YViewType,
bool eti_spec_avail =
gemv_eti_spec_avail<AViewType, XViewType, YViewType>::value>
struct GEMV {
static void gemv(const char trans[],
static void gemv(const typename AViewType::execution_space& space,
const char trans[],
typename AViewType::const_value_type& alpha,
const AViewType& A, const XViewType& x,
typename YViewType::const_value_type& beta,
Expand Down Expand Up @@ -130,11 +131,11 @@ struct GEMV {
// Prefer int as the index type, but use a larger type if needed.
if (numRows < static_cast<size_type>(INT_MAX) &&
numCols < static_cast<size_type>(INT_MAX)) {
generalGemvImpl<AViewType, XViewType, YViewType, int>(trans, alpha, A, x,
beta, y);
generalGemvImpl<AViewType, XViewType, YViewType, int>(space, trans, alpha,
A, x, beta, y);
} else {
generalGemvImpl<AViewType, XViewType, YViewType, int64_t>(trans, alpha, A,
x, beta, y);
generalGemvImpl<AViewType, XViewType, YViewType, int64_t>(
space, trans, alpha, A, x, beta, y);
}
Kokkos::Profiling::popRegion();
}
Expand Down
7 changes: 4 additions & 3 deletions src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct DotBasedGEMM {
numCcols(C.extent(1)),
dotSize(A.extent(0)) {}

void run(bool conjugateTranspose) {
void run(const typename CV::execution_space& space, bool conjugateTranspose) {
// NOTE: these workPerTeam and approxNumTeams were used for TPL CUBLAS,
// and may need to be retuned for other architectures
constexpr size_C workPerTeam = 4096; // Amount of work per team
Expand Down Expand Up @@ -143,11 +143,12 @@ struct DotBasedGEMM {

// Multiply alpha*A^TB and add it to beta*C
if (conjugateTranspose) {
Kokkos::TeamPolicy<TagMultCT, ExecSpace> policyMult(numTeams,
Kokkos::TeamPolicy<TagMultCT, ExecSpace> policyMult(space, numTeams,
Kokkos::AUTO);
Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this);
} else {
Kokkos::TeamPolicy<TagMult, ExecSpace> policyMult(numTeams, Kokkos::AUTO);
Kokkos::TeamPolicy<TagMult, ExecSpace> policyMult(space, numTeams,
Kokkos::AUTO);
Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this);
}
}
Expand Down
7 changes: 4 additions & 3 deletions src/blas/impl/KokkosBlas3_gemm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ struct GEMMImpl {
beta = beta_;
}

void run(int team_size, int vector_length, int scr_level) {
void run(const ExecSpace& space, int team_size, int vector_length,
int scr_level) {
scratch_level = scr_level;
int scratch_memory_size = ViewTypeAScratch::shmem_size() +
ViewTypeBScratch::shmem_size() +
Expand All @@ -645,10 +646,10 @@ struct GEMMImpl {
// that problem but I'm not sure if that it a good perf
// parameter or why it is set to 2 for Cuda?
Kokkos::TeamPolicy<ExecSpace, Kokkos::LaunchBounds<384, 0>> policy(
num_blocks_0 * num_blocks_1, team_size, vector_length);
space, num_blocks_0 * num_blocks_1, team_size, vector_length);
#else
Kokkos::TeamPolicy<ExecSpace, Kokkos::LaunchBounds<384, 2>> policy(
num_blocks_0 * num_blocks_1, team_size, vector_length);
space, num_blocks_0 * num_blocks_1, team_size, vector_length);
#endif

Kokkos::parallel_for(
Expand Down
Loading

0 comments on commit 7f72a92

Please sign in to comment.