Skip to content

Commit

Permalink
Merge pull request #13 from mj-will/sample-prior
Browse files Browse the repository at this point in the history
ENH: add option to start by sampling from the prior
  • Loading branch information
mj-will authored Mar 6, 2024
2 parents 7f10bd6 + 74289c8 commit e57355f
Showing 1 changed file with 47 additions and 35 deletions.
82 changes: 47 additions & 35 deletions src/nessai_torch/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
dims: int,
nlive: int = 1000,
tolerance: float = 0.1,
sample_prior_iterations: int = 2000,
outdir: Optional[str] = None,
save: bool = True,
parameter_labels: Optional[list[str]] = None,
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
self.seed = seed

self.reset_flow = int(reset_flow)
self.sample_prior_iterations = sample_prior_iterations
self.populate_count = 0
self.n_likelihood_calls = 0
self.sampling_time = 0
Expand Down Expand Up @@ -199,6 +201,46 @@ def insert_live_point(self, x, logl):
self.logl[index - 1] = logl
return index - 1

def update_proposal(self) -> None:
"""Update the proposal.
This includes training and populating the proposal.
"""
if self.proposal.has_pool:
if not self.proposal.populated:
if self.proposal.trainable:
with torch.enable_grad():
self.proposal.train(
self.live_points,
self.logl,
reset=self.should_reset,
)
proposal_acceptance = self.proposal.populate(
self.live_points,
self.logl,
)
self.history["proposal_acceptance"].append(
(self.iteration, proposal_acceptance)
)
self.proposal.compute_likelihoods()
if self.plot_pool:
plot_samples_1d(
self.live_points,
self.proposal.samples,
labels=["live points", "pool"],
parameter_labels=self.parameter_labels,
filename=os.path.join(
self.outdir, f"pool_it_{self.iteration}.png"
),
)
self.populate_count += 1
elif self.proposal.trainable:
with torch.enable_grad():
self.proposal.train(
self.live_points,
self.logl,
)

def step(self) -> None:
"""Perform one nested sampling step"""
self.logl_min = self.logl[0].clone()
Expand All @@ -207,41 +249,11 @@ def step(self) -> None:
self._logl_nested_samples.append(self.logl[0].detach().clone())
count = 0
while True:
if self.proposal.has_pool:
if not self.proposal.populated:
if self.proposal.trainable:
with torch.enable_grad():
self.proposal.train(
self.live_points,
self.logl,
reset=self.should_reset,
)
proposal_acceptance = self.proposal.populate(
self.live_points,
self.logl,
)
self.history["proposal_acceptance"].append(
(self.iteration, proposal_acceptance)
)
self.proposal.compute_likelihoods()
if self.plot_pool:
plot_samples_1d(
self.live_points,
self.proposal.samples,
labels=["live points", "pool"],
parameter_labels=self.parameter_labels,
filename=os.path.join(
self.outdir, f"pool_it_{self.iteration}.png"
),
)
self.populate_count += 1
elif self.proposal.trainable:
with torch.enable_grad():
self.proposal.train(
self.live_points,
self.logl,
)
x, logl = self.proposal.draw(self.live_points[0])
if self.iteration < self.sample_prior_iterations:
x, logl = self.prior_proposal.draw(self.live_points[0])
else:
self.update_proposal()
x, logl = self.proposal.draw(self.live_points[0])
if logl is None:
logl = self.log_likelihood_unit_hypercube(x)
count += 1
Expand Down

0 comments on commit e57355f

Please sign in to comment.