Skip to content

Commit

Permalink
Restore BLAS-1 MV paths for 1 column
Browse files Browse the repository at this point in the history
Also: test these paths, test nrm2w, and use 3-arg (async) deep copies in
the >1 column paths of these kernels.
  • Loading branch information
brian-kelley committed Mar 1, 2022
1 parent 56c2398 commit 3cafac3
Show file tree
Hide file tree
Showing 17 changed files with 370 additions and 37 deletions.
5 changes: 3 additions & 2 deletions src/blas/impl/KokkosBlas1_dot_mv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ void MV_Dot_Invoke(
}
// Zero out the result vector
Kokkos::deep_copy(
r, Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
execution_space(), r,
Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
size_type teamsPerDot;
KokkosBlas::Impl::multipleReductionWorkDistribution<execution_space,
size_type>(
Expand All @@ -156,7 +157,7 @@ void MV_Dot_Invoke(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "Dot_MV temp result"),
r.extent(0));
MV_Dot_Invoke<decltype(tempResult), XV, YV, size_type>(tempResult, x, y);
Kokkos::deep_copy(r, tempResult);
Kokkos::deep_copy(typename XV::execution_space(), r, tempResult);
}

} // namespace Impl
Expand Down
45 changes: 38 additions & 7 deletions src/blas/impl/KokkosBlas1_dot_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,20 @@ struct Dot<RV, XV, YV, X_Rank, Y_Rank, false,

typedef typename YV::size_type size_type;

// Helper to get the first column of a rank-1 or rank-2 view.
// This makes it easier to add a path for single-column dot.
template <typename V>
static auto getFirstColumn(
const V& v, typename std::enable_if<V::rank == 2>::type* = nullptr) {
return Kokkos::subview(v, Kokkos::ALL(), 0);
}

template <typename V>
static V getFirstColumn(
const V& v, typename std::enable_if<V::rank == 1>::type* = nullptr) {
return v;
}

static void dot(const RV& R, const XV& X, const YV& Y) {
Kokkos::Profiling::pushRegion(KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
? "KokkosBlas::dot[ETI]"
Expand All @@ -392,14 +406,31 @@ struct Dot<RV, XV, YV, X_Rank, Y_Rank, false,
#endif

const size_type numRows = X.extent(0);
const size_type numCols = X.extent(1);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numCols < static_cast<size_type>(INT_MAX)) {
typedef int index_type;
MV_Dot_Invoke<RV, XV, YV, index_type>(R, X, Y);
const size_type numDots = std::max(X.extent(1), Y.extent(1));
if (numDots == Kokkos::ArithTraits<size_type>::one()) {
auto R0 = Kokkos::subview(R, 0);
auto X0 = getFirstColumn(X);
auto Y0 = getFirstColumn(Y);
if (numRows < static_cast<size_type>(INT_MAX)) {
typedef int index_type;
DotFunctor<decltype(R0), decltype(X0), decltype(Y0), index_type> f(X0,
Y0);
f.run("KokkosBlas::dot<1D>", R0);
} else {
typedef int64_t index_type;
DotFunctor<decltype(R0), decltype(X0), decltype(Y0), index_type> f(X0,
Y0);
f.run("KokkosBlas::dot<1D>", R0);
}
} else {
typedef std::int64_t index_type;
MV_Dot_Invoke<RV, XV, YV, index_type>(R, X, Y);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numDots < static_cast<size_type>(INT_MAX)) {
typedef int index_type;
MV_Dot_Invoke<RV, XV, YV, index_type>(R, X, Y);
} else {
typedef std::int64_t index_type;
MV_Dot_Invoke<RV, XV, YV, index_type>(R, X, Y);
}
}
Kokkos::Profiling::popRegion();
}
Expand Down
5 changes: 3 additions & 2 deletions src/blas/impl/KokkosBlas1_nrm1_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ void MV_Nrm1_Invoke(
}
// Zero out the result vector
Kokkos::deep_copy(
r, Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
execution_space(), r,
Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
size_type teamsPerVec;
KokkosBlas::Impl::multipleReductionWorkDistribution<execution_space,
size_type>(
Expand All @@ -195,7 +196,7 @@ void MV_Nrm1_Invoke(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "Nrm1 temp result"),
r.extent(0));
MV_Nrm1_Invoke<decltype(tempResult), XV, size_type>(tempResult, x);
Kokkos::deep_copy(r, tempResult);
Kokkos::deep_copy(typename XV::execution_space(), r, tempResult);
}

} // namespace Impl
Expand Down
21 changes: 16 additions & 5 deletions src/blas/impl/KokkosBlas1_nrm1_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,23 @@ struct Nrm1<RV, XMV, 2, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
: "KokkosBlas::nrm1[noETI]");
const size_type numRows = X.extent(0);
const size_type numCols = X.extent(1);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numCols < static_cast<size_type>(INT_MAX)) {
MV_Nrm1_Invoke<RV, XMV, int>(R, X);
if (numCols == Kokkos::ArithTraits<size_type>::one()) {
auto R0 = Kokkos::subview(R, 0);
auto X0 = Kokkos::subview(X, Kokkos::ALL(), 0);
if (numRows < static_cast<size_type>(INT_MAX)) {
V_Nrm1_Invoke<decltype(R0), decltype(X0), int>(R0, X0);
} else {
typedef std::int64_t index_type;
V_Nrm1_Invoke<decltype(R0), decltype(X0), index_type>(R0, X0);
}
} else {
typedef std::int64_t index_type;
MV_Nrm1_Invoke<RV, XMV, index_type>(R, X);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numCols < static_cast<size_type>(INT_MAX)) {
MV_Nrm1_Invoke<RV, XMV, int>(R, X);
} else {
typedef std::int64_t index_type;
MV_Nrm1_Invoke<RV, XMV, index_type>(R, X);
}
}
Kokkos::Profiling::popRegion();
}
Expand Down
5 changes: 3 additions & 2 deletions src/blas/impl/KokkosBlas1_nrm2_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ void MV_Nrm2_Invoke(
}
// Zero out the result vector
Kokkos::deep_copy(
r, Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
execution_space(), r,
Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
size_type teamsPerVec;
KokkosBlas::Impl::multipleReductionWorkDistribution<execution_space,
size_type>(
Expand Down Expand Up @@ -230,7 +231,7 @@ void MV_Nrm2_Invoke(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "Nrm2 temp result"),
r.extent(0));
MV_Nrm2_Invoke<decltype(tempResult), XV, size_type>(tempResult, x, take_sqrt);
Kokkos::deep_copy(r, tempResult);
Kokkos::deep_copy(typename XV::execution_space(), r, tempResult);
}

} // namespace Impl
Expand Down
22 changes: 17 additions & 5 deletions src/blas/impl/KokkosBlas1_nrm2_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,24 @@ struct Nrm2<RV, XMV, 2, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {

const size_type numRows = X.extent(0);
const size_type numCols = X.extent(1);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numCols < static_cast<size_type>(INT_MAX)) {
MV_Nrm2_Invoke<RV, XMV, int>(R, X, take_sqrt);
if (numCols == Kokkos::ArithTraits<size_type>::one()) {
auto R0 = Kokkos::subview(R, 0);
auto X0 = Kokkos::subview(X, Kokkos::ALL(), 0);
if (numRows < static_cast<size_type>(INT_MAX)) {
V_Nrm2_Invoke<decltype(R0), decltype(X0), int>(R0, X0, take_sqrt);
} else {
typedef std::int64_t index_type;
V_Nrm2_Invoke<decltype(R0), decltype(X0), index_type>(R0, X0,
take_sqrt);
}
} else {
typedef std::int64_t index_type;
MV_Nrm2_Invoke<RV, XMV, index_type>(R, X, take_sqrt);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numCols < static_cast<size_type>(INT_MAX)) {
MV_Nrm2_Invoke<RV, XMV, int>(R, X, take_sqrt);
} else {
typedef std::int64_t index_type;
MV_Nrm2_Invoke<RV, XMV, index_type>(R, X, take_sqrt);
}
}
Kokkos::Profiling::popRegion();
}
Expand Down
5 changes: 3 additions & 2 deletions src/blas/impl/KokkosBlas1_nrm2w_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ void MV_Nrm2w_Invoke(
}
// Zero out the result vector
Kokkos::deep_copy(
r, Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
execution_space(), r,
Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
size_type teamsPerVec;
KokkosBlas::Impl::multipleReductionWorkDistribution<execution_space,
size_type>(
Expand Down Expand Up @@ -230,7 +231,7 @@ void MV_Nrm2w_Invoke(
r.extent(0));
MV_Nrm2w_Invoke<decltype(tempResult), XV, size_type>(tempResult, x, w,
take_sqrt);
Kokkos::deep_copy(r, tempResult);
Kokkos::deep_copy(typename XV::execution_space(), r, tempResult);
}

} // namespace Impl
Expand Down
23 changes: 18 additions & 5 deletions src/blas/impl/KokkosBlas1_nrm2w_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,25 @@ struct Nrm2w<RV, XMV, 2, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {

const size_type numRows = X.extent(0);
const size_type numCols = X.extent(1);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numCols < static_cast<size_type>(INT_MAX)) {
MV_Nrm2w_Invoke<RV, XMV, int>(R, X, W, take_sqrt);
if (numCols == 1) {
auto R0 = Kokkos::subview(R, 0);
auto X0 = Kokkos::subview(X, Kokkos::ALL(), 0);
auto W0 = Kokkos::subview(W, Kokkos::ALL(), 0);
if (numRows < static_cast<size_type>(INT_MAX)) {
V_Nrm2w_Invoke<decltype(R0), decltype(X0), int>(R0, X0, W0, take_sqrt);
} else {
typedef std::int64_t index_type;
V_Nrm2w_Invoke<decltype(R0), decltype(X0), index_type>(R0, X0, W0,
take_sqrt);
}
} else {
typedef std::int64_t index_type;
MV_Nrm2w_Invoke<RV, XMV, index_type>(R, X, W, take_sqrt);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numCols < static_cast<size_type>(INT_MAX)) {
MV_Nrm2w_Invoke<RV, XMV, int>(R, X, W, take_sqrt);
} else {
typedef std::int64_t index_type;
MV_Nrm2w_Invoke<RV, XMV, index_type>(R, X, W, take_sqrt);
}
}
Kokkos::Profiling::popRegion();
}
Expand Down
5 changes: 3 additions & 2 deletions src/blas/impl/KokkosBlas1_sum_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ void MV_Sum_Invoke(
}
// Zero out the result vector
Kokkos::deep_copy(
r, Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
execution_space(), r,
Kokkos::ArithTraits<typename RV::non_const_value_type>::zero());
size_type teamsPerVec;
KokkosBlas::Impl::multipleReductionWorkDistribution<execution_space,
size_type>(
Expand All @@ -187,7 +188,7 @@ void MV_Sum_Invoke(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "Sum temp result"),
r.extent(0));
MV_Sum_Invoke<decltype(tempResult), XV, size_type>(tempResult, x);
Kokkos::deep_copy(r, tempResult);
Kokkos::deep_copy(typename XV::execution_space(), r, tempResult);
}

} // namespace Impl
Expand Down
21 changes: 16 additions & 5 deletions src/blas/impl/KokkosBlas1_sum_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,23 @@ struct Sum<RV, XMV, 2, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {

const size_type numRows = X.extent(0);
const size_type numCols = X.extent(1);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numCols < static_cast<size_type>(INT_MAX)) {
MV_Sum_Invoke<RV, XMV, int>(R, X);
if (numCols == Kokkos::ArithTraits<size_type>::one()) {
auto R0 = Kokkos::subview(R, 0);
auto X0 = Kokkos::subview(X, Kokkos::ALL(), 0);
if (numRows < static_cast<size_type>(INT_MAX)) {
V_Sum_Invoke<decltype(R0), decltype(X0), int>(R0, X0);
} else {
typedef std::int64_t index_type;
V_Sum_Invoke<decltype(R0), decltype(X0), index_type>(R0, X0);
}
} else {
typedef std::int64_t index_type;
MV_Sum_Invoke<RV, XMV, index_type>(R, X);
if (numRows < static_cast<size_type>(INT_MAX) &&
numRows * numCols < static_cast<size_type>(INT_MAX)) {
MV_Sum_Invoke<RV, XMV, int>(R, X);
} else {
typedef std::int64_t index_type;
MV_Sum_Invoke<RV, XMV, index_type>(R, X);
}
}
Kokkos::Profiling::popRegion();
}
Expand Down
1 change: 1 addition & 0 deletions unit_test/blas/Test_Blas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "Test_Blas1_nrm1.hpp"
#include "Test_Blas1_nrm2_squared.hpp"
#include "Test_Blas1_nrm2.hpp"
#include "Test_Blas1_nrm2w.hpp"
#include "Test_Blas1_nrminf.hpp"
#include "Test_Blas1_reciprocal.hpp"
#include "Test_Blas1_scal.hpp"
Expand Down
3 changes: 3 additions & 0 deletions unit_test/blas/Test_Blas1_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ int test_dot_mv() {
Test::impl_test_dot_mv<view_type_a_ll, view_type_b_ll, Device>(0, 5);
Test::impl_test_dot_mv<view_type_a_ll, view_type_b_ll, Device>(13, 5);
Test::impl_test_dot_mv<view_type_a_ll, view_type_b_ll, Device>(1024, 5);
Test::impl_test_dot_mv<view_type_a_ll, view_type_b_ll, Device>(789, 1);
// Test::impl_test_dot_mv<view_type_a_ll, view_type_b_ll, Device>(132231,5);
#endif

Expand All @@ -207,6 +208,7 @@ int test_dot_mv() {
Test::impl_test_dot_mv<view_type_a_lr, view_type_b_lr, Device>(0, 5);
Test::impl_test_dot_mv<view_type_a_lr, view_type_b_lr, Device>(13, 5);
Test::impl_test_dot_mv<view_type_a_lr, view_type_b_lr, Device>(1024, 5);
Test::impl_test_dot_mv<view_type_a_ll, view_type_b_lr, Device>(789, 1);
// Test::impl_test_dot_mv<view_type_a_lr, view_type_b_lr, Device>(132231,5);
#endif

Expand All @@ -218,6 +220,7 @@ int test_dot_mv() {
Test::impl_test_dot_mv<view_type_a_ls, view_type_b_ls, Device>(0, 5);
Test::impl_test_dot_mv<view_type_a_ls, view_type_b_ls, Device>(13, 5);
Test::impl_test_dot_mv<view_type_a_ls, view_type_b_ls, Device>(1024, 5);
Test::impl_test_dot_mv<view_type_a_ll, view_type_b_ls, Device>(789, 1);
// Test::impl_test_dot_mv<view_type_a_ls, view_type_b_ls, Device>(132231,5);
#endif

Expand Down
3 changes: 3 additions & 0 deletions unit_test/blas/Test_Blas1_nrm1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ int test_nrm1_mv() {
Test::impl_test_nrm1_mv<view_type_a_ll, Device>(0, 5);
Test::impl_test_nrm1_mv<view_type_a_ll, Device>(13, 5);
Test::impl_test_nrm1_mv<view_type_a_ll, Device>(1024, 5);
Test::impl_test_nrm1_mv<view_type_a_ll, Device>(789, 1);
Test::impl_test_nrm1_mv<view_type_a_ll, Device>(132231, 5);
#endif

Expand All @@ -159,6 +160,7 @@ int test_nrm1_mv() {
Test::impl_test_nrm1_mv<view_type_a_lr, Device>(0, 5);
Test::impl_test_nrm1_mv<view_type_a_lr, Device>(13, 5);
Test::impl_test_nrm1_mv<view_type_a_lr, Device>(1024, 5);
Test::impl_test_nrm1_mv<view_type_a_lr, Device>(789, 1);
Test::impl_test_nrm1_mv<view_type_a_lr, Device>(132231, 5);
#endif

Expand All @@ -169,6 +171,7 @@ int test_nrm1_mv() {
Test::impl_test_nrm1_mv<view_type_a_ls, Device>(0, 5);
Test::impl_test_nrm1_mv<view_type_a_ls, Device>(13, 5);
Test::impl_test_nrm1_mv<view_type_a_ls, Device>(1024, 5);
Test::impl_test_nrm1_mv<view_type_a_ls, Device>(789, 1);
Test::impl_test_nrm1_mv<view_type_a_ls, Device>(132231, 5);
#endif

Expand Down
3 changes: 3 additions & 0 deletions unit_test/blas/Test_Blas1_nrm2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ int test_nrm2_mv() {
Test::impl_test_nrm2_mv<view_type_a_ll, Device>(0, 5);
Test::impl_test_nrm2_mv<view_type_a_ll, Device>(13, 5);
Test::impl_test_nrm2_mv<view_type_a_ll, Device>(1024, 5);
Test::impl_test_nrm2_mv<view_type_a_ll, Device>(789, 1);
// Test::impl_test_nrm2_mv<view_type_a_ll, Device>(132231,5);
#endif

Expand All @@ -154,6 +155,7 @@ int test_nrm2_mv() {
Test::impl_test_nrm2_mv<view_type_a_lr, Device>(0, 5);
Test::impl_test_nrm2_mv<view_type_a_lr, Device>(13, 5);
Test::impl_test_nrm2_mv<view_type_a_lr, Device>(1024, 5);
Test::impl_test_nrm2_mv<view_type_a_lr, Device>(789, 1);
// Test::impl_test_nrm2_mv<view_type_a_lr, Device>(132231,5);
#endif

Expand All @@ -164,6 +166,7 @@ int test_nrm2_mv() {
Test::impl_test_nrm2_mv<view_type_a_ls, Device>(0, 5);
Test::impl_test_nrm2_mv<view_type_a_ls, Device>(13, 5);
Test::impl_test_nrm2_mv<view_type_a_ls, Device>(1024, 5);
Test::impl_test_nrm2_mv<view_type_a_ls, Device>(789, 1);
// Test::impl_test_nrm2_mv<view_type_a_ls, Device>(132231,5);
#endif

Expand Down
3 changes: 3 additions & 0 deletions unit_test/blas/Test_Blas1_nrm2_squared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ int test_nrm2_squared_mv() {
Test::impl_test_nrm2_squared_mv<view_type_a_ll, Device>(0, 5);
Test::impl_test_nrm2_squared_mv<view_type_a_ll, Device>(13, 5);
Test::impl_test_nrm2_squared_mv<view_type_a_ll, Device>(1024, 5);
Test::impl_test_nrm2_squared_mv<view_type_a_ll, Device>(789, 1);
// Test::impl_test_nrm2_squared_mv<view_type_a_ll, Device>(132231,5);
#endif

Expand All @@ -170,6 +171,7 @@ int test_nrm2_squared_mv() {
Test::impl_test_nrm2_squared_mv<view_type_a_lr, Device>(0, 5);
Test::impl_test_nrm2_squared_mv<view_type_a_lr, Device>(13, 5);
Test::impl_test_nrm2_squared_mv<view_type_a_lr, Device>(1024, 5);
Test::impl_test_nrm2_squared_mv<view_type_a_lr, Device>(789, 1);
// Test::impl_test_nrm2_squared_mv<view_type_a_lr, Device>(132231,5);
#endif

Expand All @@ -180,6 +182,7 @@ int test_nrm2_squared_mv() {
Test::impl_test_nrm2_squared_mv<view_type_a_ls, Device>(0, 5);
Test::impl_test_nrm2_squared_mv<view_type_a_ls, Device>(13, 5);
Test::impl_test_nrm2_squared_mv<view_type_a_ls, Device>(1024, 5);
Test::impl_test_nrm2_squared_mv<view_type_a_ls, Device>(789, 1);
// Test::impl_test_nrm2_squared_mv<view_type_a_ls, Device>(132231,5);
#endif

Expand Down
Loading

0 comments on commit 3cafac3

Please sign in to comment.