Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorder krylov bases of gmres #523

Merged
merged 10 commits into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 65 additions & 96 deletions common/solver/gmres_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ __global__ __launch_bounds__(block_size) void initialize_2_1_kernel(
const auto global_id = thread::get_thread_id_flat();
const auto row_idx = global_id / stride_krylov;
const auto col_idx = global_id % stride_krylov;

if (row_idx < num_rows && col_idx < (krylov_dim + 1) * num_rhs) {
const auto krylov_bases_nrows = (krylov_dim + 1) * num_rows;
if (row_idx < krylov_bases_nrows && col_idx < num_rhs) {
krylov_bases[row_idx * stride_krylov + col_idx] = zero<ValueType>();
}

Expand All @@ -96,7 +96,6 @@ __global__ __launch_bounds__(block_size) void initialize_2_2_kernel(
const ValueType *__restrict__ residual_norm,
ValueType *__restrict__ residual_norm_collection,
ValueType *__restrict__ krylov_bases, size_type stride_krylov,
ValueType *__restrict__ next_krylov_basis, size_type stride_next_krylov,
size_type *__restrict__ final_iter_nums)
{
const auto global_id = thread::get_thread_id_flat();
Expand All @@ -112,7 +111,6 @@ __global__ __launch_bounds__(block_size) void initialize_2_2_kernel(
auto value = residual[row_idx * stride_residual + col_idx] /
residual_norm[col_idx];
krylov_bases[row_idx * stride_krylov + col_idx] = value;
next_krylov_basis[row_idx * stride_next_krylov + col_idx] = value;
}
}

Expand All @@ -132,10 +130,9 @@ __global__
template <typename ValueType>
__global__ __launch_bounds__(default_dot_size) void multidot_kernel(
size_type k, size_type num_rows, size_type num_cols,
const ValueType *__restrict__ next_krylov_basis,
size_type stride_next_krylov, const ValueType *__restrict__ krylov_bases,
size_type stride_krylov, ValueType *__restrict__ hessenberg_iter,
size_type stride_hessenberg,
const ValueType *__restrict__ krylov_bases,
const ValueType *__restrict__ next_krylov_basis, size_type stride_krylov,
ValueType *__restrict__ hessenberg_iter, size_type stride_hessenberg,
const stopping_status *__restrict__ stop_status)
{
const auto tidx = threadIdx.x;
Expand All @@ -153,14 +150,12 @@ __global__ __launch_bounds__(default_dot_size) void multidot_kernel(
ValueType *__restrict__ reduction_helper = reduction_helper_array;

ValueType local_res = zero<ValueType>();
const auto krylov_col = k * num_cols + col_idx;
if (col_idx < num_cols && !stop_status[col_idx].has_stopped()) {
for (size_type i = start_row + tidy; i < end_row;
i += default_dot_dim) {
const auto next_krylov_idx = i * stride_next_krylov + col_idx;
const auto krylov_idx = i * stride_krylov + krylov_col;
const auto krylov_idx = i * stride_krylov + col_idx;
local_res +=
next_krylov_basis[next_krylov_idx] * krylov_bases[krylov_idx];
conj(krylov_bases[krylov_idx]) * next_krylov_basis[krylov_idx];
}
}
reduction_helper[tidx * (default_dot_dim + 1) + tidy] = local_res;
Expand All @@ -185,20 +180,19 @@ __global__ __launch_bounds__(default_dot_size) void multidot_kernel(
template <int block_size, typename ValueType>
__global__ __launch_bounds__(block_size) void update_next_krylov_kernel(
size_type k, size_type num_rows, size_type num_cols,
ValueType *__restrict__ next_krylov_basis, size_type stride_next_krylov,
const ValueType *__restrict__ krylov_bases, size_type stride_krylov,
const ValueType *__restrict__ krylov_bases,
ValueType *__restrict__ next_krylov_basis, size_type stride_krylov,
const ValueType *__restrict__ hessenberg_iter, size_type stride_hessenberg,
const stopping_status *__restrict__ stop_status)
{
const auto global_id = thread::get_thread_id_flat();
const auto row_idx = global_id / stride_next_krylov;
const auto col_idx = global_id % stride_next_krylov;
const auto row_idx = global_id / stride_krylov;
const auto col_idx = global_id % stride_krylov;

if (row_idx < num_rows && col_idx < num_cols &&
!stop_status[col_idx].has_stopped()) {
const auto next_krylov_idx = row_idx * stride_next_krylov + col_idx;
const auto krylov_idx =
row_idx * stride_krylov + k * num_cols + col_idx;
const auto next_krylov_idx = row_idx * stride_krylov + col_idx;
const auto krylov_idx = row_idx * stride_krylov + col_idx;
const auto hessenberg_idx = k * stride_hessenberg + col_idx;

next_krylov_basis[next_krylov_idx] -=
Expand Down Expand Up @@ -248,87 +242,69 @@ __global__ __launch_bounds__(block_size) void update_hessenberg_2_kernel(
}


// Must be called with at least `num_rows * stride_next_krylov` threads in
// Must be called with at least `num_rows * stride_krylov` threads in
// total.
template <int block_size, typename ValueType>
__global__ __launch_bounds__(block_size) void update_krylov_next_krylov_kernel(
__global__ __launch_bounds__(block_size) void update_krylov_kernel(
size_type iter, size_type num_rows, size_type num_cols,
ValueType *__restrict__ next_krylov_basis, size_type stride_next_krylov,
ValueType *__restrict__ krylov_bases, size_type stride_krylov,
const ValueType *__restrict__ hessenberg_iter, size_type stride_hessenberg,
const stopping_status *__restrict__ stop_status)
{
const auto global_id = thread::get_thread_id_flat();
const auto row_idx = global_id / stride_next_krylov;
const auto col_idx = global_id % stride_next_krylov;
const auto row_idx = global_id / stride_krylov;
const auto col_idx = global_id % stride_krylov;
const auto hessenberg =
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx];

if (row_idx < num_rows && col_idx < num_cols &&
!stop_status[col_idx].has_stopped()) {
const auto next_krylov_idx = row_idx * stride_next_krylov + col_idx;
const auto krylov_idx =
row_idx * stride_krylov + num_cols * (iter + 1) + col_idx;

const auto next_krylov_value =
next_krylov_basis[next_krylov_idx] / hessenberg;
const auto krylov_idx = row_idx * stride_krylov + col_idx;

next_krylov_basis[next_krylov_idx] = next_krylov_value;
krylov_bases[krylov_idx] = next_krylov_value;
krylov_bases[krylov_idx] /= hessenberg;
}
}


template <typename ValueType>
__device__ void calculate_sin_and_cos_kernel(
size_type col_idx, size_type num_cols, size_type iter,
const ValueType *hessenberg_iter, size_type stride_hessenberg,
const ValueType &this_hess, const ValueType &next_hess,
ValueType *givens_sin, size_type stride_sin, ValueType *givens_cos,
size_type stride_cos)
size_type stride_cos, ValueType &register_sin, ValueType &register_cos)
{
if (hessenberg_iter[iter * stride_hessenberg + col_idx] ==
zero<ValueType>()) {
givens_cos[iter * stride_cos + col_idx] = zero<ValueType>();
givens_sin[iter * stride_sin + col_idx] = one<ValueType>();
if (this_hess == zero<ValueType>()) {
register_cos = zero<ValueType>();
register_sin = one<ValueType>();
} else {
auto hypotenuse =
sqrt(hessenberg_iter[iter * stride_hessenberg + col_idx] *
hessenberg_iter[iter * stride_hessenberg + col_idx] +
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx] *
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx]);
givens_cos[iter * stride_cos + col_idx] =
abs(hessenberg_iter[iter * stride_hessenberg + col_idx]) /
hypotenuse;
givens_sin[iter * stride_sin + col_idx] =
givens_cos[iter * stride_cos + col_idx] *
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx] /
hessenberg_iter[iter * stride_hessenberg + col_idx];
const auto scale = abs(this_hess) + abs(next_hess);
const auto hypotenuse =
scale * sqrt(abs(this_hess / scale) * abs(this_hess / scale) +
abs(next_hess / scale) * abs(next_hess / scale));
register_cos = conj(this_hess) / hypotenuse;
register_sin = conj(next_hess) / hypotenuse;
}
givens_cos[iter * stride_cos + col_idx] = register_cos;
givens_sin[iter * stride_sin + col_idx] = register_sin;
}


template <typename ValueType>
__device__ void calculate_residual_norm_kernel(
size_type col_idx, size_type num_cols, size_type iter,
const ValueType *givens_sin, size_type stride_sin,
const ValueType *givens_cos, size_type stride_cos, ValueType *residual_norm,
ValueType *residual_norm_collection,
const ValueType &register_sin, const ValueType &register_cos,
ValueType *residual_norm, ValueType *residual_norm_collection,
size_type stride_residual_norm_collection, const ValueType *b_norm)
{
residual_norm_collection[(iter + 1) * stride_residual_norm_collection +
col_idx] =
-givens_sin[iter * stride_sin + col_idx] *
const auto this_rnc =
residual_norm_collection[iter * stride_residual_norm_collection +
col_idx];
const auto next_rnc = -conj(register_sin) * this_rnc;
residual_norm_collection[iter * stride_residual_norm_collection + col_idx] =
givens_cos[iter * stride_cos + col_idx] *
residual_norm_collection[iter * stride_residual_norm_collection +
col_idx];
residual_norm[col_idx] =
abs(residual_norm_collection[(iter + 1) *
stride_residual_norm_collection +
col_idx]) /
b_norm[col_idx];
register_cos * this_rnc;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_collection(iter) = cos * this_rnc + sin * next_rnc = cos * this_rnc;
norm_collection(iter+1) = -conj(sin) * this_rnc + cos * next_rnc = -conj(sin) * this_rnc;
Is that correct?

residual_norm[col_idx] = abs(next_rnc) / b_norm[col_idx];
residual_norm_collection[(iter + 1) * stride_residual_norm_collection +
col_idx] = next_rnc;
}


Expand All @@ -351,51 +327,42 @@ __global__ __launch_bounds__(block_size) void givens_rotation_kernel(
return;
}

const auto current_thread_block = group::this_thread_block();

auto this_hess = hessenberg_iter[col_idx];
auto next_hess = hessenberg_iter[stride_hessenberg + col_idx];
for (size_type i = 0; i < iter; ++i) {
const auto tmp =
givens_cos[i * stride_cos + col_idx] *
hessenberg_iter[i * stride_hessenberg + col_idx] +
givens_sin[i * stride_sin + col_idx] *
hessenberg_iter[(i + 1) * stride_hessenberg + col_idx];
current_thread_block.sync();
hessenberg_iter[(i + 1) * stride_hessenberg + col_idx] =
givens_cos[i * stride_cos + col_idx] *
hessenberg_iter[(i + 1) * stride_hessenberg + col_idx] -
givens_sin[i * stride_sin + col_idx] *
hessenberg_iter[i * stride_hessenberg + col_idx];
hessenberg_iter[i * stride_hessenberg + col_idx] = tmp;
current_thread_block.sync();
const auto cos = givens_cos[i * stride_cos + col_idx];
const auto sin = givens_sin[i * stride_sin + col_idx];
hessenberg_iter[i * stride_hessenberg + col_idx] =
cos * this_hess + sin * next_hess;
this_hess = conj(cos) * next_hess - conj(sin) * this_hess;
next_hess = hessenberg_iter[(i + 2) * stride_hessenberg + col_idx];
}
// for j in 1:iter - 1
// for j in 0:iter - 1
// temp = cos(j)*hessenberg(j) +
// sin(j)*hessenberg(j+1)
// hessenberg(j+1) = -sin(j)*hessenberg(j) +
// cos(j)*hessenberg(j+1)
// hessenberg(j) = temp;
// end

calculate_sin_and_cos_kernel(col_idx, num_cols, iter, hessenberg_iter,
stride_hessenberg, givens_sin, stride_sin,
givens_cos, stride_cos);
// Calculate sin and cos
ValueType register_sin;
ValueType register_cos;
calculate_sin_and_cos_kernel(col_idx, num_cols, iter, this_hess, next_hess,
givens_sin, stride_sin, givens_cos, stride_cos,
register_sin, register_cos);
// Calculate sin and cos on hessenberg(iter) and hessenberg(iter+1)

hessenberg_iter[iter * stride_hessenberg + col_idx] =
givens_cos[iter * stride_cos + col_idx] *
hessenberg_iter[iter * stride_hessenberg + col_idx] +
givens_sin[iter * stride_sin + col_idx] *
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx];
register_cos * this_hess + register_sin * next_hess;
hessenberg_iter[(iter + 1) * stride_hessenberg + col_idx] =
zero<ValueType>();
// hessenberg(iter) = cos(iter)*hessenberg(iter) +
// sin(iter)*hessenberg(iter)
// sin(iter)*hessenberg(iter+1)
// hessenberg(iter+1) = 0

calculate_residual_norm_kernel(col_idx, num_cols, iter, givens_sin,
stride_sin, givens_cos, stride_cos,
residual_norm, residual_norm_collection,
stride_residual_norm_collection, b_norm);
calculate_residual_norm_kernel(
col_idx, num_cols, iter, register_sin, register_cos, residual_norm,
residual_norm_collection, stride_residual_norm_collection, b_norm);
// Calculate residual norm
}

Expand Down Expand Up @@ -449,11 +416,13 @@ __global__ __launch_bounds__(block_size) void calculate_Qy_kernel(
const auto col_id = global_id % stride_preconditioner;

if (row_id < num_rows && col_id < num_cols) {
before_preconditioner[global_id] = zero<ValueType>();
ValueType temp = zero<ValueType>();

for (size_type j = 0; j < final_iter_nums[col_id]; ++j) {
before_preconditioner[global_id] +=
krylov_bases[row_id * stride_krylov + j * num_rhs + col_id] *
temp +=
krylov_bases[(row_id + j * num_rows) * stride_krylov + col_id] *
y[j * stride_y + col_id];
}
before_preconditioner[global_id] = temp;
}
}
Loading