Skip to content

Commit

Permalink
Merge Reorder krylov bases of gmres
Browse files Browse the repository at this point in the history
This PR reorders the Krylov bases of gmres and does some improvement

Ki = the i-th vector of Krylov basis for each rhs
K = [K1 K2 ... Kn], which lead the stride of Krylov still be larger than 1 even if the num_rhs = 1.
use K = [K1; K2; ...; Kn] to get coalesced memory access when num_rhs = 1
Also, use the vendor's dot function when num_rhs = 1.

Related PR: #523
  • Loading branch information
yhmtsai authored May 15, 2020
2 parents bebd252 + 966568c commit 2fb46e1
Show file tree
Hide file tree
Showing 11 changed files with 405 additions and 441 deletions.
161 changes: 65 additions & 96 deletions common/solver/gmres_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ __global__ __launch_bounds__(block_size) void initialize_2_1_kernel(
const auto global_id = thread::get_thread_id_flat();
const auto row_idx = global_id / stride_krylov;
const auto col_idx = global_id % stride_krylov;

if (row_idx < num_rows && col_idx < (krylov_dim + 1) * num_rhs) {
const auto krylov_bases_nrows = (krylov_dim + 1) * num_rows;
if (row_idx < krylov_bases_nrows && col_idx < num_rhs) {
krylov_bases[row_idx * stride_krylov + col_idx] = zero<ValueType>();
}

Expand All @@ -96,7 +96,6 @@ __global__ __launch_bounds__(block_size) void initialize_2_2_kernel(
const ValueType *__restrict__ residual_norm,
ValueType *__restrict__ residual_norm_collection,
ValueType *__restrict__ krylov_bases, size_type stride_krylov,
ValueType *__restrict__ next_krylov_basis, size_type stride_next_krylov,
size_type *__restrict__ final_iter_nums)
{
const auto global_id = thread::get_thread_id_flat();
Expand All @@ -112,7 +111,6 @@ __global__ __launch_bounds__(block_size) void initialize_2_2_kernel(
auto value = residual[row_idx * stride_residual + col_idx] /
residual_norm[col_idx];
krylov_bases[row_idx * stride_krylov + col_idx] = value;
next_krylov_basis[row_idx * stride_next_krylov + col_idx] = value;
}
}

Expand All @@ -132,10 +130,9 @@ __global__
template <typename ValueType>
__global__ __launch_bounds__(default_dot_size) void multidot_kernel(
size_type k, size_type num_rows, size_type num_cols,
const ValueType *__restrict__ next_krylov_basis,
size_type stride_next_krylov, const ValueType *__restrict__ krylov_bases,
size_type stride_krylov, ValueType *__restrict__ hessenberg_iter,
size_type stride_hessenberg,
const ValueType *__restrict__ krylov_bases,
const ValueType *__restrict__ next_krylov_basis, size_type stride_krylov,
ValueType *__restrict__ hessenberg_iter, size_type stride_hessenberg,
const stopping_status *__restrict__ stop_status)
{
const auto tidx = threadIdx.x;
Expand All @@ -153,14 +150,12 @@ __global__ __launch_bounds__(default_dot_size) void multidot_kernel(
ValueType *__restrict__ reduction_helper = reduction_helper_array;

ValueType local_res = zero<ValueType>();
const auto krylov_col = k * num_cols + col_idx;
if (col_idx < num_cols && !stop_status[col_idx].has_stopped()) {
for (size_type i = start_row + tidy; i < end_row;
i += default_dot_dim) {
const auto next_krylov_idx = i * stride_next_krylov + col_idx;
const auto krylov_idx = i * stride_krylov + krylov_col;
const auto krylov_idx = i * stride_krylov + col_idx;
local_res +=
next_krylov_basis[next_krylov_idx] * krylov_bases[krylov_idx];
conj(krylov_bases[krylov_idx]) * next_krylov_basis[krylov_idx];
}
}
reduction_helper[tidx * (default_dot_dim + 1) + tidy] = local_res;
Expand All @@ -185,20 +180,19 @@ __global__ __launch_bounds__(default_dot_size) void multidot_kernel(
template <int block_size, typename ValueType>
__global__ __launch_bounds__(block_size) void update_next_krylov_kernel(
size_type k, size_type num_rows, size_type num_cols,
ValueType *__restrict__ next_krylov_basis, size_type stride_next_krylov,
const ValueType *__restrict__ krylov_bases, size_type stride_krylov,
const ValueType *__restrict__ krylov_bases,
ValueType *__restrict__ next_krylov_basis, size_type stride_krylov,
const ValueType *__restrict__ hessenberg_iter, size_type stride_hessenberg,
const stopping_status *__restrict__ stop_status)
{
const auto global_id = thread::get_thread_id_flat();
const auto row_idx = global_id / stride_next_krylov;
const auto col_idx = global_id % stride_next_krylov;
const auto row_idx = global_id / stride_krylov;
const auto col_idx = global_id % stride_krylov;

if (row_idx < num_rows && col_idx < num_cols &&
!stop_status[col_idx].has_stopped()) {
const auto next_krylov_idx = row_idx * stride_next_krylov + col_idx;
const auto krylov_idx =
row_idx * stride_krylov + k * num_cols + col_idx;
const auto next_krylov_idx = row_idx * stride_krylov + col_idx;
const auto krylov_idx = row_idx * stride_krylov + col_idx;
const auto hessenberg_idx = k * stride_hessenberg + col_idx;

next_krylov_basis[next_krylov_idx] -=
Expand Down Expand Up @@ -248,87 +242,69 @@ __global__ __launch_bounds__(block_size) void update_hessenberg_2_kernel(
}


// Must be called with at least `num_rows * stride_next_krylov` threads in
// Must be called with at least `num_rows * stride_krylov` threads in
// total.
template <int block_size, typename ValueType>
__global__ __launch_bounds__(block_size) void update_krylov_next_krylov_kernel(
__global__ __launch_bounds__(block_size) void update_krylov_kernel(
size_type iter, size_type num_rows, size_type num_cols,
ValueType *__restrict__ next_krylov_basis, size_type stride_next_krylov,
ValueType *__restrict__ krylov_bases, size_type stride_krylov,
const ValueType *__restrict__ hessenberg_iter, size_type stride_hessenberg,
const stopping_status *__restrict__ stop_status)
{
const auto global_id = thread::get_thread_id_flat();
const auto row_idx = global_id / stride_next_krylov;
const auto col_idx = global_id % stride_next_krylov;
const auto row_idx = global_id / stride_krylov;
const auto col_idx = global_id % stride_krylov;
const auto hessenberg =
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx];

if (row_idx < num_rows && col_idx < num_cols &&
!stop_status[col_idx].has_stopped()) {
const auto next_krylov_idx = row_idx * stride_next_krylov + col_idx;
const auto krylov_idx =
row_idx * stride_krylov + num_cols * (iter + 1) + col_idx;

const auto next_krylov_value =
next_krylov_basis[next_krylov_idx] / hessenberg;
const auto krylov_idx = row_idx * stride_krylov + col_idx;

next_krylov_basis[next_krylov_idx] = next_krylov_value;
krylov_bases[krylov_idx] = next_krylov_value;
krylov_bases[krylov_idx] /= hessenberg;
}
}


template <typename ValueType>
__device__ void calculate_sin_and_cos_kernel(
size_type col_idx, size_type num_cols, size_type iter,
const ValueType *hessenberg_iter, size_type stride_hessenberg,
const ValueType &this_hess, const ValueType &next_hess,
ValueType *givens_sin, size_type stride_sin, ValueType *givens_cos,
size_type stride_cos)
size_type stride_cos, ValueType &register_sin, ValueType &register_cos)
{
if (hessenberg_iter[iter * stride_hessenberg + col_idx] ==
zero<ValueType>()) {
givens_cos[iter * stride_cos + col_idx] = zero<ValueType>();
givens_sin[iter * stride_sin + col_idx] = one<ValueType>();
if (this_hess == zero<ValueType>()) {
register_cos = zero<ValueType>();
register_sin = one<ValueType>();
} else {
auto hypotenuse =
sqrt(hessenberg_iter[iter * stride_hessenberg + col_idx] *
hessenberg_iter[iter * stride_hessenberg + col_idx] +
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx] *
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx]);
givens_cos[iter * stride_cos + col_idx] =
abs(hessenberg_iter[iter * stride_hessenberg + col_idx]) /
hypotenuse;
givens_sin[iter * stride_sin + col_idx] =
givens_cos[iter * stride_cos + col_idx] *
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx] /
hessenberg_iter[iter * stride_hessenberg + col_idx];
const auto scale = abs(this_hess) + abs(next_hess);
const auto hypotenuse =
scale * sqrt(abs(this_hess / scale) * abs(this_hess / scale) +
abs(next_hess / scale) * abs(next_hess / scale));
register_cos = conj(this_hess) / hypotenuse;
register_sin = conj(next_hess) / hypotenuse;
}
givens_cos[iter * stride_cos + col_idx] = register_cos;
givens_sin[iter * stride_sin + col_idx] = register_sin;
}


template <typename ValueType>
__device__ void calculate_residual_norm_kernel(
size_type col_idx, size_type num_cols, size_type iter,
const ValueType *givens_sin, size_type stride_sin,
const ValueType *givens_cos, size_type stride_cos, ValueType *residual_norm,
ValueType *residual_norm_collection,
const ValueType &register_sin, const ValueType &register_cos,
ValueType *residual_norm, ValueType *residual_norm_collection,
size_type stride_residual_norm_collection, const ValueType *b_norm)
{
residual_norm_collection[(iter + 1) * stride_residual_norm_collection +
col_idx] =
-givens_sin[iter * stride_sin + col_idx] *
const auto this_rnc =
residual_norm_collection[iter * stride_residual_norm_collection +
col_idx];
const auto next_rnc = -conj(register_sin) * this_rnc;
residual_norm_collection[iter * stride_residual_norm_collection + col_idx] =
givens_cos[iter * stride_cos + col_idx] *
residual_norm_collection[iter * stride_residual_norm_collection +
col_idx];
residual_norm[col_idx] =
abs(residual_norm_collection[(iter + 1) *
stride_residual_norm_collection +
col_idx]) /
b_norm[col_idx];
register_cos * this_rnc;
residual_norm[col_idx] = abs(next_rnc) / b_norm[col_idx];
residual_norm_collection[(iter + 1) * stride_residual_norm_collection +
col_idx] = next_rnc;
}


Expand All @@ -351,51 +327,42 @@ __global__ __launch_bounds__(block_size) void givens_rotation_kernel(
return;
}

const auto current_thread_block = group::this_thread_block();

auto this_hess = hessenberg_iter[col_idx];
auto next_hess = hessenberg_iter[stride_hessenberg + col_idx];
for (size_type i = 0; i < iter; ++i) {
const auto tmp =
givens_cos[i * stride_cos + col_idx] *
hessenberg_iter[i * stride_hessenberg + col_idx] +
givens_sin[i * stride_sin + col_idx] *
hessenberg_iter[(i + 1) * stride_hessenberg + col_idx];
current_thread_block.sync();
hessenberg_iter[(i + 1) * stride_hessenberg + col_idx] =
givens_cos[i * stride_cos + col_idx] *
hessenberg_iter[(i + 1) * stride_hessenberg + col_idx] -
givens_sin[i * stride_sin + col_idx] *
hessenberg_iter[i * stride_hessenberg + col_idx];
hessenberg_iter[i * stride_hessenberg + col_idx] = tmp;
current_thread_block.sync();
const auto cos = givens_cos[i * stride_cos + col_idx];
const auto sin = givens_sin[i * stride_sin + col_idx];
hessenberg_iter[i * stride_hessenberg + col_idx] =
cos * this_hess + sin * next_hess;
this_hess = conj(cos) * next_hess - conj(sin) * this_hess;
next_hess = hessenberg_iter[(i + 2) * stride_hessenberg + col_idx];
}
// for j in 1:iter - 1
// for j in 0:iter - 1
// temp = cos(j)*hessenberg(j) +
// sin(j)*hessenberg(j+1)
// hessenberg(j+1) = -sin(j)*hessenberg(j) +
// cos(j)*hessenberg(j+1)
// hessenberg(j) = temp;
// end

calculate_sin_and_cos_kernel(col_idx, num_cols, iter, hessenberg_iter,
stride_hessenberg, givens_sin, stride_sin,
givens_cos, stride_cos);
// Calculate sin and cos
ValueType register_sin;
ValueType register_cos;
calculate_sin_and_cos_kernel(col_idx, num_cols, iter, this_hess, next_hess,
givens_sin, stride_sin, givens_cos, stride_cos,
register_sin, register_cos);
// Calculate sin and cos on hessenberg(iter) and hessenberg(iter+1)

hessenberg_iter[iter * stride_hessenberg + col_idx] =
givens_cos[iter * stride_cos + col_idx] *
hessenberg_iter[iter * stride_hessenberg + col_idx] +
givens_sin[iter * stride_sin + col_idx] *
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx];
register_cos * this_hess + register_sin * next_hess;
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx] =
zero<ValueType>();
// hessenberg(iter) = cos(iter)*hessenberg(iter) +
// sin(iter)*hessenberg(iter)
// sin(iter)*hessenberg(iter+1)
// hessenberg(iter+1) = 0

calculate_residual_norm_kernel(col_idx, num_cols, iter, givens_sin,
stride_sin, givens_cos, stride_cos,
residual_norm, residual_norm_collection,
stride_residual_norm_collection, b_norm);
calculate_residual_norm_kernel(
col_idx, num_cols, iter, register_sin, register_cos, residual_norm,
residual_norm_collection, stride_residual_norm_collection, b_norm);
// Calculate residual norm
}

Expand Down Expand Up @@ -449,11 +416,13 @@ __global__ __launch_bounds__(block_size) void calculate_Qy_kernel(
const auto col_id = global_id % stride_preconditioner;

if (row_id < num_rows && col_id < num_cols) {
before_preconditioner[global_id] = zero<ValueType>();
ValueType temp = zero<ValueType>();

for (size_type j = 0; j < final_iter_nums[col_id]; ++j) {
before_preconditioner[global_id] +=
krylov_bases[row_id * stride_krylov + j * num_rhs + col_id] *
temp +=
krylov_bases[(row_id + j * num_rows) * stride_krylov + col_id] *
y[j * stride_y + col_id];
}
before_preconditioner[global_id] = temp;
}
}
Loading

0 comments on commit 2fb46e1

Please sign in to comment.