Skip to content

Commit

Permalink
Use default_size_type as default offset in matrix types (#2149)
Browse files Browse the repository at this point in the history
Now a declaration like CrsMatrix<Scalar, Ordinal, Device>
will by default use an ETI'd type combination (as int is the default
ETI'd offset)
  • Loading branch information
brian-kelley authored Mar 26, 2024
1 parent 363868e commit 8756faa
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 17 deletions.
5 changes: 2 additions & 3 deletions sparse/src/KokkosSparse_BsrMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "Kokkos_ArithTraits.hpp"
#include "KokkosSparse_CrsMatrix.hpp"
#include "KokkosKernels_Error.hpp"
#include "KokkosKernels_default_types.hpp"

namespace KokkosSparse {

Expand Down Expand Up @@ -325,9 +326,7 @@ struct BsrRowViewConst {
/// storage for sparse matrices, as described, for example, in Saad
/// (2nd ed.).
template <class ScalarType, class OrdinalType, class Device,
class MemoryTraits = void,
class SizeType = typename Kokkos::ViewTraits<OrdinalType*, Device,
void, void>::size_type>
class MemoryTraits = void, class SizeType = default_size_type>
class BsrMatrix {
static_assert(
std::is_signed<OrdinalType>::value,
Expand Down
4 changes: 1 addition & 3 deletions sparse/src/KokkosSparse_CrsMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,7 @@ struct SparseRowViewConst {
/// storage for sparse matrices, as described, for example, in Saad
/// (2nd ed.).
template <class ScalarType, class OrdinalType, class Device,
class MemoryTraits = void,
class SizeType = typename Kokkos::ViewTraits<OrdinalType*, Device,
void, void>::size_type>
class MemoryTraits = void, class SizeType = default_size_type>
class CrsMatrix {
static_assert(
std::is_signed<OrdinalType>::value,
Expand Down
8 changes: 8 additions & 0 deletions sparse/src/KokkosSparse_ccs2crs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ template <class OrdinalType, class SizeType, class ValViewType,
class ColMapViewType, class RowIdViewType>
auto ccs2crs(OrdinalType nrows, OrdinalType ncols, SizeType nnz,
ValViewType vals, ColMapViewType col_map, RowIdViewType row_ids) {
static_assert(
std::is_same_v<SizeType, typename ColMapViewType::non_const_value_type>,
"ccs2crs: SizeType (type of nnz) must match the element type of "
"ColMapViewType");
static_assert(
std::is_same_v<OrdinalType, typename RowIdViewType::non_const_value_type>,
"ccs2crs: OrdinalType (type of nrows, ncols) must match the element type "
"of RowIdViewType");
using Ccs2crsType = Impl::Ccs2Crs<OrdinalType, SizeType, ValViewType,
ColMapViewType, RowIdViewType>;
Ccs2crsType ccs2Crs(nrows, ncols, nnz, vals, col_map, row_ids);
Expand Down
10 changes: 9 additions & 1 deletion sparse/src/KokkosSparse_crs2ccs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ template <class OrdinalType, class SizeType, class ValViewType,
class RowMapViewType, class ColIdViewType>
auto crs2ccs(OrdinalType nrows, OrdinalType ncols, SizeType nnz,
ValViewType vals, RowMapViewType row_map, ColIdViewType col_ids) {
static_assert(
std::is_same_v<SizeType, typename RowMapViewType::non_const_value_type>,
"crs2ccs: SizeType (type of nnz) must match the element type of "
"RowMapViewType");
static_assert(
std::is_same_v<OrdinalType, typename ColIdViewType::non_const_value_type>,
"crs2ccs: OrdinalType (type of nrows, ncols) must match the element type "
"of ColIdViewType");
using Crs2ccsType = Impl::Crs2Ccs<OrdinalType, SizeType, ValViewType,
RowMapViewType, ColIdViewType>;
Crs2ccsType crs2Ccs(nrows, ncols, nnz, vals, row_map, col_ids);
Expand Down Expand Up @@ -128,4 +136,4 @@ auto crs2ccs(KokkosSparse::CrsMatrix<ScalarType, OrdinalType, DeviceType,
}
} // namespace KokkosSparse

#endif // _KOKKOSSPARSE_CRS2CCS_HPP
#endif // _KOKKOSSPARSE_CRS2CCS_HPP
10 changes: 9 additions & 1 deletion sparse/src/KokkosSparse_crs2coo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ template <class OrdinalType, class SizeType, class ValViewType,
class RowMapViewType, class ColIdViewType>
auto crs2coo(OrdinalType nrows, OrdinalType ncols, SizeType nnz,
ValViewType vals, RowMapViewType row_map, ColIdViewType col_ids) {
static_assert(
std::is_same_v<SizeType, typename RowMapViewType::non_const_value_type>,
"crs2coo: SizeType (type of nnz) must match the element type of "
"RowMapViewType");
static_assert(
std::is_same_v<OrdinalType, typename ColIdViewType::non_const_value_type>,
"crs2coo: OrdinalType (type of nrows, ncols) must match the element type "
"of ColIdViewType");
using Crs2cooType = Impl::Crs2Coo<OrdinalType, SizeType, ValViewType,
RowMapViewType, ColIdViewType>;
Crs2cooType crs2Coo(nrows, ncols, nnz, vals, row_map, col_ids);
Expand Down Expand Up @@ -161,4 +169,4 @@ auto crs2coo(KokkosSparse::CrsMatrix<ScalarType, OrdinalType, DeviceType,
crsMatrix.graph.entries);
}
} // namespace KokkosSparse
#endif // _KOKKOSSPARSE_CRS2COO_HPP
#endif // _KOKKOSSPARSE_CRS2COO_HPP
8 changes: 5 additions & 3 deletions sparse/unit_test/Test_Sparse_TestUtils_RandCsMat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
namespace Test {
template <class ScalarType, class LayoutType, class ExeSpaceType>
void doCsMat(size_t m, size_t n, ScalarType min_val, ScalarType max_val) {
using RandCs = RandCsMatrix<ScalarType, LayoutType, ExeSpaceType>;
using size_type = typename RandCs::size_type;
auto expected_min = ScalarType(1.0);
size_t expected_nnz = 0;
RandCsMatrix<ScalarType, LayoutType, ExeSpaceType> cm(m, n, min_val, max_val);
RandCs cm(m, n, min_val, max_val);

for (size_t i = 0; i < cm.get_nnz(); ++i)
for (size_type i = 0; i < cm.get_nnz(); ++i)
ASSERT_GE(cm(i), expected_min) << cm.info;

auto map_d = cm.get_map();
Expand Down Expand Up @@ -83,4 +85,4 @@ TEST_F(TestCategory, sparse_randcsmat) {
doAllCsMat<TestDevice>(dim, dim * 3);
}
}
} // namespace Test
} // namespace Test
4 changes: 2 additions & 2 deletions sparse/unit_test/Test_Sparse_spmv_bsr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ Bsr bsr_random(const int blockSize, const int blockRows, const int blockCols) {
rcs(blockRows, blockCols, scalar_type(0), max_a<scalar_type>());

const auto colids = Kokkos::subview(
rcs.get_ids(), Kokkos::make_pair(size_t(0), rcs.get_nnz()));
rcs.get_ids(), Kokkos::make_pair(size_type(0), rcs.get_nnz()));
const auto vals = Kokkos::subview(
rcs.get_vals(), Kokkos::make_pair(size_t(0), rcs.get_nnz()));
rcs.get_vals(), Kokkos::make_pair(size_type(0), rcs.get_nnz()));
Graph graph(colids, rcs.get_map());
Crs crs("crs", blockCols, vals, graph);

Expand Down
6 changes: 2 additions & 4 deletions test_common/KokkosKernels_TestUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,7 @@ class RandCooMat {
/// \tparam LayoutType
/// \tparam Device
template <class ScalarType, class LayoutType, class Device,
typename Ordinal = int64_t,
typename Size = typename Kokkos::ViewTraits<Ordinal*, Device, void,
void>::size_type>
typename Ordinal = int64_t, typename Size = default_size_type>
class RandCsMatrix {
public:
using value_type = ScalarType;
Expand Down Expand Up @@ -765,7 +763,7 @@ class RandCsMatrix {

// O(c), where c is a constant.
ScalarType operator()(Size idx) { return __vals(idx); }
size_t get_nnz() { return size_t(__nnz); }
Size get_nnz() { return __nnz; }
// dimension2: This is either columns for a Crs matrix or rows for a Ccs
// matrix.
Ordinal get_dim2() { return __dim2; }
Expand Down

0 comments on commit 8756faa

Please sign in to comment.