Skip to content

Commit

Permalink
Merge pull request #142 from Team-RADDISH/tk/init_filter
Browse files Browse the repository at this point in the history
Move filter initialisations into a function and group arrays into a struct.
  • Loading branch information
giordano committed Jan 14, 2021
2 parents 36f4f73 + 8d0b81f commit 5dd2bed
Showing 1 changed file with 60 additions and 38 deletions.
98 changes: 60 additions & 38 deletions src/ParticleDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,41 @@ function copy_states!(particles::AbstractArray{T,4},

end

# Initialize arrays used by the filter
function init_filter(filter_params::FilterParameters, model_data, nprt_per_rank::Int, T::Type)

if MPI.Comm_rank(MPI.COMM_WORLD) == filter_params.master_rank
weights = Vector{T}(undef, filter_params.nprt)
else
weights = Vector{T}(undef, nprt_per_rank)
end

resampling_indices = Vector{Int}(undef, filter_params.nprt)

# TODO: these variables should be set in a better way
nx, ny, n_state_var = model_data.model_params.nx, model_data.model_params.ny, model_data.model_params.n_state_var

statistics = Array{SummaryStat{T}, 3}(undef, nx, ny, n_state_var)
avg_arr = Array{T,3}(undef, nx, ny, n_state_var)
var_arr = Array{T,3}(undef, nx, ny, n_state_var)

# Memory buffer used during copy of the states
copy_buffer = Array{T,4}(undef, nx, ny, n_state_var, nprt_per_rank)

return FilterData(weights, resampling_indices, statistics, avg_arr, var_arr, copy_buffer)
end

struct FilterData{T, S, U, V, X}

weights::T
resampling_indices::U
statistics::S
avg_arr::V
var_arr::V
copy_buffer::X

end

struct BootstrapFilter end

function run_particle_filter(init, filter_params::FilterParameters, model_params_dict::Dict, ::Type{BootstrapFilter})
Expand Down Expand Up @@ -251,38 +286,21 @@ function run_particle_filter(init, filter_params::FilterParameters, model_params
@timeit_debug timer "Model initialization" model_data = init(model_params_dict, nprt_per_rank, my_rank)

# TODO: put the body of this block in a function
@timeit_debug timer "Filter initialization" begin
# TODO: ideally this will be an argument of the function, to choose a
# different datatype.
T = Float64

if MPI.Comm_rank(MPI.COMM_WORLD) == filter_params.master_rank
weights = Vector{T}(undef, filter_params.nprt)
else
weights = Vector{T}(undef, nprt_per_rank)
end

resampling_indices = Vector{Int}(undef, filter_params.nprt)

# TODO: these variables should be set in a better way
nx, ny, n_state_var = model_data.model_params.nx, model_data.model_params.ny, model_data.model_params.n_state_var

statistics = Array{SummaryStat{T}, 3}(undef, nx, ny, n_state_var)
avg_arr = Array{T,3}(undef, nx, ny, n_state_var)
var_arr = Array{T,3}(undef, nx, ny, n_state_var)

# Memory buffer used during copy of the states
copy_buffer = Array{T}(undef, nx, ny, n_state_var, nprt_per_rank)
end
@timeit_debug timer "Filter initialization" filter_data = init_filter(filter_params, model_data, nprt_per_rank, Float64)

@timeit_debug timer "get_particles" particles = get_particles(model_data)
@timeit_debug timer "Mean and Var" get_mean_and_var!(statistics, particles, filter_params.master_rank)
@timeit_debug timer "Mean and Var" get_mean_and_var!(filter_data.statistics, particles, filter_params.master_rank)

# Write initial state (time = 0) + metadata
if(filter_params.verbose && my_rank == filter_params.master_rank)
@timeit_debug timer "IO" begin
unpack_statistics!(avg_arr, var_arr, statistics)
write_snapshot(filter_params.output_filename, model_data, avg_arr, var_arr, weights, 0)
unpack_statistics!(filter_data.avg_arr, filter_data.var_arr, filter_data.statistics)
write_snapshot(filter_params.output_filename,
model_data,
filter_data.avg_arr,
filter_data.var_arr,
filter_data.weights,
0)
end
end

Expand All @@ -296,7 +314,7 @@ function run_particle_filter(init, filter_params::FilterParameters, model_params

@timeit_debug timer "Particle State Update and Process Noise" model_observations = update_particles!(model_data, nprt_per_rank)

@timeit_debug timer "Weights" get_log_weights!(@view(weights[1:nprt_per_rank]),
@timeit_debug timer "Weights" get_log_weights!(@view(filter_data.weights[1:nprt_per_rank]),
truth_observations,
model_observations,
filter_params.weight_std)
Expand All @@ -308,33 +326,37 @@ function run_particle_filter(init, filter_params::FilterParameters, model_params
# for their chunk of state.
if my_rank == filter_params.master_rank
@timeit_debug timer "MPI Gather" MPI.Gather!(MPI.IN_PLACE,
UBuffer(weights, nprt_per_rank),
UBuffer(filter_data.weights, nprt_per_rank),
filter_params.master_rank,
MPI.COMM_WORLD)
@timeit_debug timer "Weights" normalized_exp!(weights)
@timeit_debug timer "Resample" resample!(resampling_indices, weights)
@timeit_debug timer "Weights" normalized_exp!(filter_data.weights)
@timeit_debug timer "Resample" resample!(filter_data.resampling_indices, filter_data.weights)

else
@timeit_debug timer "MPI Gather" MPI.Gather!(weights,
@timeit_debug timer "MPI Gather" MPI.Gather!(filter_data.weights,
nothing,
filter_params.master_rank,
MPI.COMM_WORLD)
end

# Broadcast resampled particle indices to all ranks
MPI.Bcast!(resampling_indices, filter_params.master_rank, MPI.COMM_WORLD)
MPI.Bcast!(filter_data.resampling_indices, filter_params.master_rank, MPI.COMM_WORLD)

@timeit_debug timer "get_particles" particles = get_particles(model_data)
@timeit_debug timer "State Copy" copy_states!(particles, copy_buffer, resampling_indices, my_rank, nprt_per_rank)
@timeit_debug timer "State Copy" copy_states!(particles,
filter_data.copy_buffer,
filter_data.resampling_indices,
my_rank,
nprt_per_rank)

@timeit_debug timer "get_particles" particles = get_particles(model_data)
@timeit_debug timer "Mean and Var" get_mean_and_var!(statistics, particles, filter_params.master_rank)
@timeit_debug timer "Mean and Var" get_mean_and_var!(filter_data.statistics, particles, filter_params.master_rank)

if my_rank == filter_params.master_rank && filter_params.verbose

@timeit_debug timer "IO" begin
unpack_statistics!(avg_arr, var_arr, statistics)
write_snapshot(filter_params.output_filename, model_data, avg_arr, var_arr, weights, it)
unpack_statistics!(filter_data.avg_arr, filter_data.var_arr, filter_data.statistics)
write_snapshot(filter_params.output_filename, model_data, filter_data.avg_arr, filter_data.var_arr, filter_data.weights, it)
end

end
Expand Down Expand Up @@ -370,9 +392,9 @@ function run_particle_filter(init, filter_params::FilterParameters, model_params
end
end

unpack_statistics!(avg_arr, var_arr, statistics)
unpack_statistics!(filter_data.avg_arr, filter_data.var_arr, filter_data.statistics)

return get_truth(model_data), avg_arr, var_arr
return get_truth(model_data), filter_data.avg_arr, filter_data.var_arr
end

# Initialise params struct with user-defined dict of values.
Expand Down

0 comments on commit 5dd2bed

Please sign in to comment.