From 4036f28e3a87aacccbe68bbe19ccd8b332f6585b Mon Sep 17 00:00:00 2001 From: xuewc Date: Tue, 2 Apr 2024 21:03:03 +0800 Subject: [PATCH] fix: support unicode character --- src/elisa/infer/nested_sampling.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/elisa/infer/nested_sampling.py b/src/elisa/infer/nested_sampling.py index 24b93bc5..b1f796b6 100644 --- a/src/elisa/infer/nested_sampling.py +++ b/src/elisa/infer/nested_sampling.py @@ -221,27 +221,24 @@ def run(self, rng_key, *args, **kwargs): # Jaxns requires loglikelihood function to have explicit signatures. local_dict = {} - loglik_fn_def = """def loglik_fn({}):\n - \tparams = {{{}}}\n + loglik_fn_def = """def loglik_fn(params):\n + \tparams = {k + '_base': v for k, v in params.items()}\n \treturn log_density_(reparam_model, args, kwargs, params)[0] - """.format( - ", ".join([f"{name.replace('.', '_')}_base" for name in param_names]), - ", ".join([f"'{name}_base': {name.replace('.', '_')}_base" for name in param_names]), - ) + """ exec(loglik_fn_def, locals(), local_dict) loglik_fn = local_dict["loglik_fn"] # use NestedSampler with identity prior chain def prior_model(): - params = [] + params = {} for name in param_names: shape = prototype_trace[name]["fn"].shape() param = yield Prior( tfpd.Uniform(low=jnp.zeros(shape), high=jnp.ones(shape)), name=name + "_base", ) - params.append(param) - return tuple(params) + params[name] = param + return params model = Model(prior_model=prior_model, log_likelihood=loglik_fn)