Skip to content

Commit

Permalink
use crs_matrix view traits for magnitude view
Browse files Browse the repository at this point in the history
  • Loading branch information
tmranse committed Apr 6, 2023
1 parent 1c2105b commit 51ac816
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 37 deletions.
49 changes: 16 additions & 33 deletions sparse/impl/KokkosSparse_mdf_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,16 @@
namespace KokkosSparse {
namespace Impl {

template <typename T, std::size_t N>
struct add_N_pointers {
using type = typename add_N_pointers<std::add_pointer_t<T>, N - 1>::type;
};
template <typename T>
struct add_N_pointers<T, std::size_t(0)> {
using type = T;
template <typename crs_matrix_type>
struct MDF_types {
using scalar_type = typename crs_matrix_type::value_type;
using KAS = typename Kokkos::ArithTraits<scalar_type>;
using scalar_mag_type = typename KAS::mag_type;
using values_mag_type = Kokkos::View<scalar_mag_type*, Kokkos::LayoutRight,
typename crs_matrix_type::device_type,
typename crs_matrix_type::memory_traits>;
};

template <typename T, typename... Args>
auto create_mag_mirror_view(const Kokkos::View<T, Args...>& v) {
using src_view_t = Kokkos::View<T, Args...>;
using KAS = Kokkos::ArithTraits<typename src_view_t::non_const_value_type>;
using mag_type = typename KAS::mag_type;
using data_type = typename add_N_pointers<mag_type, src_view_t::rank()>::type;
return Kokkos::View<data_type, Args...>(
Kokkos::ViewAllocateWithoutInitializing(v.label() + "::Magnitude"),
v.layout());
}

template <typename SrcView>
using mag_mirror_view_t = decltype(create_mag_mirror_view(SrcView()));

template <class crs_matrix_type>
struct MDF_count_lower {
using col_ind_type = typename crs_matrix_type::StaticCrsGraphType::
Expand Down Expand Up @@ -82,16 +69,14 @@ struct MDF_discarded_fill_norm {
using col_ind_type =
typename static_crs_graph_type::entries_type::non_const_type;
using values_type = typename crs_matrix_type::values_type::non_const_type;
using values_mag_type = KokkosSparse::Impl::mag_mirror_view_t<values_type>;
using values_mag_type = typename MDF_types<crs_matrix_type>::values_mag_type;
using size_type = typename crs_matrix_type::size_type;
using ordinal_type = typename crs_matrix_type::ordinal_type;
using scalar_type = typename crs_matrix_type::value_type;
using KAS = typename Kokkos::ArithTraits<scalar_type>;
using scalar_mag_type = typename KAS::mag_type;
using KAM = typename Kokkos::ArithTraits<scalar_mag_type>;

const scalar_mag_type zero = KAM::zero();

crs_matrix_type A, At;
ordinal_type factorization_step;
col_ind_type permutation;
Expand All @@ -116,8 +101,8 @@ struct MDF_discarded_fill_norm {
KOKKOS_INLINE_FUNCTION
void operator()(const ordinal_type i) const {
ordinal_type rowIdx = permutation(i);
scalar_mag_type discard_norm = zero;
scalar_type diag_val = zero;
scalar_mag_type discard_norm = KAM::zero();
scalar_type diag_val = KAS::zero();
bool entryIsDiscarded = true;
ordinal_type numFillEntries = 0;
for (size_type alphaIdx = At.graph.row_map(rowIdx);
Expand Down Expand Up @@ -217,9 +202,7 @@ struct MDF_selective_discarded_fill_norm {
using KAS = typename Kokkos::ArithTraits<scalar_type>;
using scalar_mag_type = typename KAS::mag_type;
using KAM = typename Kokkos::ArithTraits<scalar_mag_type>;
using values_mag_type = KokkosSparse::Impl::mag_mirror_view_t<values_type>;

const scalar_mag_type zero = KAS::abs(KAS::zero());
using values_mag_type = typename MDF_types<crs_matrix_type>::values_mag_type;

crs_matrix_type A, At;
ordinal_type factorization_step;
Expand Down Expand Up @@ -248,8 +231,8 @@ struct MDF_selective_discarded_fill_norm {
KOKKOS_INLINE_FUNCTION
void operator()(const ordinal_type i) const {
ordinal_type rowIdx = permutation(update_list(i));
scalar_mag_type discard_norm = zero;
scalar_type diag_val = zero;
scalar_mag_type discard_norm = KAM::zero();
scalar_type diag_val = KAS::zero();
bool entryIsDiscarded = true;
ordinal_type numFillEntries = 0;
for (size_type alphaIdx = At.graph.row_map(rowIdx);
Expand Down Expand Up @@ -350,7 +333,7 @@ struct MDF_select_row {
using size_type = typename crs_matrix_type::size_type;
using ordinal_type = typename crs_matrix_type::ordinal_type;
using scalar_type = typename crs_matrix_type::value_type;
using values_mag_type = KokkosSparse::Impl::mag_mirror_view_t<values_type>;
using values_mag_type = typename MDF_types<crs_matrix_type>::values_mag_type;

// type used to perform the reduction
// do not confuse it with scalar_type!
Expand Down Expand Up @@ -462,7 +445,7 @@ struct MDF_factorize_row {
using ordinal_type = typename crs_matrix_type::ordinal_type;
using size_type = typename crs_matrix_type::size_type;
using value_type = typename crs_matrix_type::value_type;
using values_mag_type = KokkosSparse::Impl::mag_mirror_view_t<values_type>;
using values_mag_type = typename MDF_types<crs_matrix_type>::values_mag_type;
using value_mag_type = typename values_mag_type::value_type;

crs_matrix_type A, At;
Expand Down
8 changes: 4 additions & 4 deletions sparse/src/KokkosSparse_mdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ template <class crs_matrix_type, class MDF_handle>
void mdf_numeric(const crs_matrix_type& A, MDF_handle& handle) {
using col_ind_type = typename crs_matrix_type::StaticCrsGraphType::
entries_type::non_const_type;
using values_type = typename crs_matrix_type::values_type::non_const_type;
using values_mag_type = KokkosSparse::Impl::mag_mirror_view_t<values_type>;
using ordinal_type = typename crs_matrix_type::ordinal_type;
using value_mag_type = typename values_mag_type::value_type;
using values_mag_type =
typename KokkosSparse::Impl::MDF_types<crs_matrix_type>::values_mag_type;
using ordinal_type = typename crs_matrix_type::ordinal_type;
using value_mag_type = typename values_mag_type::value_type;

using execution_space = typename crs_matrix_type::execution_space;
using range_policy_type = Kokkos::RangePolicy<ordinal_type, execution_space>;
Expand Down

0 comments on commit 51ac816

Please sign in to comment.