-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Request for Special Priors #91
Comments
Hi @Gaurav17Joshi, let me see if I understand correctly. You want to use multi-variate uniform distribution for A_i, and sigma_i, and for t_i you want to use a scaled and shifted Beta(1, n-i)? Note, that for t_i generated this way they are not forced to be sorted, and thus identifiable, so you will likely have degeneracies in your posterior. Also, the mean of t_i by the above definition would be 1/(1 + n - i), which goes from large to small as i increases, and seems opposite from what you want. You can do all this without special priors. Note, the casting of constants to I'll assume some data for illustration. from jaxns.types import float_type
# Your hyperparameters for priors, when using 64-bit it's important to cast things appropriately
n = 3
t_max = jnp.asarray(20., float_type)
t_min = jnp.asarray(0., float_type)
A_lower = jnp.asarray(0., float_type)
A_upper = jnp.asarray(1., float_type)
sigma_lower = jnp.asarray(0., float_type)
sigma_upper = jnp.asarray(1., float_type)
# Make fake data
X = jnp.linspace(t_min, t_max, 5)
Y = jnp.exp(jnp.sin(X)) Then define the prior model, def prior_model() -> PriorModelGen:
A = yield Prior(tfpd.Uniform(low=A_lower * jnp.ones(n), high=A_upper * jnp.ones(n)), name='A')
sigma = yield Prior(tfpd.Uniform(low=sigma_lower * jnp.ones(n), high=sigma_upper * jnp.ones(n)), name='sigma')
t_array = []
scale_bij = tfp.bijectors.Scale(scale=t_max - t_min)
shift_bij = tfp.bijectors.Shift(shift=t_min)
for i in range(n):
underlying_beta = tfpd.Beta(
concentration1=jnp.asarray(1., float_type),
concentration0=jnp.asarray(n - i, float_type)
)
t = yield Prior(shift_bij(scale_bij(underlying_beta)), name=f"t{i}")
t_array.append(t)
t_array = jnp.stack(t_array)
return A, sigma, t_array Finally, the likelihood, and let's test the model def log_likelihood(A, sigma, t_array):
@vmap
def eval_mean(x):
dx = (t_array - x) / sigma
components = A * jnp.exp(-0.5 * jnp.square(dx))
return jnp.sum(components)
m_X = eval_mean(X)
# Do something with m_X
return -jnp.sum(jnp.square(Y - m_X))
model = Model(prior_model=prior_model,
log_likelihood=log_likelihood)
model.sanity_check(random.PRNGKey(0), S=100)
# Example prior sample
print(model.transform(model.sample_U(random.PRNGKey(42))))
# {'A': Array([0.26283673, 0.10945365, 0.46926031], dtype=float64), 'sigma': Array([0.40959993, 0.08672034, 0.48140902], dtype=float64), 't0': Array(15.24431331, dtype=float64), 't1': Array(19.5986375, dtype=float64), 't2': Array(0.76363376, dtype=float64)} Does this meet your need? |
Hi @Joshuaalbert , thanks for the prompt reply As for the second one, the Constrained Multivariate Beta Prior, is indeed shifted and scaled, but it is also conditioned. The shift and scale factor is not same for all, for
Avoiding degeneracy in the times is central to this project. As I said the ForcedIdentifiabilty Special prior does an excellent job of providing non degenerate samples following |
Ah, I see. I misread your original post, which I why I pointed out the degeneracy. But, it's very easy to incorporate into the code if t_array = []
scale_bij = tfp.bijectors.Scale(scale=t_max - t_min)
shift_bij = tfp.bijectors.Shift(shift=t_min)
for i in range(n):
underlying_beta = tfpd.Beta(
concentration1=jnp.asarray(1., float_type),
concentration0=jnp.asarray(n - i, float_type)
)
t = yield Prior(shift_bij(scale_bij(underlying_beta)), name=f"t{i}")
# Update the shift and scale here
scale_bij = tfp.bijectors.Scale(scale=t_max - t)
shift_bij = tfp.bijectors.Shift(shift=t)
t_array.append(t)
t_array = jnp.stack(t_array) |
Thanks, this is working well in my code. Just one more request, do you have some ideas as to how one may prepare tests for such priors and the prior_model in general. |
What aspect are you looking to test? |
@Gaurav17Joshi would love to help you further with testing prior, but can you please open a thread in the discussions? I'll close this given that we've addressed the initial issue. |
Is your feature request related to a problem? Please describe.
Hi, I am working on a project with the stingray library in which we are making a feature to evaluate astronomical time series using gaussian processes with jaxns for Bayesian inference and evidence sampling.
For this, we want to use some multivariable priors. We want to create a mean function as the sum of multiple functions (say gaussian/exponential).
$A_i, t_i, \sigma_i$ are its parameters.
Here n is the number of gaussians and
We want to build two kinds of multivariate distributions for these mean parameters for our jaxns prior_model.
Describe the solution you'd like
It would be very helpful if these two special priors could be included in the jaxns special prior section.
I would want to help make these features if you can share some references as to how these special priors are made
Describe alternatives you've considered
I tried to make tfpd joint distributions for these but as it did not work in jaxns as tfpd joint distributions do not have a _quatile function.
The text was updated successfully, but these errors were encountered: