Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with Blog Post Estimating Estrogen #22

Open
oxinabox opened this issue Nov 21, 2022 · 0 comments
Open

Issues with Blog Post Estimating Estrogen #22

oxinabox opened this issue Nov 21, 2022 · 0 comments

Comments

@oxinabox
Copy link
Owner

@sethaxen told me on slack, and I am recording it here so it doesn't get eaten by history before i act on it, the following:


Great start! The plots from the final 3 sampling results (conditioned on c3, c8, and c1 or any subset of these) looked fishy, which made me think you were getting divergences, which happen when HMC encounters regions of high curvature where it can't reliably sample. So let's do a few diagnostic checks. First, let's sample multiple chains, as this allows more reliable convergence diagnostics:

julia> chain=sample(model, NUTS(), MCMCThreads(), 1_000, 4)
...
Chains MCMC chain (1000×23×4 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 25.09 seconds
Compute duration  = 99.61 seconds
parameters        = c_max, t_max, halflife, err, c2, c4, c6, c10, c12, c16, c24
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters       mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol    Float64   Float64    Float64   Float64     Float64   Float64       Float64 

       c_max   108.6282    1.9417     0.0307    0.0731    683.8163    1.0067        6.8649
       t_max     2.1786    0.0673     0.0011    0.0019   1278.5006    1.0022       12.8349
    halflife     6.8477    0.8309     0.0131    0.0409    412.5109    1.0126        4.1412
         err     1.1108    1.0635     0.0168    0.0827     48.9063    1.0935        0.4910
          c2    99.7609    3.3208     0.0525    0.0976   1299.0791    1.0029       13.0415
          c4    90.2151    2.0221     0.0320    0.0432   2331.7233    1.0008       23.4083
          c6    73.6348    1.9702     0.0312    0.0505   1622.1679    1.0020       16.2850
         c10    49.0056    2.5075     0.0396    0.0819    913.6478    1.0080        9.1722
         c12    39.9438    2.7244     0.0431    0.0925    823.6327    1.0085        8.2685
         c16    26.6574    2.6566     0.0420    0.0985    699.1499    1.0102        7.0188
         c24    11.8861    2.7374     0.0433    0.1261    485.6644    1.0110        4.8756

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%      97.5% 
      Symbol    Float64    Float64    Float64    Float64    Float64 

       c_max   104.3847   108.2221   108.8569   109.1745   112.0284
       t_max     2.0441     2.1585     2.1771     2.1943     2.3212
    halflife     6.0577     6.6845     6.7670     6.9133     7.8893
         err     0.2000     0.3568     0.7444     1.5416     3.8401
          c2    91.7489    98.9999   100.1175   100.9388   105.9018
          c4    85.8345    89.7698    90.3332    90.7347    94.1405
          c6    69.5013    73.0164    73.6187    74.1465    77.9393
         c10    44.0264    48.1991    49.0064    49.6336    54.9456
         c12    34.5891    39.2227    39.7270    40.6457    45.4985
         c16    21.5447    25.9203    26.3998    27.3007    32.2124
         c24     8.0231    11.1646    11.5504    12.3348    17.1209

julia> mean(chain[:numerical_error])
0.183

The columns to check here are ess and rhat . The first estimates how many truly independent draws would give an estimate for the mean of a given parameter with the same standard error as these non-independent draws. 45 draws for err is too low and indicates something is wrong. rhat is a convergence diagnostic and for all parameters should be less than 1.01, but some exceed this threshold. In the final check, we see that 18% of transitions failed due to numerical error (usually divergences). So there are geometric issues preventing some regions of the posterior from being sampled. So we can at least say MCMC didn't work well, and I wouldn't do much downstream analysis with these results unless it was to try to figure out why sampling failed. Often sampling problems indicate problems with the model.
Sometimes we can increase the adapt delta to some very large value and re-run sampling. This causes HMC to adapt a smaller step size and be better able to handle high curvature:

julia> chain2=sample(model, NUTS(0.999), MCMCThreads(), 1_000, 4);

julia> mean(chain2[:numerical_error])
0.0025

Even though most of the divergences went away, at such a high adapt delta, we should see no more, so I think it's worth looking into this further to see if the model can be improved.
One thing we can check is if divergences cluster in parameter space, which we can use ArviZ for:

julia> using ArviZ

julia> idata = convert_to_inference_data(chain)
InferenceData with groups:
  > posterior
  > sample_stats

julia> plot_pair(idata; var_names=[:c_max, :t_max, :halflife, :err], divergences=true)

This plot (attached) shows that divergent transitions occur when err is low. In fact, you may have a funnel geometry, which tends to pose problems for MCMC methods.
I wasn't able to put more time into this, but if you come back to the model, I'd suggest maybe simulating dose curves and observations from the prior fixing sigma to low and high values and see if it makes sense why sigma being low could be problematic.

plots from ArviZ

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant