diff --git a/src/OptimalFilter.jl b/src/OptimalFilter.jl index 1a30108c..25557759 100644 --- a/src/OptimalFilter.jl +++ b/src/OptimalFilter.jl @@ -139,47 +139,44 @@ end # Normalized two-dimension discrete Fourier transofrm normalized by sqrt(n1_bar). # Operates on the 2d data stored as a matrix. +# The argument `f` lets you switch between forward transform (`f=identity`) and +# inverse transform (`f=inv`). # From Dietrich & Newsam 96 in text following equation 12 -function normalized_2d_fft!(transformed_array::AbstractMatrix{T}, array::AbstractMatrix{S}, grid_ext) where {T,S} +function normalized_2d_fft!(transformed_array::AbstractMatrix{<:Complex}, array::AbstractMatrix{S}, fft_plan::FFTW.FFTWPlan, fft_plan!::FFTW.FFTWPlan, grid_ext, f=identity) where {S} - normalization_factor = 1.0 / sqrt(grid_ext.nx * grid_ext.ny) - transformed_array .= fft(array) .* normalization_factor + normalization_factor = f(sqrt(grid_ext.nx * grid_ext.ny)) + if pointer(transformed_array) == pointer(array) + mul!(transformed_array, f(fft_plan!), array) + else + mul!(transformed_array, f(fft_plan), array) + end + transformed_array ./= normalization_factor end # Normalized two-dimension discrete Fourier transofrm normalized by sqrt(n1_bar). # Operates on the 2d data stored as a vector. +# The argument `f` lets you switch between forward transform (`f=identity`) and +# inverse transform (`f=inv`). # From Dietrich & Newsam 96 in text following equation 12 -function normalized_2d_fft!(transformed_vector::AbstractVector{T}, vector::AbstractVector{S}, grid_ext) where {T,S} - - normalization_factor = 1.0 / sqrt(grid_ext.nx * grid_ext.ny) - transformed_vector .= fft(reshape(vector, grid_ext.nx, grid_ext.ny))[:] .* normalization_factor - -end - -# Normalized inverse two-dimension discrete Fourier transofrm normalized by sqrt(n1_bar). -# Operates on the 2d data stored as a matrix. -# From Dietrich & Newsam 96 in text following equation 12 -function normalized_inverse_2d_fft!(transformed_array::AbstractMatrix{T}, array::AbstractMatrix{S}, grid_ext) where {T,S} - - normalization_factor = sqrt(grid_ext.nx * grid_ext.ny) - transformed_array .= ifft(array) .* normalization_factor - -end - -# Normalized inverse two-dimension discrete Fourier transofrm normalized by sqrt(n1_bar). -# Operates on the 2d data stored as a vector. -# From Dietrich & Newsam 96 in text following equation 12 -function normalized_inverse_2d_fft!(transformed_vector::AbstractVector{T}, vector::AbstractVector{S}, grid_ext) where {T,S} +function normalized_2d_fft!(transformed_vector::AbstractVector{<:Complex}, vector::AbstractVector{S}, fft_plan::FFTW.FFTWPlan, fft_plan!::FFTW.FFTWPlan, grid_ext, f=identity) where {S} - normalization_factor = sqrt(grid_ext.nx * grid_ext.ny) - transformed_vector .= ifft(reshape(vector, grid_ext.nx, grid_ext.ny))[:] .* normalization_factor + normalization_factor = f(sqrt(grid_ext.nx * grid_ext.ny)) + tmp_array = complex(reshape(vector, grid_ext.nx, grid_ext.ny)) + mul!(tmp_array, f(fft_plan!), tmp_array) + transformed_vector .= @view(tmp_array[:]) ./ normalization_factor end # Decomposition of R11, equation 12 of Deitrich and Newsam -function WΛWH_decomposition!(transformed_array::AbstractMatrix{T}, array::AbstractMatrix{T}, - offline_matrices::OfflineMatrices, grid::Grid, grid_ext::Grid) where T +function WΛWH_decomposition!(transformed_array::AbstractMatrix{T}, + array::AbstractMatrix{T}, + offline_matrices::OfflineMatrices, + grid::Grid, + grid_ext::Grid, + fft_plan::FFTW.FFTWPlan, + fft_plan!::FFTW.FFTWPlan, + ) where T @assert size(array) == (grid.nx, grid.ny) @@ -187,11 +184,11 @@ function WΛWH_decomposition!(transformed_array::AbstractMatrix{T}, array::Abstr extended_array[1:grid.nx, 1:grid.ny] .= array - normalized_2d_fft!(extended_array, extended_array, grid_ext) + normalized_2d_fft!(extended_array, extended_array, fft_plan, fft_plan!, grid_ext) # Here we do an element-wise multiplication of the extended_array with the vector Lambda. This is identical to # Diagonal(Lambda) * extended_array[:], but avoids flattening and reshaping extended_array. - normalized_inverse_2d_fft!(extended_array, reshape(offline_matrices.Lambda, grid_ext.nx, grid_ext.ny).*extended_array, grid_ext) + normalized_2d_fft!(extended_array, reshape(offline_matrices.Lambda, grid_ext.nx, grid_ext.ny).*extended_array, fft_plan, fft_plan!, grid_ext, inv) transformed_array .= real.(@view(extended_array[1:grid.nx, 1:grid.ny])) @@ -211,7 +208,15 @@ function get_values_at_stations(field::AbstractMatrix{T}, stations) where T end # Allocate and compute matrices that do not depend on time-dependent variables (height and observations). -function init_offline_matrices(grid::Grid, grid_ext::Grid, stations::NamedTuple, noise_params::NamedTuple, obs_noise_std::T, F::Type) where T +function init_offline_matrices(grid::Grid, + grid_ext::Grid, + stations::NamedTuple, + noise_params::NamedTuple, + obs_noise_std::T, + fft_plan::FFTW.FFTWPlan, + fft_plan!::FFTW.FFTWPlan, + F::Type, + ) where T n1 = grid.nx * grid.ny # number of elements in original grid n1_bar = grid_ext.nx * grid_ext.ny # number of elements in extended grid @@ -241,12 +246,12 @@ function init_offline_matrices(grid::Grid, grid_ext::Grid, stations::NamedTuple, matrices.R12_invR22 .= matrices.R12 * matrices.R22_inv fourier_coeffs = Vector{C}(undef, n1_bar) - normalized_inverse_2d_fft!(fourier_coeffs, matrices.rho_bar, grid_ext) + normalized_2d_fft!(fourier_coeffs, matrices.rho_bar, fft_plan, fft_plan!, grid_ext, inv) matrices.Lambda .= sqrt(n1_bar) .* real.(fourier_coeffs) WHbar_R12 = Matrix{C}(undef, n1_bar, stations.nst) for i in 1:stations.nst - normalized_2d_fft!(@view(WHbar_R12[:,i]), @view(matrices.R21_bar[i,:]), grid_ext) + normalized_2d_fft!(@view(WHbar_R12[:,i]), @view(matrices.R21_bar[i,:]), fft_plan, fft_plan!, grid_ext) end KH = Diagonal(matrices.Lambda)^(-1/2)*WHbar_R12 matrices.K .= KH' @@ -285,10 +290,18 @@ function init_online_matrices(grid::Grid, grid_ext::Grid, stations::NamedTuple, end # Calculate the mean for the optimal proposal of the height -function calculate_mean_height!(mean::AbstractArray{T,3}, height::AbstractArray{T,3}, - offline_matrices::OfflineMatrices, observations::AbstractVector{T}, - stations::NamedTuple, grid::Grid, grid_ext::Grid, - filter_params, obs_noise_std::T) where T +function calculate_mean_height!(mean::AbstractArray{T,3}, + height::AbstractArray{T,3}, + offline_matrices::OfflineMatrices, + observations::AbstractVector{T}, + stations::NamedTuple, + grid::Grid, + grid_ext::Grid, + fft_plan::FFTW.FFTWPlan, + fft_plan!::FFTW.FFTWPlan, + filter_params, + obs_noise_std::T, + ) where T # The arguments for the WΛWH decompositions are matrices that only have nonzero values # at the indices of the stations. Store them as sparse arrays to save space. @@ -299,8 +312,8 @@ function calculate_mean_height!(mean::AbstractArray{T,3}, height::AbstractArray{ # Compute WΛWH decompositions, results are dense matrices, store them in buffers # These correspond to mu21 and mu22 in Alex's code - WΛWH_decomposition!(offline_matrices.buf1, mu21, offline_matrices, grid, grid_ext) - WΛWH_decomposition!(offline_matrices.buf2, mu22, offline_matrices, grid, grid_ext) + WΛWH_decomposition!(offline_matrices.buf1, mu21, offline_matrices, grid, grid_ext, fft_plan, fft_plan!) + WΛWH_decomposition!(offline_matrices.buf2, mu22, offline_matrices, grid, grid_ext, fft_plan, fft_plan!) # Compute the difference of the decomposition results, store in offline_matrices.buf1 # This corresponds to mu2 in Alex's code. @@ -317,7 +330,7 @@ function calculate_mean_height!(mean::AbstractArray{T,3}, height::AbstractArray{ # Compute decomposition of height values at stations times the inverse covariance matrix # The argument corresponds to mu10 and the outcome to mu11 in Alex's code - WΛWH_decomposition!(offline_matrices.buf1, mu10_sparse, offline_matrices, grid, grid_ext) + WΛWH_decomposition!(offline_matrices.buf1, mu10_sparse, offline_matrices, grid, grid_ext, fft_plan, fft_plan!) # Compute the mean for the ith particle using mu2 and mu11 # Skip storing the temporary mu1 in Alex's code @@ -328,13 +341,22 @@ function calculate_mean_height!(mean::AbstractArray{T,3}, height::AbstractArray{ end function sample_height_proposal!(height::AbstractArray{T,3}, - offline_matrices::OfflineMatrices, online_matrices::OnlineMatrices, - observations::AbstractVector{T}, stations::NamedTuple, grid::Grid, - grid_ext::Grid, filter_params, rng::Random.AbstractRNG, obs_noise_std::T) where T + offline_matrices::OfflineMatrices, + online_matrices::OnlineMatrices, + observations::AbstractVector{T}, + stations::NamedTuple, + grid::Grid, + grid_ext::Grid, + fft_plan::FFTW.FFTWPlan, + fft_plan!::FFTW.FFTWPlan, + filter_params, + rng::Random.AbstractRNG, + obs_noise_std::T, + ) where T @assert iseven(filter_params.nprt) "Number of particles must be even" - calculate_mean_height!(online_matrices.mean, height, offline_matrices, observations, stations, grid, grid_ext, filter_params, obs_noise_std) + calculate_mean_height!(online_matrices.mean, height, offline_matrices, observations, stations, grid, grid_ext, fft_plan, fft_plan!, filter_params, obs_noise_std) i_n1 = LinearIndices((grid.nx, grid.ny)) i_n1_bar = LinearIndices((grid_ext.nx, grid_ext.ny)) @@ -345,7 +367,7 @@ function sample_height_proposal!(height::AbstractArray{T,3}, e2 = complex.(randn(rng, stations.nst), randn(rng, stations.nst)) # This gives the vector z1_bar - normalized_inverse_2d_fft!(online_matrices.z1_bar, Diagonal(offline_matrices.Lambda)^(1/2) * e1, grid_ext) + normalized_2d_fft!(online_matrices.z1_bar, Diagonal(offline_matrices.Lambda)^(1/2) * e1, fft_plan, fft_plan!, grid_ext, inv) # This is the vector z2 online_matrices.z2 .= offline_matrices.K * e1 .+ offline_matrices.L * e2 diff --git a/src/ParticleDA.jl b/src/ParticleDA.jl index f889cc6c..39db3be1 100644 --- a/src/ParticleDA.jl +++ b/src/ParticleDA.jl @@ -366,12 +366,16 @@ function init_filter(filter_params::FilterParameters, model_data, nprt_per_rank: model_noise_params = get_model_noise_params(model_data) obs_noise_std = get_obs_noise_std(model_data) + # Precompute two FFT plans, one in-place and the other out-of-place + C = complex(T) + tmp_array = Matrix{C}(undef, grid_ext.nx, grid_ext.ny) + fft_plan, fft_plan! = FFTW.plan_fft(tmp_array), FFTW.plan_fft!(tmp_array) - offline_matrices = init_offline_matrices(grid, grid_ext, stations, model_noise_params, obs_noise_std, T) + offline_matrices = init_offline_matrices(grid, grid_ext, stations, model_noise_params, obs_noise_std, fft_plan, fft_plan!, T) online_matrices = init_online_matrices(grid, grid_ext, stations, filter_params, T) rng = get_rng(model_data) - return (; filter_data..., offline_matrices, online_matrices, stations, grid, grid_ext, rng, obs_noise_std) + return (; filter_data..., offline_matrices, online_matrices, stations, grid, grid_ext, rng, obs_noise_std, fft_plan, fft_plan!) end function update_particle_proposal!(model_data, filter_data, filter_params, truth_observations, nprt_per_rank, filter_type::BootstrapFilter) @@ -397,6 +401,8 @@ function update_particle_proposal!(model_data, filter_data, filter_params, truth filter_data.stations, filter_data.grid, filter_data.grid_ext, + filter_data.fft_plan, + filter_data.fft_plan!, filter_params, filter_data.rng[threadid()], filter_data.obs_noise_std) diff --git a/test/runtests.jl b/test/runtests.jl index c852572f..6e087bbd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using ParticleDA using LinearAlgebra, Test, HDF5, Random, YAML using MPI using StableRNGs +using FFTW using ParticleDA: FilterParameters @@ -290,8 +291,9 @@ end arr = rand(ComplexF64,10,10) arr2 = zeros(ComplexF64,10,10) arr3 = zeros(ComplexF64,10,10) - ParticleDA.normalized_2d_fft!(arr2,arr,grid_ext) - ParticleDA.normalized_inverse_2d_fft!(arr3,arr2,grid_ext) + fft_plan, fft_plan! = FFTW.plan_fft(arr), FFTW.plan_fft!(arr) + ParticleDA.normalized_2d_fft!(arr2, arr, fft_plan, fft_plan!, grid_ext) + ParticleDA.normalized_2d_fft!(arr3, arr2, fft_plan, fft_plan!, grid_ext, inv) @test arr ≈ arr3 cov_1 = zeros(stations.nst, grid_ext.nx * grid_ext.ny) @@ -307,14 +309,16 @@ end height = rand(grid.nx, grid.ny, filter_params.nprt) obs = randn(stations.nst) - mat_off = ParticleDA.init_offline_matrices(grid, grid_ext, stations, noise_params, model_params.obs_noise_std, Float64) + tmp_array = Matrix{ComplexF64}(undef, grid_ext.nx, grid_ext.ny) + fft_plan, fft_plan! = FFTW.plan_fft(tmp_array), FFTW.plan_fft!(tmp_array) + mat_off = ParticleDA.init_offline_matrices(grid, grid_ext, stations, noise_params, model_params.obs_noise_std, fft_plan, fft_plan!, Float64) mat_on = ParticleDA.init_online_matrices(grid, grid_ext, stations, filter_params, Float64) @test minimum(mat_off.Lambda) > 0.0 - ParticleDA.calculate_mean_height!(mat_on.mean, height, mat_off, obs, stations, grid, grid_ext, filter_params, model_params.obs_noise_std) + ParticleDA.calculate_mean_height!(mat_on.mean, height, mat_off, obs, stations, grid, grid_ext, fft_plan, fft_plan!, filter_params, model_params.obs_noise_std) @test all(isfinite, mat_on.mean) rng = Random.MersenneTwister(seed) - ParticleDA.sample_height_proposal!(height, mat_off, mat_on, obs, stations, grid, grid_ext, filter_params, rng, model_params.obs_noise_std) + ParticleDA.sample_height_proposal!(height, mat_off, mat_on, obs, stations, grid, grid_ext, fft_plan, fft_plan!, filter_params, rng, model_params.obs_noise_std) @test all(isfinite, mat_on.samples) end @@ -361,10 +365,12 @@ end obs[i] = height[stations.ist[i], stations.jst[i],1] + rand() end - mat_off = ParticleDA.init_offline_matrices(grid, grid_ext, stations, noise_params, model_params.obs_noise_std, Float64) + tmp_array = Matrix{ComplexF64}(undef, grid_ext.nx, grid_ext.ny) + fft_plan, fft_plan! = FFTW.plan_fft(tmp_array), FFTW.plan_fft!(tmp_array) + mat_off = ParticleDA.init_offline_matrices(grid, grid_ext, stations, noise_params, model_params.obs_noise_std, fft_plan, fft_plan!, Float64) mat_on = ParticleDA.init_online_matrices(grid, grid_ext, stations, filter_params, Float64) - ParticleDA.sample_height_proposal!(height, mat_off, mat_on, obs, stations, grid, grid_ext, filter_params, rng, model_params.obs_noise_std) + ParticleDA.sample_height_proposal!(height, mat_off, mat_on, obs, stations, grid, grid_ext, fft_plan, fft_plan!, filter_params, rng, model_params.obs_noise_std) Yobs_t = copy(obs) FH_t = copy(reshape(permutedims(height, [3 1 2]), filter_params.nprt, (model_params.nx)*(model_params.ny)))