Skip to content

Commit

Permalink
fix: support unicode character
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Apr 2, 2024
1 parent f213c52 commit 4036f28
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions src/elisa/infer/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,27 +221,24 @@ def run(self, rng_key, *args, **kwargs):

# Jaxns requires loglikelihood function to have explicit signatures.
local_dict = {}
loglik_fn_def = """def loglik_fn({}):\n
\tparams = {{{}}}\n
loglik_fn_def = """def loglik_fn(params):\n
\tparams = {k + '_base': v for k, v in params.items()}\n
\treturn log_density_(reparam_model, args, kwargs, params)[0]
""".format(
", ".join([f"{name.replace('.', '_')}_base" for name in param_names]),
", ".join([f"'{name}_base': {name.replace('.', '_')}_base" for name in param_names]),
)
"""
exec(loglik_fn_def, locals(), local_dict)
loglik_fn = local_dict["loglik_fn"]

# use NestedSampler with identity prior chain
def prior_model():
params = []
params = {}
for name in param_names:
shape = prototype_trace[name]["fn"].shape()
param = yield Prior(
tfpd.Uniform(low=jnp.zeros(shape), high=jnp.ones(shape)),
name=name + "_base",
)
params.append(param)
return tuple(params)
params[name] = param
return params

model = Model(prior_model=prior_model, log_likelihood=loglik_fn)

Expand Down

0 comments on commit 4036f28

Please sign in to comment.