From cc902124ed08fe8c3007699451cc0c16f38520d2 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Date: Mon, 3 Mar 2025 23:34:09 +0800 Subject: [PATCH] Allow compile_kwargs in sample_smc (#7702) --- pymc/smc/kernels.py | 21 +++++++++++++++------ pymc/smc/sampling.py | 18 ++++++++++++++---- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index a5c86b5609..567b73f514 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -134,6 +134,7 @@ def __init__( model=None, random_seed=None, threshold=0.5, + compile_kwargs: dict | None = None, ): """ Initialize the SMC_kernel class. @@ -154,6 +155,8 @@ def __init__( Determines the change of beta from stage to stage, i.e.indirectly the number of stages, the higher the value of `threshold` the higher the number of stages. Defaults to 0.5. It should be between 0 and 1. + compile_kwargs: dict, optional + Keyword arguments passed to pytensor.function Attributes ---------- @@ -172,8 +175,8 @@ def __init__( self.model = modelcontext(model) self.variables = self.model.value_vars - self.var_info = {} - self.tempered_posterior = None + self.var_info: dict[str, tuple] = {} + self.tempered_posterior: np.ndarray self.prior_logp = None self.likelihood_logp = None self.tempered_posterior_logp = None @@ -184,6 +187,7 @@ def __init__( self.iteration = 0 self.resampling_indexes = None self.weights = np.ones(self.draws) / self.draws + self.compile_kwargs = compile_kwargs if compile_kwargs is not None else {} def initialize_population(self) -> dict[str, np.ndarray]: """Create an initial population from the prior distribution.""" @@ -239,10 +243,10 @@ def _initialize_kernel(self): shared = make_shared_replacements(initial_point, self.variables, self.model) self.prior_logp_func = _logp_forw( - initial_point, [self.model.varlogp], self.variables, shared + initial_point, [self.model.varlogp], self.variables, shared, self.compile_kwargs ) self.likelihood_logp_func = _logp_forw( - initial_point, [self.model.datalogp], self.variables, shared + initial_point, [self.model.datalogp], self.variables, shared, self.compile_kwargs ) priors = [self.prior_logp_func(sample) for sample in self.tempered_posterior] @@ -606,7 +610,7 @@ def systematic_resampling(weights, rng): return new_indices -def _logp_forw(point, out_vars, in_vars, shared): +def _logp_forw(point, out_vars, in_vars, shared, compile_kwargs=None): """Compile PyTensor function of the model and the input and output variables. Parameters @@ -617,7 +621,12 @@ def _logp_forw(point, out_vars, in_vars, shared): Containing Distribution for the input variables shared : list Containing TensorVariable for depended shared data + compile_kwargs: dict, optional + Additional keyword arguments passed to pytensor.function """ + if compile_kwargs is None: + compile_kwargs = {} + # Replace integer inputs with rounded float inputs if any(var.dtype in discrete_types for var in in_vars): replace_int_input = {} @@ -636,6 +645,6 @@ def _logp_forw(point, out_vars, in_vars, shared): out_list, inarray0 = join_nonshared_inputs( point=point, outputs=out_vars, inputs=in_vars, shared_inputs=shared ) - f = compile([inarray0], out_list[0]) + f = compile([inarray0], out_list[0], **compile_kwargs) f.trust_input = True return f diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index a4e8248814..f3176f464b 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -58,6 +58,7 @@ def sample_smc( return_inferencedata=True, idata_kwargs=None, progressbar=True, + compile_kwargs: dict | None = None, **kernel_kwargs, ) -> InferenceData | MultiTrace: r""" @@ -95,6 +96,9 @@ def sample_smc( Keyword arguments for :func:`pymc.to_inference_data`. progressbar : bool, optional, default True Whether or not to display a progress bar in the command line. + compile_kwargs: dict, optional + Keyword arguments to pass to pytensor.function + **kernel_kwargs : dict, optional Keyword arguments passed to the SMC_kernel. The default IMH kernel takes the following keywords: @@ -102,10 +106,11 @@ def sample_smc( Determines the change of beta from stage to stage, i.e. indirectly the number of stages, the higher the value of `threshold` the higher the number of stages. Defaults to 0.5. It should be between 0 and 1. - correlation_threshold : float, default 0.01 - The lower the value the higher the number of MCMC steps computed automatically. - Defaults to 0.01. It should be between 0 and 1. - Keyword arguments for other kernels should be checked in the respective docstrings. + correlation_threshold : float, default 0.01 + The lower the value the higher the number of MCMC steps computed automatically. + Defaults to 0.01. It should be between 0 and 1. + + Additional keyword arguments for other kernels should be checked in the respective docstrings. Notes ----- @@ -160,6 +165,11 @@ def sample_smc( else: cores = min(chains, cores) + if compile_kwargs is None: + compile_kwargs = {} + + kernel_kwargs["compile_kwargs"] = compile_kwargs + random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains) model = modelcontext(model)