Skip to content

Commit

Permalink
Fix inference errors when not debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Feb 20, 2024
1 parent 202ed63 commit 3c6eccb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 32 deletions.
65 changes: 39 additions & 26 deletions src/MatrixFields/field_matrix_iterative_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ There are 4 values that can be included in `kwargs...`:
"""
struct StationaryIterativeSolve{
correlated_solves,
debug,
P <: Union{Nothing, PreconditionerAlgorithm},
K <: NamedTuple,
} <: LazyFieldMatrixSolverAlgorithm
Expand All @@ -375,6 +376,7 @@ end
function StationaryIterativeSolve(;
P_alg = nothing,
n_iters = 1,
debug = false,
correlated_solves = false,
eigsolve_kwargs = (;),
)
Expand All @@ -385,9 +387,10 @@ function StationaryIterativeSolve(;
eigsolve_kwargs′ =
(; krylovdim = 4, maxiter = 20, tol = 0.01, eigsolve_kwargs...)
# Make correlated_solves into a type parameter to ensure type-stability.
params = (correlated_solves, typeof(P_alg), typeof(eigsolve_kwargs′))
params = (correlated_solves, typeof(P_alg), typeof(eigsolve_kwargs′), debug)
return StationaryIterativeSolve{params...}(P_alg, n_iters, eigsolve_kwargs′)
end
get_debug(::StationaryIterativeSolve{CS, debug}) where {CS, debug} = debug

# Extract correlated_solves as if it were a regular field, not a type parameter.
Base.getproperty(
Expand All @@ -411,23 +414,29 @@ check_field_matrix_solver(alg::StationaryIterativeSolve, cache, A, b) =

function run_field_matrix_solver!(alg::StationaryIterativeSolve, cache, x, A, b)
P = lazy_or_concrete_preconditioner(alg.P_alg, cache.P_cache, A)
using_cuda = ClimaComms.array_type(concrete_field_vector(b)) <: CUDA.CuArray
!using_cuda && @debug begin
e₀ = concrete_field_vector(b) # Initialize e to any nonzero vector.
λs, _, info = KrylovKit.eigsolve(e₀, 1; alg.eigsolve_kwargs...) do e
e_view = field_vector_view(e, keys(b).name_tree)
lazy_Ae = lazy_mul(A, e_view)
lazy_Δe = apply_preconditioner(alg.P_alg, cache.P_cache, P, lazy_Ae)
concrete_field_vector(@. e_view - lazy_Δe) # (I - inv(P) * A) * e
end
if info.converged == 0
(; tol, maxiter) = alg.eigsolve_kwargs
"Unable to approximate ρ(I - inv(P) * A) to within a tolerance \
of $(100 * tol) % in $maxiter or fewer iterations"
else
"ρ(I - inv(P) * A) ≈ $(abs(λs[1]))"
end
end _group = :spectral_radius
if get_debug(alg)
@debug begin
e₀ = concrete_field_vector(b) # Initialize e to any nonzero vector.
λs, _, info = KrylovKit.eigsolve(e₀, 1; alg.eigsolve_kwargs...) do e
e_view = field_vector_view(e, keys(b).name_tree)
lazy_Ae = lazy_mul(A, e_view)
lazy_Δe = apply_preconditioner(
alg.P_alg,
cache.P_cache,
P,
lazy_Ae,
)
concrete_field_vector(@. e_view - lazy_Δe) # (I - inv(P) * A) * e
end
if info.converged == 0
(; tol, maxiter) = alg.eigsolve_kwargs
"Unable to approximate ρ(I - inv(P) * A) to within a tolerance \
of $(100 * tol) % in $maxiter or fewer iterations"
else
"ρ(I - inv(P) * A) ≈ $(abs(λs[1]))"
end
end _group = :spectral_radius
end
if alg.correlated_solves
@. x = cache.previous_x
else
Expand All @@ -436,18 +445,22 @@ function run_field_matrix_solver!(alg::StationaryIterativeSolve, cache, x, A, b)
for iter in 1:(alg.n_iters)
lazy_Δb = lazy_sub(b, lazy_mul(A, x))
lazy_Δx = apply_preconditioner(alg.P_alg, cache.P_cache, P, lazy_Δb)
if get_debug(alg)
@debug begin
norm_Δx = norm(concrete_field_vector(Base.materialize(lazy_Δx)))
"||x[$(iter - 1)] - x'||₂ ≈ $norm_Δx"
end _group = :error_norm
end
@. x += lazy_Δx
end
if get_debug(alg)
@debug begin
lazy_Δb = lazy_sub(b, lazy_mul(A, x))
lazy_Δx = apply_preconditioner(alg.P_alg, cache.P_cache, P, lazy_Δb)
norm_Δx = norm(concrete_field_vector(Base.materialize(lazy_Δx)))
"||x[$(iter - 1)] - x'||₂ ≈ $norm_Δx"
"||x[$(alg.n_iters)] - x'||₂ ≈ $norm_Δx"
end _group = :error_norm
@. x += lazy_Δx
end
@debug begin
lazy_Δb = lazy_sub(b, lazy_mul(A, x))
lazy_Δx = apply_preconditioner(alg.P_alg, cache.P_cache, P, lazy_Δb)
norm_Δx = norm(concrete_field_vector(Base.materialize(lazy_Δx)))
"||x[$(alg.n_iters)] - x'||₂ ≈ $norm_Δx"
end _group = :error_norm
if alg.correlated_solves
@. cache.previous_x = x
end
Expand Down
8 changes: 2 additions & 6 deletions test/MatrixFields/field_matrix_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)

# In addition to ignoring the type instabilities from CUDA, ignore those
# from CUBLAS (norm), KrylovKit (eigsolve), and CoreLogging (@debug).
ignored = (
ignore_cuda...,
using_cuda ? AnyFrameModule(CUDA.CUBLAS) :
AnyFrameModule(MatrixFields.KrylovKit),
AnyFrameModule(Base.CoreLogging),
)
ignored =
(ignore_cuda..., (using_cuda ? AnyFrameModule(CUDA.CUBLAS) : ())...)
@test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b)
@test_opt ignored_modules = ignored field_matrix_solve!(args...)
@test_opt ignored_modules = ignored field_matrix_mul!(b, A, x)
Expand Down

0 comments on commit 3c6eccb

Please sign in to comment.