diff --git a/docs/src/index.md b/docs/src/index.md index 0bf9913d..7465928f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -68,7 +68,9 @@ be used by [`run_particle_filter`](@ref): ParticleDA.get_particles ParticleDA.get_truth ParticleDA.update_truth! -ParticleDA.update_particles! +ParticleDA.update_particle_dynamics! +ParticleDA.update_particle_noise! +ParticleDA.get_particle_observations! ParticleDA.write_snapshot ``` diff --git a/src/ParticleDA.jl b/src/ParticleDA.jl index d47ef919..19a5c9a6 100644 --- a/src/ParticleDA.jl +++ b/src/ParticleDA.jl @@ -38,14 +38,32 @@ above signature, specifying the type of `model_data`. function update_truth! end """ - ParticleDA.update_particles!(model_data, nprt_per_rank::Int) -> particles_observations + ParticleDA.update_particle_dynamics!(model_data, nprt_per_rank::Int) -Update the particles using the dynamic of the model and return the vector of the +Update the particles using the dynamic of the model. `nprt_per_rank` is the +number of particles per each MPI rank. This method is intended to be extended +by the user with the above signature, specifying the type of `model_data`. +""" +function update_particle_dynamics! end + +""" + ParticleDA.update_particle_noise!(model_data, nprt_per_rank::Int) + +Update the particles using the noise of the model and return the vector of the particles. `nprt_per_rank` is the number of particles per each MPI rank. This method is intended to be extended by the user with the above signature, specifying the type of `model_data`. """ -function update_particles! end +function update_particle_noise! end + +""" + ParticleDA.get_particle_observations!(model_data, nprt_per_rank::Int) -> particles_observations + +Return the vector of the particles observations. `nprt_per_rank` is the number +of particles per each MPI rank. This method is intended to be extended by the +user with the above signature, specifying the type of `model_data`. +""" +function get_particle_observations! end """ ParticleDA.write_snapshot(output_filename, model_data, avg_arr, var_arr, weights, it) @@ -325,9 +343,11 @@ function run_particle_filter(init, filter_params::FilterParameters, model_params # Forecast: Update tsunami forecast and get observations from it # Parallelised with threads. - @timeit_debug timer "Particle State Update and Process Noise" model_observations = update_particles!(model_data, nprt_per_rank) + @timeit_debug timer "Particle Dynamics" update_particle_dynamics!(model_data, nprt_per_rank); + @timeit_debug timer "Particle Noise" update_particle_noise!(model_data, nprt_per_rank) + @timeit_debug timer "Particle Observations" model_observations = get_particle_observations!(model_data, nprt_per_rank) - @timeit_debug timer "Weights" get_log_weights!(@view(filter_data.weights[1:nprt_per_rank]), + @timeit_debug timer "Particle Weights" get_log_weights!(@view(filter_data.weights[1:nprt_per_rank]), truth_observations, model_observations, filter_params.weight_std) diff --git a/test/model/model.jl b/test/model/model.jl index e838de8a..e994a44b 100644 --- a/test/model/model.jl +++ b/test/model/model.jl @@ -455,12 +455,15 @@ function ParticleDA.update_truth!(d::ModelData, _) return d.observations.truth end -function ParticleDA.update_particles!(d::ModelData, nprt_per_rank) +function ParticleDA.update_particle_dynamics!(d::ModelData, nprt_per_rank) + # Update dynamics Threads.@threads for ip in 1:nprt_per_rank tsunami_update!(@view(d.field_buffer[:, :, 1, threadid()]), @view(d.field_buffer[:, :, 2, threadid()]), @view(d.states.particles[:, :, :, ip]), d.model_matrices, d.model_params) - end +end + +function ParticleDA.update_particle_noise!(d::ModelData, nprt_per_rank) # Add process noise add_random_field!(d.states.particles, d.field_buffer, @@ -468,7 +471,9 @@ function ParticleDA.update_particles!(d::ModelData, nprt_per_rank) d.rng, d.model_params.n_state_var, nprt_per_rank) +end +function ParticleDA.get_particle_observations!(d::ModelData, nprt_per_rank) # get observations for ip in 1:nprt_per_rank get_obs!(@view(d.observations.model[:,ip]),