Skip to content

Commit

Permalink
ParIlut: create and destroy spgemm handle for each usage (#1736)
Browse files Browse the repository at this point in the history
* ParIlut: create and destroy spgemm handle for each usage

This fixes memory errors on Cuda

* Formatting

(cherry picked from commit bf9ed2a)
  • Loading branch information
jgfouca authored and ndellingwood committed Mar 28, 2023
1 parent 8aab5f9 commit e5f7e88
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions sparse/impl/KokkosSparse_par_ilut_numeric_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ struct IlutWrap {
const URowMapType& U_row_map, const UEntriesType& U_entries,
const UValuesType& U_values, LURowMapType& LU_row_map,
LUEntriesType& LU_entries, LUValuesType& LU_values) {
std::string myalg("SPGEMM_KK_MEMORY");
KokkosSparse::SPGEMMAlgorithm spgemm_algorithm =
KokkosSparse::StringToSPGEMMAlgorithm(myalg);
kh.create_spgemm_handle(spgemm_algorithm);

const size_type nrows = ih.get_nrows();

KokkosSparse::Experimental::spgemm_symbolic(
Expand All @@ -95,6 +100,8 @@ struct IlutWrap {

// Need to sort LU CRS if on CUDA!
sort_crs_matrix<execution_space>(LU_row_map, LU_entries, LU_values);

kh.destroy_spgemm_handle();
}

/**
Expand Down Expand Up @@ -716,6 +723,8 @@ struct IlutWrap {
RRowMapType& R_row_map, REntriesType& R_entries, RValuesType& R_values,
LURowMapType& LU_row_map, LUEntriesType& LU_entries,
LUValuesType& LU_values) {
scalar_t result;

multiply_matrices(kh, ih, L_row_map, L_entries, L_values, U_row_map,
U_entries, U_values, LU_row_map, LU_entries, LU_values);

Expand All @@ -731,8 +740,6 @@ struct IlutWrap {
&kh, A_row_map, A_entries, A_values, 1., LU_row_map, LU_entries,
LU_values, -1., R_row_map, R_entries, R_values);

scalar_t result;

auto policy = ih.get_default_team_policy();

Kokkos::parallel_reduce(
Expand Down Expand Up @@ -852,11 +859,6 @@ struct IlutWrap {
thandle.get_residual_norm_delta_stop();
const size_type max_iter = thandle.get_max_iter();

std::string myalg("SPGEMM_KK_MEMORY");
KokkosSparse::SPGEMMAlgorithm spgemm_algorithm =
KokkosSparse::StringToSPGEMMAlgorithm(myalg);
kh.create_spgemm_handle(spgemm_algorithm);

kh.create_spadd_handle(true /*we expect inputs to be sorted*/);

//
Expand Down Expand Up @@ -969,7 +971,6 @@ struct IlutWrap {
++itr;
}

kh.destroy_spgemm_handle();
kh.destroy_spadd_handle();
} // end ilut_numeric

Expand Down

0 comments on commit e5f7e88

Please sign in to comment.