Skip to content

Commit

Permalink
Allow compile_kwargs in sample_smc (#7702)
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski authored Mar 3, 2025
1 parent 3ccff92 commit cc90212
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
21 changes: 15 additions & 6 deletions pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
model=None,
random_seed=None,
threshold=0.5,
compile_kwargs: dict | None = None,
):
"""
Initialize the SMC_kernel class.
Expand All @@ -154,6 +155,8 @@ def __init__(
Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
It should be between 0 and 1.
compile_kwargs: dict, optional
Keyword arguments passed to pytensor.function
Attributes
----------
Expand All @@ -172,8 +175,8 @@ def __init__(
self.model = modelcontext(model)
self.variables = self.model.value_vars

self.var_info = {}
self.tempered_posterior = None
self.var_info: dict[str, tuple] = {}
self.tempered_posterior: np.ndarray
self.prior_logp = None
self.likelihood_logp = None
self.tempered_posterior_logp = None
Expand All @@ -184,6 +187,7 @@ def __init__(
self.iteration = 0
self.resampling_indexes = None
self.weights = np.ones(self.draws) / self.draws
self.compile_kwargs = compile_kwargs if compile_kwargs is not None else {}

def initialize_population(self) -> dict[str, np.ndarray]:
"""Create an initial population from the prior distribution."""
Expand Down Expand Up @@ -239,10 +243,10 @@ def _initialize_kernel(self):
shared = make_shared_replacements(initial_point, self.variables, self.model)

self.prior_logp_func = _logp_forw(
initial_point, [self.model.varlogp], self.variables, shared
initial_point, [self.model.varlogp], self.variables, shared, self.compile_kwargs
)
self.likelihood_logp_func = _logp_forw(
initial_point, [self.model.datalogp], self.variables, shared
initial_point, [self.model.datalogp], self.variables, shared, self.compile_kwargs
)

priors = [self.prior_logp_func(sample) for sample in self.tempered_posterior]
Expand Down Expand Up @@ -606,7 +610,7 @@ def systematic_resampling(weights, rng):
return new_indices


def _logp_forw(point, out_vars, in_vars, shared):
def _logp_forw(point, out_vars, in_vars, shared, compile_kwargs=None):
"""Compile PyTensor function of the model and the input and output variables.
Parameters
Expand All @@ -617,7 +621,12 @@ def _logp_forw(point, out_vars, in_vars, shared):
Containing Distribution for the input variables
shared : list
Containing TensorVariable for depended shared data
compile_kwargs: dict, optional
Additional keyword arguments passed to pytensor.function
"""
if compile_kwargs is None:
compile_kwargs = {}

# Replace integer inputs with rounded float inputs
if any(var.dtype in discrete_types for var in in_vars):
replace_int_input = {}
Expand All @@ -636,6 +645,6 @@ def _logp_forw(point, out_vars, in_vars, shared):
out_list, inarray0 = join_nonshared_inputs(
point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared
)
f = compile([inarray0], out_list[0])
f = compile([inarray0], out_list[0], **compile_kwargs)
f.trust_input = True
return f
18 changes: 14 additions & 4 deletions pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def sample_smc(
return_inferencedata=True,
idata_kwargs=None,
progressbar=True,
compile_kwargs: dict | None = None,
**kernel_kwargs,
) -> InferenceData | MultiTrace:
r"""
Expand Down Expand Up @@ -95,17 +96,21 @@ def sample_smc(
Keyword arguments for :func:`pymc.to_inference_data`.
progressbar : bool, optional, default True
Whether or not to display a progress bar in the command line.
compile_kwargs: dict, optional
Keyword arguments to pass to pytensor.function
**kernel_kwargs : dict, optional
Keyword arguments passed to the SMC_kernel. The default IMH kernel takes the following keywords:
threshold : float, default 0.5
Determines the change of beta from stage to stage, i.e. indirectly the number of stages,
the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
It should be between 0 and 1.
correlation_threshold : float, default 0.01
The lower the value the higher the number of MCMC steps computed automatically.
Defaults to 0.01. It should be between 0 and 1.
Keyword arguments for other kernels should be checked in the respective docstrings.
correlation_threshold : float, default 0.01
The lower the value the higher the number of MCMC steps computed automatically.
Defaults to 0.01. It should be between 0 and 1.
Additional keyword arguments for other kernels should be checked in the respective docstrings.
Notes
-----
Expand Down Expand Up @@ -160,6 +165,11 @@ def sample_smc(
else:
cores = min(chains, cores)

if compile_kwargs is None:
compile_kwargs = {}

kernel_kwargs["compile_kwargs"] = compile_kwargs

random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains)

model = modelcontext(model)
Expand Down

0 comments on commit cc90212

Please sign in to comment.