-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Precompute FFTW plan #159
Precompute FFTW plan #159
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -139,59 +139,56 @@ end | |
|
||
# Normalized two-dimension discrete Fourier transofrm normalized by sqrt(n1_bar). | ||
# Operates on the 2d data stored as a matrix. | ||
# The argument `f` lets you switch between forward transform (`f=identity`) and | ||
# inverse transform (`f=inv`). | ||
# From Dietrich & Newsam 96 in text following equation 12 | ||
function normalized_2d_fft!(transformed_array::AbstractMatrix{T}, array::AbstractMatrix{S}, grid_ext) where {T,S} | ||
function normalized_2d_fft!(transformed_array::AbstractMatrix{<:Complex}, array::AbstractMatrix{S}, fft_plan::FFTW.FFTWPlan, fft_plan!::FFTW.FFTWPlan, grid_ext, f=identity) where {S} | ||
|
||
normalization_factor = 1.0 / sqrt(grid_ext.nx * grid_ext.ny) | ||
transformed_array .= fft(array) .* normalization_factor | ||
normalization_factor = f(sqrt(grid_ext.nx * grid_ext.ny)) | ||
if pointer(transformed_array) == pointer(array) | ||
mul!(transformed_array, f(fft_plan!), array) | ||
else | ||
mul!(transformed_array, f(fft_plan), array) | ||
Comment on lines
+149
to
+151
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the difference between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The out-of-place mul!(B, fft_plan, A) with mul!(A, fft_plan!, A) We do have cases where we overwrite the input array. |
||
end | ||
transformed_array ./= normalization_factor | ||
|
||
end | ||
|
||
# Normalized two-dimension discrete Fourier transofrm normalized by sqrt(n1_bar). | ||
# Operates on the 2d data stored as a vector. | ||
# The argument `f` lets you switch between forward transform (`f=identity`) and | ||
# inverse transform (`f=inv`). | ||
# From Dietrich & Newsam 96 in text following equation 12 | ||
function normalized_2d_fft!(transformed_vector::AbstractVector{T}, vector::AbstractVector{S}, grid_ext) where {T,S} | ||
|
||
normalization_factor = 1.0 / sqrt(grid_ext.nx * grid_ext.ny) | ||
transformed_vector .= fft(reshape(vector, grid_ext.nx, grid_ext.ny))[:] .* normalization_factor | ||
|
||
end | ||
|
||
# Normalized inverse two-dimension discrete Fourier transofrm normalized by sqrt(n1_bar). | ||
# Operates on the 2d data stored as a matrix. | ||
# From Dietrich & Newsam 96 in text following equation 12 | ||
function normalized_inverse_2d_fft!(transformed_array::AbstractMatrix{T}, array::AbstractMatrix{S}, grid_ext) where {T,S} | ||
|
||
normalization_factor = sqrt(grid_ext.nx * grid_ext.ny) | ||
transformed_array .= ifft(array) .* normalization_factor | ||
|
||
end | ||
|
||
# Normalized inverse two-dimension discrete Fourier transofrm normalized by sqrt(n1_bar). | ||
# Operates on the 2d data stored as a vector. | ||
# From Dietrich & Newsam 96 in text following equation 12 | ||
function normalized_inverse_2d_fft!(transformed_vector::AbstractVector{T}, vector::AbstractVector{S}, grid_ext) where {T,S} | ||
function normalized_2d_fft!(transformed_vector::AbstractVector{<:Complex}, vector::AbstractVector{S}, fft_plan::FFTW.FFTWPlan, fft_plan!::FFTW.FFTWPlan, grid_ext, f=identity) where {S} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's correct, but all these methods need to have the same signature in terms of number of arguments. |
||
|
||
normalization_factor = sqrt(grid_ext.nx * grid_ext.ny) | ||
transformed_vector .= ifft(reshape(vector, grid_ext.nx, grid_ext.ny))[:] .* normalization_factor | ||
normalization_factor = f(sqrt(grid_ext.nx * grid_ext.ny)) | ||
tmp_array = complex(reshape(vector, grid_ext.nx, grid_ext.ny)) | ||
mul!(tmp_array, f(fft_plan!), tmp_array) | ||
transformed_vector .= @view(tmp_array[:]) ./ normalization_factor | ||
|
||
end | ||
|
||
# Decomposition of R11, equation 12 of Deitrich and Newsam | ||
function WΛWH_decomposition!(transformed_array::AbstractMatrix{T}, array::AbstractMatrix{T}, | ||
offline_matrices::OfflineMatrices, grid::Grid, grid_ext::Grid) where T | ||
function WΛWH_decomposition!(transformed_array::AbstractMatrix{T}, | ||
array::AbstractMatrix{T}, | ||
offline_matrices::OfflineMatrices, | ||
grid::Grid, | ||
grid_ext::Grid, | ||
fft_plan::FFTW.FFTWPlan, | ||
fft_plan!::FFTW.FFTWPlan, | ||
) where T | ||
|
||
@assert size(array) == (grid.nx, grid.ny) | ||
|
||
extended_array = zeros(ComplexF64, grid_ext.nx, grid_ext.ny) | ||
|
||
extended_array[1:grid.nx, 1:grid.ny] .= array | ||
|
||
normalized_2d_fft!(extended_array, extended_array, grid_ext) | ||
normalized_2d_fft!(extended_array, extended_array, fft_plan, fft_plan!, grid_ext) | ||
|
||
# Here we do an element-wise multiplication of the extended_array with the vector Lambda. This is identical to | ||
# Diagonal(Lambda) * extended_array[:], but avoids flattening and reshaping extended_array. | ||
normalized_inverse_2d_fft!(extended_array, reshape(offline_matrices.Lambda, grid_ext.nx, grid_ext.ny).*extended_array, grid_ext) | ||
normalized_2d_fft!(extended_array, reshape(offline_matrices.Lambda, grid_ext.nx, grid_ext.ny).*extended_array, fft_plan, fft_plan!, grid_ext, inv) | ||
|
||
transformed_array .= real.(@view(extended_array[1:grid.nx, 1:grid.ny])) | ||
|
||
|
@@ -211,7 +208,15 @@ function get_values_at_stations(field::AbstractMatrix{T}, stations) where T | |
end | ||
|
||
# Allocate and compute matrices that do not depend on time-dependent variables (height and observations). | ||
function init_offline_matrices(grid::Grid, grid_ext::Grid, stations::NamedTuple, noise_params::NamedTuple, obs_noise_std::T, F::Type) where T | ||
function init_offline_matrices(grid::Grid, | ||
grid_ext::Grid, | ||
stations::NamedTuple, | ||
noise_params::NamedTuple, | ||
obs_noise_std::T, | ||
fft_plan::FFTW.FFTWPlan, | ||
fft_plan!::FFTW.FFTWPlan, | ||
F::Type, | ||
) where T | ||
|
||
n1 = grid.nx * grid.ny # number of elements in original grid | ||
n1_bar = grid_ext.nx * grid_ext.ny # number of elements in extended grid | ||
|
@@ -241,12 +246,12 @@ function init_offline_matrices(grid::Grid, grid_ext::Grid, stations::NamedTuple, | |
matrices.R12_invR22 .= matrices.R12 * matrices.R22_inv | ||
|
||
fourier_coeffs = Vector{C}(undef, n1_bar) | ||
normalized_inverse_2d_fft!(fourier_coeffs, matrices.rho_bar, grid_ext) | ||
normalized_2d_fft!(fourier_coeffs, matrices.rho_bar, fft_plan, fft_plan!, grid_ext, inv) | ||
matrices.Lambda .= sqrt(n1_bar) .* real.(fourier_coeffs) | ||
|
||
WHbar_R12 = Matrix{C}(undef, n1_bar, stations.nst) | ||
for i in 1:stations.nst | ||
normalized_2d_fft!(@view(WHbar_R12[:,i]), @view(matrices.R21_bar[i,:]), grid_ext) | ||
normalized_2d_fft!(@view(WHbar_R12[:,i]), @view(matrices.R21_bar[i,:]), fft_plan, fft_plan!, grid_ext) | ||
end | ||
KH = Diagonal(matrices.Lambda)^(-1/2)*WHbar_R12 | ||
matrices.K .= KH' | ||
|
@@ -285,10 +290,18 @@ function init_online_matrices(grid::Grid, grid_ext::Grid, stations::NamedTuple, | |
end | ||
|
||
# Calculate the mean for the optimal proposal of the height | ||
function calculate_mean_height!(mean::AbstractArray{T,3}, height::AbstractArray{T,3}, | ||
offline_matrices::OfflineMatrices, observations::AbstractVector{T}, | ||
stations::NamedTuple, grid::Grid, grid_ext::Grid, | ||
filter_params, obs_noise_std::T) where T | ||
function calculate_mean_height!(mean::AbstractArray{T,3}, | ||
height::AbstractArray{T,3}, | ||
offline_matrices::OfflineMatrices, | ||
observations::AbstractVector{T}, | ||
stations::NamedTuple, | ||
grid::Grid, | ||
grid_ext::Grid, | ||
fft_plan::FFTW.FFTWPlan, | ||
fft_plan!::FFTW.FFTWPlan, | ||
filter_params, | ||
obs_noise_std::T, | ||
) where T | ||
|
||
# The arguments for the WΛWH decompositions are matrices that only have nonzero values | ||
# at the indices of the stations. Store them as sparse arrays to save space. | ||
|
@@ -299,8 +312,8 @@ function calculate_mean_height!(mean::AbstractArray{T,3}, height::AbstractArray{ | |
|
||
# Compute WΛWH decompositions, results are dense matrices, store them in buffers | ||
# These correspond to mu21 and mu22 in Alex's code | ||
WΛWH_decomposition!(offline_matrices.buf1, mu21, offline_matrices, grid, grid_ext) | ||
WΛWH_decomposition!(offline_matrices.buf2, mu22, offline_matrices, grid, grid_ext) | ||
WΛWH_decomposition!(offline_matrices.buf1, mu21, offline_matrices, grid, grid_ext, fft_plan, fft_plan!) | ||
WΛWH_decomposition!(offline_matrices.buf2, mu22, offline_matrices, grid, grid_ext, fft_plan, fft_plan!) | ||
|
||
# Compute the difference of the decomposition results, store in offline_matrices.buf1 | ||
# This corresponds to mu2 in Alex's code. | ||
|
@@ -317,7 +330,7 @@ function calculate_mean_height!(mean::AbstractArray{T,3}, height::AbstractArray{ | |
|
||
# Compute decomposition of height values at stations times the inverse covariance matrix | ||
# The argument corresponds to mu10 and the outcome to mu11 in Alex's code | ||
WΛWH_decomposition!(offline_matrices.buf1, mu10_sparse, offline_matrices, grid, grid_ext) | ||
WΛWH_decomposition!(offline_matrices.buf1, mu10_sparse, offline_matrices, grid, grid_ext, fft_plan, fft_plan!) | ||
|
||
# Compute the mean for the ith particle using mu2 and mu11 | ||
# Skip storing the temporary mu1 in Alex's code | ||
|
@@ -328,13 +341,22 @@ function calculate_mean_height!(mean::AbstractArray{T,3}, height::AbstractArray{ | |
end | ||
|
||
function sample_height_proposal!(height::AbstractArray{T,3}, | ||
offline_matrices::OfflineMatrices, online_matrices::OnlineMatrices, | ||
observations::AbstractVector{T}, stations::NamedTuple, grid::Grid, | ||
grid_ext::Grid, filter_params, rng::Random.AbstractRNG, obs_noise_std::T) where T | ||
offline_matrices::OfflineMatrices, | ||
online_matrices::OnlineMatrices, | ||
observations::AbstractVector{T}, | ||
stations::NamedTuple, | ||
grid::Grid, | ||
grid_ext::Grid, | ||
fft_plan::FFTW.FFTWPlan, | ||
fft_plan!::FFTW.FFTWPlan, | ||
filter_params, | ||
rng::Random.AbstractRNG, | ||
obs_noise_std::T, | ||
) where T | ||
|
||
@assert iseven(filter_params.nprt) "Number of particles must be even" | ||
|
||
calculate_mean_height!(online_matrices.mean, height, offline_matrices, observations, stations, grid, grid_ext, filter_params, obs_noise_std) | ||
calculate_mean_height!(online_matrices.mean, height, offline_matrices, observations, stations, grid, grid_ext, fft_plan, fft_plan!, filter_params, obs_noise_std) | ||
|
||
i_n1 = LinearIndices((grid.nx, grid.ny)) | ||
i_n1_bar = LinearIndices((grid_ext.nx, grid_ext.ny)) | ||
|
@@ -345,7 +367,7 @@ function sample_height_proposal!(height::AbstractArray{T,3}, | |
e2 = complex.(randn(rng, stations.nst), randn(rng, stations.nst)) | ||
|
||
# This gives the vector z1_bar | ||
normalized_inverse_2d_fft!(online_matrices.z1_bar, Diagonal(offline_matrices.Lambda)^(1/2) * e1, grid_ext) | ||
normalized_2d_fft!(online_matrices.z1_bar, Diagonal(offline_matrices.Lambda)^(1/2) * e1, fft_plan, fft_plan!, grid_ext, inv) | ||
|
||
# This is the vector z2 | ||
online_matrices.z2 .= offline_matrices.K * e1 .+ offline_matrices.L * e2 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -366,12 +366,16 @@ function init_filter(filter_params::FilterParameters, model_data, nprt_per_rank: | |
|
||
model_noise_params = get_model_noise_params(model_data) | ||
obs_noise_std = get_obs_noise_std(model_data) | ||
# Precompute two FFT plans, one in-place and the other out-of-place | ||
C = complex(T) | ||
tmp_array = Matrix{C}(undef, grid_ext.nx, grid_ext.ny) | ||
fft_plan, fft_plan! = FFTW.plan_fft(tmp_array), FFTW.plan_fft!(tmp_array) | ||
|
||
offline_matrices = init_offline_matrices(grid, grid_ext, stations, model_noise_params, obs_noise_std, T) | ||
offline_matrices = init_offline_matrices(grid, grid_ext, stations, model_noise_params, obs_noise_std, fft_plan, fft_plan!, T) | ||
Comment on lines
+372
to
+374
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually what I did at the beginning, but I didn't like it too much because it was opaque what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😁 Not a big deal either way |
||
online_matrices = init_online_matrices(grid, grid_ext, stations, filter_params, T) | ||
rng = get_rng(model_data) | ||
|
||
return (; filter_data..., offline_matrices, online_matrices, stations, grid, grid_ext, rng, obs_noise_std) | ||
return (; filter_data..., offline_matrices, online_matrices, stations, grid, grid_ext, rng, obs_noise_std, fft_plan, fft_plan!) | ||
end | ||
|
||
function update_particle_proposal!(model_data, filter_data, filter_params, truth_observations, nprt_per_rank, filter_type::BootstrapFilter) | ||
|
@@ -397,6 +401,8 @@ function update_particle_proposal!(model_data, filter_data, filter_params, truth | |
filter_data.stations, | ||
filter_data.grid, | ||
filter_data.grid_ext, | ||
filter_data.fft_plan, | ||
filter_data.fft_plan!, | ||
filter_params, | ||
filter_data.rng[threadid()], | ||
filter_data.obs_noise_std) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the pointers for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the answer to #159 (comment)