Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce timers with the @timeit_debug macro #50

Merged
merged 4 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
Expand Down
121 changes: 69 additions & 52 deletions src/TDAC.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TDAC

using Random, Distributions, Statistics, Distributed, Base.Threads, YAML, GaussianRandomFields, HDF5
using TimerOutputs

export tdac, main

Expand Down Expand Up @@ -430,92 +431,108 @@ end

function tdac(params::tdac_params)

timer = TimerOutput()
if params.enable_timers
TimerOutputs.enable_debug_timings(TDAC)
end

if(params.verbose)
write_grid(params)
write_params(params)
@timeit_debug timer "IO" write_grid(params)
@timeit_debug timer "IO" write_params(params)
end

states, observations, stations, weights = init_tdac(params)

background_grf = init_gaussian_random_field_generator(params)

rng = Random.MersenneTwister(params.random_seed)

#TODO: Put all llw2d setup in one function
# Set up tsunami model
#TODO: Put these in a data structure
gg, hh, hm, hn, fm, fn, fe = LLW2d.setup(params.nx, params.ny, params.bathymetry_setup)

# obtain initial tsunami height
eta = reshape(@view(states.truth[1:params.dim_grid]), params.nx, params.ny)
LLW2d.initheight!(eta, hh, params.dx, params.dy, params.source_size)
@timeit_debug timer "Initialization" begin

states, observations, stations, weights = init_tdac(params)

background_grf = init_gaussian_random_field_generator(params)

# set station positions
LLW2d.set_stations!(stations.ist,
stations.jst,
params.station_separation,
params.station_boundary,
params.station_dx,
params.station_dy,
params.dx,
params.dy)
rng = Random.MersenneTwister(params.random_seed)

#TODO: Put all llw2d setup in one function
# Set up tsunami model
#TODO: Put these in a data structure
gg, hh, hm, hn, fm, fn, fe = LLW2d.setup(params.nx, params.ny, params.bathymetry_setup)

# obtain initial tsunami height
eta = reshape(@view(states.truth[1:params.dim_grid]), params.nx, params.ny)
LLW2d.initheight!(eta, hh, params.dx, params.dy, params.source_size)

# set station positions
LLW2d.set_stations!(stations.ist,
stations.jst,
params.station_separation,
params.station_boundary,
params.station_dx,
params.station_dy,
params.dx,
params.dy)

# Initialize all particles to the true initial state + noise
states.particles .= states.truth
for ip in 1:params.nprt
add_random_field!(@view(states.particles[:,ip]), background_grf, rng, params)
# Initialize all particles to the true initial state + noise
states.particles .= states.truth
for ip in 1:params.nprt
add_random_field!(@view(states.particles[:,ip]), background_grf, rng, params)
end

cov_obs = get_obs_covariance(stations.ist, stations.jst, params)

end

# Write initial state
if params.verbose
write_snapshot(states, 0, params)
end

cov_obs = get_obs_covariance(stations.ist, stations.jst, params)
@timeit_debug timer "IO" write_snapshot(states, 0, params)
end

for it in 1:params.n_time_step

# integrate true synthetic wavefield
tsunami_update!(states.truth, hm, hn, fn, fm, fe, gg, params)
@timeit_debug timer "True State Update" tsunami_update!(states.truth, hm, hn, fn, fm, fe, gg, params)

# Forecast: Update tsunami forecast and get observations from it
# Parallelised with threads.
Threads.@threads for ip in 1:params.nprt
# Parallelised with threads.

tsunami_update!(@view(states.particles[:,ip]), hm, hn, fn, fm, fe, gg, params)
@timeit_debug timer "Particle State Update" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't try, but I think you can put @timeit_debug timer "Particle State Update" right before the for loop in the next line? That line becomes a bit long, but at least you don't need a whole new begin-end block here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you can't

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I'd put the timer on the line inside the loop, but I think there was some issue with threads

Threads.@threads for ip in 1:params.nprt

tsunami_update!(@view(states.particles[:,ip]), hm, hn, fn, fm, fe, gg, params)

end
end

# Get observation from true synthetic wavefield
get_obs!(observations.truth, states.truth, stations.ist, stations.jst, params)
@timeit_debug timer "Observations" get_obs!(observations.truth, states.truth, stations.ist, stations.jst, params)

# Add process noise, get observations, add observation noise (to particles)
for ip in 1:params.nprt
add_random_field!(@view(states.particles[:,ip]), background_grf, rng, params)
get_obs!(@view(observations.model[:,ip]),
@view(states.particles[:,ip]),
stations.ist,
stations.jst,
params)
add_noise!(@view(observations.model[:,ip]), rng, params)
@timeit_debug timer "Process Noise" add_random_field!(@view(states.particles[:,ip]), background_grf, rng, params)
@timeit_debug timer "Observations" get_obs!(@view(observations.model[:,ip]),
@view(states.particles[:,ip]),
stations.ist,
stations.jst,
params)
@timeit_debug timer "Observation Noise" add_noise!(@view(observations.model[:,ip]), rng, params)
end

# Weigh and resample particles
get_weights!(weights, observations.truth, observations.model, cov_obs)
resample!(states.resampled, states.particles, weights)
states.particles .= states.resampled
@timeit_debug timer "Weights" get_weights!(weights, observations.truth, observations.model, cov_obs)
@timeit_debug timer "Resample" resample!(states.resampled, states.particles, weights)
@timeit_debug timer "State Copy" states.particles .= states.resampled

# Calculate statistical values
Statistics.mean!(states.avg, states.particles)
states.var .= @view(Statistics.var(states.particles; dims=2)[:])
@timeit_debug timer "Particle Stats" Statistics.mean!(states.avg, states.particles)
@timeit_debug timer "Particle Stats" states.var .= @view(Statistics.var(states.particles; dims=2)[:])

# Write output
if params.verbose
write_snapshot(states, it, params)
@timeit_debug timer "IO" write_snapshot(states, it, params)
end

end

if params.enable_timers
h5write(params.output_filename, "timer/rank0", string(timer))
end

return states.truth, states.avg, states.var
end

Expand Down
1 change: 1 addition & 0 deletions src/params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Base.@kwdef struct tdac_params{T<:AbstractFloat}
obs_noise_amplitude::T = 1.0

random_seed::Int = 12345
enable_timers::Bool = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth documenting it or you prefer to keep this field out of casual users' reach?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right, forgot to update the docstring

end

end