Skip to content

Commit

Permalink
fix complex and related algorithm comment
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed May 7, 2020
1 parent cebe937 commit c6ad309
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 40 deletions.
9 changes: 4 additions & 5 deletions common/solver/gmres_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,8 @@ __device__ void calculate_sin_and_cos_kernel(
const auto hypotenuse =
scale * sqrt(abs(this_hess / scale) * abs(this_hess / scale) +
abs(next_hess / scale) * abs(next_hess / scale));
register_cos = abs(this_hess) / hypotenuse;
register_sin =
this_hess / abs(this_hess) * conj(next_hess) / hypotenuse;
register_cos = conj(this_hess) / hypotenuse;
register_sin = conj(next_hess) / hypotenuse;
}
givens_cos[iter * stride_cos + col_idx] = register_cos;
givens_sin[iter * stride_sin + col_idx] = register_sin;
Expand All @@ -300,7 +299,7 @@ __device__ void calculate_residual_norm_kernel(
const auto this_rnc =
residual_norm_collection[iter * stride_residual_norm_collection +
col_idx];
const auto next_rnc = -register_sin * this_rnc;
const auto next_rnc = -conj(register_sin) * this_rnc;
residual_norm_collection[iter * stride_residual_norm_collection + col_idx] =
register_cos * this_rnc;
residual_norm[col_idx] = abs(next_rnc) / b_norm[col_idx];
Expand Down Expand Up @@ -335,7 +334,7 @@ __global__ __launch_bounds__(block_size) void givens_rotation_kernel(
const auto sin = givens_sin[i * stride_sin + col_idx];
hessenberg_iter[i * stride_hessenberg + col_idx] =
cos * this_hess + sin * next_hess;
this_hess = cos * next_hess - sin * this_hess;
this_hess = conj(cos) * next_hess - conj(sin) * this_hess;
next_hess = hessenberg_iter[(i + 2) * stride_hessenberg + col_idx];
}
// for j in 0:iter - 1
Expand Down
20 changes: 10 additions & 10 deletions core/solver/gmres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,39 +208,39 @@ void Gmres<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
restart_iter, &final_iter_nums, &stop_status));
// final_iter_nums += 1 (unconverged)
// next_krylov_basis is alias for (restart_iter + 1)-th krylov_bases
// for i in 0:restart_iter
// for i in 0:restart_iter(include)
// hessenberg(restart_iter, i) = next_krylov_basis' *
// krylov_bases(:, i)
// next_krylov_basis -= hessenberg(restart_iter, i) *
// krylov_bases(:, i)
// end
// hessenberg(restart_iter, restart_iter + 1) = norm(next_krylov_basis)
// next_krylov_basis /= hessenberg(restart_iter, restart_iter + 1)
// hessenberg(restart_iter+1, restart_iter) = norm(next_krylov_basis)
// next_krylov_basis /= hessenberg(restart_iter + 1, restart_iter)
// End of arnoldi
// Start apply givens rotation
// for j in 0:restart_iter
// for j in 0:restart_iter(exclude)
// temp = cos(j)*hessenberg(j) +
// sin(j)*hessenberg(j+1)
// hessenberg(j+1) = -sin(j)*hessenberg(j) +
// cos(j)*hessenberg(j+1)
// hessenberg(j+1) = -conj(sin(j))*hessenberg(j) +
// conj(cos(j))*hessenberg(j+1)
// hessenberg(j) = temp;
// end
// Calculate sin and cos
// this_hess = hessenberg(restart_iter)
// next_hess = hessenberg(restart_iter+1)
// hypotenuse = sqrt(this_hess * this_hess + next_hess * next_hess);
// cos = abs(this_hess) / hypotenuse;
// sin = cos * next_hess / this_hess
// cos(restart_iter) = conj(this_hess) / hypotenuse;
// sin(restart_iter) = conj(next_hess) / this_hess
// hessenberg(restart_iter) =
// cos(restart_iter)*hessenberg(restart_iter) +
// sin(restart_iter)*hessenberg(restart_iter)
// hessenberg(restart_iter+1) = 0
// End apply givens rotation
// Calculate residual norm
// this_rnc = residual_norm_collection(restart_iter)
// next_rnc = -sin(restart_iter) * this_rnc
// next_rnc = -conj(sin(restart_iter)) * this_rnc
// residual_norm_collection(restart_iter) = cos(restart_iter) * this_rnc
// residual = abs(next_rnc)/b_norm
// residual_norm = abs(next_rnc)/b_norm
// residual_norm_collection(restart_iter + 1) = next_rnc

restart_iter++;
Expand Down
22 changes: 11 additions & 11 deletions omp/solver/gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ void finish_arnoldi(size_type num_rows, matrix::Dense<ValueType> *krylov_bases,
}
}
// for i in 1:iter
// hessenberg(iter, i) = krylov_bases' * krylov_bases(:, i)
// krylov_bases -= hessenberg(iter, i) * krylov_bases(:, i)
// hessenberg(iter, i) = next_krylov_basis' * krylov_bases(:, i)
// next_krylov_basis -= hessenberg(iter, i) * krylov_bases(:, i)
// end

ValueType hessenberg_iter_entry = zero<ValueType>();
Expand All @@ -100,14 +100,13 @@ void finish_arnoldi(size_type num_rows, matrix::Dense<ValueType> *krylov_bases,
krylov_bases->at(j + next_krylov_rowoffset, i);
}
hessenberg_iter->at(iter + 1, i) = sqrt(hessenberg_iter_entry);
// hessenberg(iter, iter + 1) = norm(krylov_bases)
// hessenberg(iter + 1, iter) = norm(krylov_bases)
#pragma omp parallel for
for (size_type j = 0; j < num_rows; ++j) {
krylov_bases->at(j + next_krylov_rowoffset, i) /=
hessenberg_iter->at(iter + 1, i);
}
// krylov_bases /= hessenberg(iter, iter + 1)
// krylov_bases(:, iter + 1) = krylov_bases
// next_krylov_basis /= hessenberg(iter, iter + 1)
// End of arnoldi
}
}
Expand Down Expand Up @@ -151,13 +150,13 @@ void givens_rotation(matrix::Dense<ValueType> *givens_sin,
auto temp = givens_cos->at(j, i) * hessenberg_iter->at(j, i) +
givens_sin->at(j, i) * hessenberg_iter->at(j + 1, i);
hessenberg_iter->at(j + 1, i) =
-givens_sin->at(j, i) * hessenberg_iter->at(j, i) +
givens_cos->at(j, i) * hessenberg_iter->at(j + 1, i);
-conj(givens_sin->at(j, i)) * hessenberg_iter->at(j, i) +
conj(givens_cos->at(j, i)) * hessenberg_iter->at(j + 1, i);
hessenberg_iter->at(j, i) = temp;
// temp = cos(j)*hessenberg(j) +
// sin(j)*hessenberg(j+1)
// hessenberg(j+1) = -sin(j)*hessenberg(j) +
// cos(j)*hessenberg(j+1)
// hessenberg(j+1) = -conj(sin(j))*hessenberg(j) +
// conj(cos(j))*hessenberg(j+1)
// hessenberg(j) = temp;
}

Expand All @@ -168,7 +167,7 @@ void givens_rotation(matrix::Dense<ValueType> *givens_sin,
givens_sin->at(iter, i) * hessenberg_iter->at(iter + 1, i);
hessenberg_iter->at(iter + 1, i) = zero<ValueType>();
// hessenberg(iter) = cos(iter)*hessenberg(iter) +
// sin(iter)*hessenberg(iter)
// sin(iter)*hessenberg(iter + 1)
// hessenberg(iter+1) = 0
}
}
Expand All @@ -188,7 +187,8 @@ void calculate_next_residual_norm(
continue;
}
residual_norm_collection->at(iter + 1, i) =
-givens_sin->at(iter, i) * residual_norm_collection->at(iter, i);
-conj(givens_sin)->at(iter, i) *
residual_norm_collection->at(iter, i);
residual_norm_collection->at(iter, i) =
givens_cos->at(iter, i) * residual_norm_collection->at(iter, i);
residual_norm->at(0, i) =
Expand Down
27 changes: 13 additions & 14 deletions reference/solver/gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ void finish_arnoldi(size_type num_rows, matrix::Dense<ValueType> *krylov_bases,
}
}
// for i in 1:iter
// hessenberg(iter, i) = krylov_bases' * krylov_bases(:, i)
// krylov_bases -= hessenberg(iter, i) * krylov_bases(:, i)
// hessenberg(iter, i) = next_krylov_basis' * krylov_bases(:, i)
// next_krylov_basis -= hessenberg(iter, i) * krylov_bases(:, i)
// end

hessenberg_iter->at(iter + 1, i) = 0;
Expand All @@ -91,13 +91,12 @@ void finish_arnoldi(size_type num_rows, matrix::Dense<ValueType> *krylov_bases,
}
hessenberg_iter->at(iter + 1, i) =
sqrt(hessenberg_iter->at(iter + 1, i));
// hessenberg(iter, iter + 1) = norm(krylov_bases)
// hessenberg(iter + 1, iter) = norm(krylov_bases)
for (size_type j = 0; j < num_rows; ++j) {
krylov_bases->at(j + next_krylov_rowoffset, i) /=
hessenberg_iter->at(iter + 1, i);
}
// krylov_bases /= hessenberg(iter, iter + 1)
// krylov_bases(:, iter + 1) = krylov_bases
// next_krylov_basis /= hessenberg(iter, iter + 1)
// End of arnoldi
}
}
Expand All @@ -119,9 +118,8 @@ void calculate_sin_and_cos(matrix::Dense<ValueType> *givens_sin,
const auto hypotenuse =
scale * sqrt(abs(this_hess / scale) * abs(this_hess / scale) +
abs(next_hess / scale) * abs(next_hess / scale));
givens_cos->at(iter, rhs) = abs(this_hess) / hypotenuse;
givens_sin->at(iter, rhs) =
this_hess / abs(this_hess) * conj(next_hess) / hypotenuse;
givens_cos->at(iter, rhs) = conj(this_hess) / hypotenuse;
givens_sin->at(iter, rhs) = conj(next_hess) / hypotenuse;
}
}

Expand All @@ -140,13 +138,13 @@ void givens_rotation(matrix::Dense<ValueType> *givens_sin,
auto temp = givens_cos->at(j, i) * hessenberg_iter->at(j, i) +
givens_sin->at(j, i) * hessenberg_iter->at(j + 1, i);
hessenberg_iter->at(j + 1, i) =
-givens_sin->at(j, i) * hessenberg_iter->at(j, i) +
givens_cos->at(j, i) * hessenberg_iter->at(j + 1, i);
-conj(givens_sin->at(j, i)) * hessenberg_iter->at(j, i) +
conj(givens_cos->at(j, i)) * hessenberg_iter->at(j + 1, i);
hessenberg_iter->at(j, i) = temp;
// temp = cos(j)*hessenberg(j) +
// sin(j)*hessenberg(j+1)
// hessenberg(j+1) = -sin(j)*hessenberg(j) +
// cos(j)*hessenberg(j+1)
// hessenberg(j+1) = -conj(sin(j))*hessenberg(j) +
// conj(cos(j))*hessenberg(j+1)
// hessenberg(j) = temp;
}

Expand All @@ -157,7 +155,7 @@ void givens_rotation(matrix::Dense<ValueType> *givens_sin,
givens_sin->at(iter, i) * hessenberg_iter->at(iter + 1, i);
hessenberg_iter->at(iter + 1, i) = zero<ValueType>();
// hessenberg(iter) = cos(iter)*hessenberg(iter) +
// sin(iter)*hessenberg(iter)
// sin(iter)*hessenberg(iter + 1)
// hessenberg(iter+1) = 0
}
}
Expand All @@ -176,7 +174,8 @@ void calculate_next_residual_norm(
continue;
}
residual_norm_collection->at(iter + 1, i) =
-givens_sin->at(iter, i) * residual_norm_collection->at(iter, i);
-conj(givens_sin->at(iter, i)) *
residual_norm_collection->at(iter, i);
residual_norm_collection->at(iter, i) =
givens_cos->at(iter, i) * residual_norm_collection->at(iter, i);
residual_norm->at(0, i) =
Expand Down

0 comments on commit c6ad309

Please sign in to comment.