Skip to content

Commit

Permalink
Implement optimal particle filter from Alex (#133)
Browse files Browse the repository at this point in the history
* Add functions for Alex optimal filter

* Rewrite functions with rho_bar a vector and R21,R12 matrices

* sample_height_proposal passes preliminary tests

* sample_height_proposal passes preliminary tests

* Add validation against Alex code

* Add performance benchmarks

* Shift grid indices so that both station and grid indexing starts at 1. Shift indices when initialising StationVectors from st.st_ij of Sample_Optimal_Height_Proposal.jl

* Add get_log_weights function

* Add comments, fix bugs

change i to iprt for correct behaviour in get_log_weights!
interpolate variables in function calls that are timed with @Btime

* Apply suggestions from code review

Add dots and in-place operations

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>

* matrices.K*KH is a real matrix product

* Use mul! to do multiplication in-place

* Remove duplicate call

* Minimal allocations reduction

* Add get_stations and set_particles! to the interface

* Reorder data structures for compatibility. Remove unit tests (will move to runtests.jl)

* Integrate Optimal Filter to ParticleDA

- Added unit tests and validation tests
- Made changes to model interface to expose fields needed by optimal filter
- Added a OptimalFilter type for dispatching
- No implementation of run_particle_filter yet

* Apply suggestions from code review

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>

* Draft implementation

* Changes suggested by review

Changing rand(rng,nd) to randn(rng) causes validation test to fail unless it's done on both sides. Unclear why.

* Implementation according to comment on 5 Feb in #136

* Add compat entry for FFTW

* Change to a more sensible test height field

* Make grid size nx (not nx+1) on both ParticleDA and OptimalFilter. Something goes wrong.

* Correct dx,dy in covariance calculations. tests pass.

* Fix bugs

* Apply suggestions from code review

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>

* Apply suggestions from code review

In two parts because I didn't realize github doesn't automatically load a large diff even when there are comments.

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>

* Define model grid as storing values at the nodes (instead of cell faces)

The only change required is defining dx as L_x/(nx-1) instead of L_x/nx
In my opinion this is more consistent with the way we define the stations
Makes it easier (or perhaps even possible) to integrate with the optimal filter
Also fixed a bug in write_stations that created an Int dataset on the hdf5 file for the coordinate (float) value
Reference data and some tests had to be updated

* Pass grid parameters instead of using model_params

* Fix bugs

* Minor improvements

- Parametric types in init_filter
- Comments
- Indentations

* Fix arguments to init_*line_matrices

* Use sigma and lambda from model params instead of redefining them in filter.

* update notebook with animation and parameters

* Point out bootstrap filter in example

* add FFTW to benchmark environment

* Pass filter type to run_particle_filter and use it to dispatch (#155)

* Pass filter type to run_particle_filter and use it to dispatch appropriate functions

* Update to simpler syntax

* Update benchmarks and the benchmark environment

Co-authored-by: Mosè Giordano <mose@gnu.org>

* Split get_grid_size into three functions (#157)

* Split get_grid_size into three functions that return NTuples

* Remove extra parenthesis

* Clarify function names

* Changes suggested by review

- Define Grid in OptimalFilter
- Annotate argument types
- Fix runtests.jl
- Fix docstrings

* Add integration test with optimal filter

* Add input file for the test

* Precompute FFTW plan (#159)

* Precompute FFTW plan

* Support in-place and out-of-place FFTs

* Reduce duplication in `normalized_2d_fft!` methods

* Force A to be hermitian to avoid roundoff errors

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>

* Throw error if matrix is not approximately Hermitian

* clean up

* Pass RNG directly to the filter instead of the model

Also, use the RNG also in the `resample!` function.

* Reduce allocations in `sample_height_proposal!` (#162)

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>
Co-authored-by: Mosè Giordano <mose@gnu.org>
  • Loading branch information
3 people committed Mar 17, 2021
1 parent 2ee15f0 commit 38fad3b
Show file tree
Hide file tree
Showing 18 changed files with 1,801 additions and 113 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
env = Dict(
"JULIA_NUM_THREADS" => "2",
),
)
),
)
- name: Print judgement
shell: julia --color=yes {0}
Expand Down
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ version = "0.1.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
Expand All @@ -18,4 +22,5 @@ HDF5 = "0.14, 0.15"
MPI = "0.16"
TimerOutputs = "0.5"
YAML = "0.4"
FFTW = "1"
julia = "1.5"
43 changes: 23 additions & 20 deletions benchmark/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "8ed9de2f1b1a9b1dee48582ad477c6e67b83eb2c"
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.0.0"
version = "1.0.1"

[[AbstractTrees]]
deps = ["Markdown"]
git-tree-sha1 = "33e450545eaf7699da1a6e755f9ea65f14077a45"
git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.3.3"
version = "0.3.4"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand Down Expand Up @@ -110,9 +109,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "0fc424e725eaec6ea3e9fa8df773bee18a1ab503"
git-tree-sha1 = "e64debe8cd174cc52d7dd617ebc5492c6f8b698c"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.24.14"
version = "0.24.15"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
Expand Down Expand Up @@ -144,9 +143,9 @@ version = "0.4.7"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "4705cc4e212c3c978c60b1b18118ec49b4d731fd"
git-tree-sha1 = "dd4ab4257c257532003eb9836eea07473fcc732e"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.5"
version = "0.11.6"

[[Future]]
deps = ["Random"]
Expand Down Expand Up @@ -338,25 +337,25 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.0"

[[PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
git-tree-sha1 = "95a4038d1011dfdbde7cecd2ad0ac411e53ab1bc"
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "f82a0e71f222199de8e9eb9a09977bd0767d52a0"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.10.1"
version = "0.11.0"

[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "50c9a9ed8c714945e01cd53a21007ed3865ed714"
git-tree-sha1 = "223a825cccef2228f3fdbf2ecc7ca93363059073"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.0.15"
version = "1.0.16"

[[ParticleDA]]
deps = ["Distributions", "HDF5", "LinearAlgebra", "MPI", "Statistics", "TimerOutputs", "YAML"]
deps = ["Distributions", "FFTW", "Future", "HDF5", "LinearAlgebra", "MPI", "Random", "SparseArrays", "Statistics", "TimerOutputs", "YAML"]
path = ".."
uuid = "61cd1fb4-f4c4-4bc8-80c6-ea5639a6ca2e"
version = "0.1.0"

[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs"]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[PkgBenchmark]]
Expand Down Expand Up @@ -401,9 +400,9 @@ version = "1.0.0"

[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "cfbac6c1ed70c002ec6361e7fd334f02820d6419"
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.2"
version = "1.1.3"

[[Rmath]]
deps = ["Random", "Rmath_jll"]
Expand Down Expand Up @@ -510,9 +509,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Printf"]
git-tree-sha1 = "3318281dd4121ecf9713ce1383b9ace7d7476fdd"
git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.7"
version = "0.5.8"

[[URIs]]
git-tree-sha1 = "7855809b88d7b16e9b029afd17880930626f54a2"
Expand Down Expand Up @@ -556,3 +555,7 @@ version = "1.0.18+1"
[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
2 changes: 2 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ BenchmarkCI = "20533458-34a3-403d-a444-e18f38190b5b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
GaussianRandomFields = "e4b2fa32-6e09-5554-b718-106ed5adafe9"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Expand All @@ -15,6 +16,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
[compat]
BenchmarkCI = "0.1"
BenchmarkTools = "0.5"
FFTW = "1"
GaussianRandomFields = "2.1.1"
MPI = "0.16"
PkgBenchmark = "0.2.10"
Expand Down
17 changes: 14 additions & 3 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using BenchmarkTools
using ParticleDA
using MPI
using Random
using Future
using Base.Threads

include(joinpath(joinpath(@__DIR__, "..", "test"), "model", "model.jl"))
using .Model
Expand All @@ -15,6 +18,7 @@ const my_size = MPI.Comm_size(MPI.COMM_WORLD)

SUITE["base"] = BenchmarkGroup()
SUITE["BootstrapFilter"] = BenchmarkGroup()
SUITE["OptimalFilter"] = BenchmarkGroup()

const params = Dict(
"filter" => Dict(
Expand All @@ -34,8 +38,12 @@ const params = Dict(
)

const nprt_per_rank = Int(params["filter"]["nprt"] / my_size)
const model_data = Model.init(params["model"]["llw2d"], nprt_per_rank, my_rank)
const bootstrap_filter_data = ParticleDA.init_filter(ParticleDA.get_params(ParticleDA.FilterParameters, params["filter"]), model_data, nprt_per_rank, Float64)
const rng = let
m = Random.default_rng()
[m; accumulate(Future.randjump, fill(big(10)^20, nthreads()-1), init=m)]
end
const model_data = Model.init(params["model"]["llw2d"], nprt_per_rank, my_rank, rng)
const bootstrap_filter_data = ParticleDA.init_filter(ParticleDA.get_params(ParticleDA.FilterParameters, params["filter"]), model_data, nprt_per_rank, rng, Float64, BootstrapFilter())
const filter_params = ParticleDA.get_params(ParticleDA.FilterParameters, params["filter"])

SUITE["base"]["get_particles"] = @benchmarkable ParticleDA.get_particles($(model_data))
Expand All @@ -48,5 +56,8 @@ SUITE["base"]["normalized_exp!"] = @benchmarkable ParticleDA.normalized_exp!(wei
SUITE["base"]["resample!"] = @benchmarkable ParticleDA.resample!(resampling_indices, weights) setup=(resampling_indices = Vector{Int}(undef, filter_params.nprt); weights = rand(filter_params.nprt))
# SUITE["base"]["copy_states!"] = @benchmarkable ParticleDA.copy_states!($(ParticleDA.get_particles(model_data)), $(bootstrap_filter_data.copy_buffer), $(bootstrap_filter_data.resampling_indices), $(my_rank), $(nprt_per_rank))

SUITE["BootstrapFilter"]["init_filter"] = @benchmarkable ParticleDA.init_filter($(filter_params), $(model_data), $(nprt_per_rank), Float64)
SUITE["BootstrapFilter"]["init_filter"] = @benchmarkable ParticleDA.init_filter($(filter_params), $(model_data), $(nprt_per_rank), $(rng), Float64, $(BootstrapFilter()))
SUITE["BootstrapFilter"]["run_particle_filter"] = @benchmarkable ParticleDA.run_particle_filter($(Model.init), $(params), $(BootstrapFilter())) seconds=30 setup=(cd(mktempdir()))

SUITE["OptimalFilter"]["init_filter"] = @benchmarkable ParticleDA.init_filter($(filter_params), $(model_data), $(nprt_per_rank), $(rng), Float64, $(OptimalFilter()))
SUITE["OptimalFilter"]["run_particle_filter"] = @benchmarkable ParticleDA.run_particle_filter($(Model.init), $(params), $(OptimalFilter())) seconds=30 setup=(cd(mktempdir()))
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ The particle filter parameters are saved in the following data structure:
ParticleDA.FilterParameters
```

## Example: tsunami modelling
## Example: tsunami modelling with the bootstrap filter

A full example of a model interfacing `ParticleDA` is available in
`test/model/model.jl`. This model represents a tsunami and is partly based on
Expand All @@ -124,7 +124,7 @@ Pkg.instantiate()
include(module_src)
using .Model
# Run the particle filter using the `init` file defined in the `Model` module
# Run the particle filter using the `init` file defined in the `Model` module.
run_particle_filter(Model.init, input_file, BootstrapFilter())
```

Expand Down
105 changes: 87 additions & 18 deletions extra/Plot_tdac_output.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -21,21 +21,11 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The following datasets found in file ../tdac.h5 : ['data_avg', 'data_syn', 'data_var', 'grid', 'params', 'stations', 'timer', 'weights']\n",
"The following timestamps found: ['t0', 't1', 't10', 't11', 't12', 't13', 't14', 't15', 't16', 't17', 't18', 't19', 't2', 't20', 't3', 't4', 't5', 't6', 't7', 't8', 't9']\n",
"The following fields found: ['height', 'vx', 'vy']\n"
]
}
],
"source": [
"filename = \"../tdac.h5\"\n",
"outputs": [],
"source": [
"filename = \"../particle_da.h5\"\n",
"fh = h5py.File(filename,'r')\n",
"print(\"The following datasets found in file\",filename,\":\",list(fh))\n",
"if \"data_syn\" in list(fh): print(\"The following timestamps found: \", list(fh[\"data_syn\"]))\n",
Expand Down Expand Up @@ -103,7 +93,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"plt.rcParams[\"figure.figsize\"] = (18,6)\n",
Expand All @@ -128,7 +120,7 @@
" a.scatter(x_st, y_st, color = 'r', marker = '*')\n",
" a.set_xlabel(f\"x [{y.units:~}]\")\n",
" a.set_ylabel(f\"y [{x.units:~}]\")\n",
" plt.colorbar(im,ax=a)"
" fig.colorbar(im,ax=a)"
]
},
{
Expand Down Expand Up @@ -178,6 +170,83 @@
"plt.xlabel('Time step')\n",
"plt.ylabel('Estimated Sample Size (1 / sum(weight^2))');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Animation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import animation\n",
"\n",
"plt.rcParams[\"figure.figsize\"] = (18,6)\n",
"\n",
"n_contours = 100\n",
"\n",
"fig, ax = plt.subplots(1,3)\n",
"ax[0].set_title(f\"True {field_desc.lower()} [{z_t.units:~}]\")\n",
"ax[1].set_title(f\"Assimilated {field_desc.lower()} [{z_avg.units:~}]\")\n",
"ax[2].set_title(f\"Std of assimilated {field_desc.lower()} [{z_std.units:~}]\")\n",
"\n",
"cb = []\n",
"\n",
"def animate(i):\n",
" \n",
" for a in ax:\n",
" for c in a.collections:\n",
" c.remove()\n",
" \n",
" timestamp = f\"t{i}\"\n",
" z_t = fh[\"data_syn\"][timestamp][field][()] * ureg(field_unit)\n",
" z_avg = fh[\"data_avg\"][timestamp][field][()] * ureg(field_unit)\n",
" z_var = fh[\"data_var\"][timestamp][field][()] * ureg(var_unit)\n",
" z_std = np.sqrt(z_var)\n",
" \n",
" zmax = max(np.max(z_t), np.max(z_avg)).magnitude\n",
" zmin = min(np.min(z_t), np.min(z_avg)).magnitude\n",
" levels = np.linspace(zmin, zmax, n_contours) \n",
" \n",
" i1 = ax[0].contourf(x,y,z_t,levels)\n",
" i2 = ax[1].contourf(x,y,z_avg,levels)\n",
" i3 = ax[2].contourf(x,y,z_std,n_contours)\n",
" \n",
" images = [i1, i2, i3]\n",
" \n",
" for a,im in zip(ax,images):\n",
" a.scatter(x_st, y_st, color = 'r', marker = '*')\n",
" a.set_xlabel(f\"x [{y.units:~}]\")\n",
" a.set_ylabel(f\"y [{x.units:~}]\")\n",
" #if len(cb) < 3:\n",
" # cb.append(plt.colorbar(im, ax=a))\n",
" #else:\n",
" # for c in cb:\n",
" # c.update_normal(im)\n",
" \n",
" return images\n",
" \n",
"anim = animation.FuncAnimation(fig, animate, frames=50)\n",
"anim.save(\"animation.mp4\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for key in fh[\"params\"].attrs:\n",
" try:\n",
" print(key, ':', fh[\"params\"].attrs[key])\n",
" except TypeError:\n",
" print(key, ':', 'N/A')"
]
}
],
"metadata": {
Expand All @@ -196,7 +265,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
"version": "3.8.7"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 38fad3b

Please sign in to comment.