diff --git a/src/ParticleDA.jl b/src/ParticleDA.jl index 0caa9018..f0c09f4e 100644 --- a/src/ParticleDA.jl +++ b/src/ParticleDA.jl @@ -224,6 +224,41 @@ function copy_states!(particles::AbstractArray{T,4}, end +# Initialize arrays used by the filter +function init_filter(filter_params::FilterParameters, model_data, nprt_per_rank::Int, T::Type) + + if MPI.Comm_rank(MPI.COMM_WORLD) == filter_params.master_rank + weights = Vector{T}(undef, filter_params.nprt) + else + weights = Vector{T}(undef, nprt_per_rank) + end + + resampling_indices = Vector{Int}(undef, filter_params.nprt) + + # TODO: these variables should be set in a better way + nx, ny, n_state_var = model_data.model_params.nx, model_data.model_params.ny, model_data.model_params.n_state_var + + statistics = Array{SummaryStat{T}, 3}(undef, nx, ny, n_state_var) + avg_arr = Array{T,3}(undef, nx, ny, n_state_var) + var_arr = Array{T,3}(undef, nx, ny, n_state_var) + + # Memory buffer used during copy of the states + copy_buffer = Array{T,4}(undef, nx, ny, n_state_var, nprt_per_rank) + + return FilterData(weights, resampling_indices, statistics, avg_arr, var_arr, copy_buffer) +end + +struct FilterData{T, S, U, V, X} + + weights::T + resampling_indices::U + statistics::S + avg_arr::V + var_arr::V + copy_buffer::X + +end + struct BootstrapFilter end function run_particle_filter(init, filter_params::FilterParameters, model_params_dict::Dict, ::Type{BootstrapFilter}) @@ -251,38 +286,21 @@ function run_particle_filter(init, filter_params::FilterParameters, model_params @timeit_debug timer "Model initialization" model_data = init(model_params_dict, nprt_per_rank, my_rank) # TODO: put the body of this block in a function - @timeit_debug timer "Filter initialization" begin - # TODO: ideally this will be an argument of the function, to choose a - # different datatype. - T = Float64 - - if MPI.Comm_rank(MPI.COMM_WORLD) == filter_params.master_rank - weights = Vector{T}(undef, filter_params.nprt) - else - weights = Vector{T}(undef, nprt_per_rank) - end - - resampling_indices = Vector{Int}(undef, filter_params.nprt) - - # TODO: these variables should be set in a better way - nx, ny, n_state_var = model_data.model_params.nx, model_data.model_params.ny, model_data.model_params.n_state_var - - statistics = Array{SummaryStat{T}, 3}(undef, nx, ny, n_state_var) - avg_arr = Array{T,3}(undef, nx, ny, n_state_var) - var_arr = Array{T,3}(undef, nx, ny, n_state_var) - - # Memory buffer used during copy of the states - copy_buffer = Array{T}(undef, nx, ny, n_state_var, nprt_per_rank) - end + @timeit_debug timer "Filter initialization" filter_data = init_filter(filter_params, model_data, nprt_per_rank, Float64) @timeit_debug timer "get_particles" particles = get_particles(model_data) - @timeit_debug timer "Mean and Var" get_mean_and_var!(statistics, particles, filter_params.master_rank) + @timeit_debug timer "Mean and Var" get_mean_and_var!(filter_data.statistics, particles, filter_params.master_rank) # Write initial state (time = 0) + metadata if(filter_params.verbose && my_rank == filter_params.master_rank) @timeit_debug timer "IO" begin - unpack_statistics!(avg_arr, var_arr, statistics) - write_snapshot(filter_params.output_filename, model_data, avg_arr, var_arr, weights, 0) + unpack_statistics!(filter_data.avg_arr, filter_data.var_arr, filter_data.statistics) + write_snapshot(filter_params.output_filename, + model_data, + filter_data.avg_arr, + filter_data.var_arr, + filter_data.weights, + 0) end end @@ -296,7 +314,7 @@ function run_particle_filter(init, filter_params::FilterParameters, model_params @timeit_debug timer "Particle State Update and Process Noise" model_observations = update_particles!(model_data, nprt_per_rank) - @timeit_debug timer "Weights" get_log_weights!(@view(weights[1:nprt_per_rank]), + @timeit_debug timer "Weights" get_log_weights!(@view(filter_data.weights[1:nprt_per_rank]), truth_observations, model_observations, filter_params.weight_std) @@ -308,33 +326,37 @@ function run_particle_filter(init, filter_params::FilterParameters, model_params # for their chunk of state. if my_rank == filter_params.master_rank @timeit_debug timer "MPI Gather" MPI.Gather!(MPI.IN_PLACE, - UBuffer(weights, nprt_per_rank), + UBuffer(filter_data.weights, nprt_per_rank), filter_params.master_rank, MPI.COMM_WORLD) - @timeit_debug timer "Weights" normalized_exp!(weights) - @timeit_debug timer "Resample" resample!(resampling_indices, weights) + @timeit_debug timer "Weights" normalized_exp!(filter_data.weights) + @timeit_debug timer "Resample" resample!(filter_data.resampling_indices, filter_data.weights) else - @timeit_debug timer "MPI Gather" MPI.Gather!(weights, + @timeit_debug timer "MPI Gather" MPI.Gather!(filter_data.weights, nothing, filter_params.master_rank, MPI.COMM_WORLD) end # Broadcast resampled particle indices to all ranks - MPI.Bcast!(resampling_indices, filter_params.master_rank, MPI.COMM_WORLD) + MPI.Bcast!(filter_data.resampling_indices, filter_params.master_rank, MPI.COMM_WORLD) @timeit_debug timer "get_particles" particles = get_particles(model_data) - @timeit_debug timer "State Copy" copy_states!(particles, copy_buffer, resampling_indices, my_rank, nprt_per_rank) + @timeit_debug timer "State Copy" copy_states!(particles, + filter_data.copy_buffer, + filter_data.resampling_indices, + my_rank, + nprt_per_rank) @timeit_debug timer "get_particles" particles = get_particles(model_data) - @timeit_debug timer "Mean and Var" get_mean_and_var!(statistics, particles, filter_params.master_rank) + @timeit_debug timer "Mean and Var" get_mean_and_var!(filter_data.statistics, particles, filter_params.master_rank) if my_rank == filter_params.master_rank && filter_params.verbose @timeit_debug timer "IO" begin - unpack_statistics!(avg_arr, var_arr, statistics) - write_snapshot(filter_params.output_filename, model_data, avg_arr, var_arr, weights, it) + unpack_statistics!(filter_data.avg_arr, filter_data.var_arr, filter_data.statistics) + write_snapshot(filter_params.output_filename, model_data, filter_data.avg_arr, filter_data.var_arr, filter_data.weights, it) end end @@ -370,9 +392,9 @@ function run_particle_filter(init, filter_params::FilterParameters, model_params end end - unpack_statistics!(avg_arr, var_arr, statistics) + unpack_statistics!(filter_data.avg_arr, filter_data.var_arr, filter_data.statistics) - return get_truth(model_data), avg_arr, var_arr + return get_truth(model_data), filter_data.avg_arr, filter_data.var_arr end # Initialise params struct with user-defined dict of values.