Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request for Special Priors #91

Closed
Gaurav17Joshi opened this issue Jun 18, 2023 · 6 comments
Closed

Request for Special Priors #91

Gaurav17Joshi opened this issue Jun 18, 2023 · 6 comments
Labels
enhancement New feature or request

Comments

@Gaurav17Joshi
Copy link

Is your feature request related to a problem? Please describe.
Hi, I am working on a project with the stingray library in which we are making a feature to evaluate astronomical time series using gaussian processes with jaxns for Bayesian inference and evidence sampling.

For this, we want to use some multivariable priors. We want to create a mean function as the sum of multiple functions (say gaussian/exponential).
Screenshot 2023-06-18 at 9 25 38 PM
Here n is the number of gaussians and $A_i, t_i, \sigma_i$ are its parameters.

We want to build two kinds of multivariate distributions for these mean parameters for our jaxns prior_model.

  1. Independent Multivariate Uniform Prior: For the parameter A, sigma, we have the upper and lower bounds and we want to uniformly sample n A's.
  2. Constrained Multivariate Beta Prior: For the parameter $t_i$, we need a constrained prior which samples out times such that $t_i < t_{i+1}$. Eg. for a 0 to 20s lightcurve with n = 3, we will have to sample the three max_times such that $0 < t_{0} < t_{1}< t_{2} < 20$. We have used the Forced_Indetifiability, from jaxns and it works well, but in the reference paper, the authors have used a constrained beta distribution. Ie $t_0$ is a beta(alpha =1, beta = n = 3) distribution from 0 to 20s, $t_1$ is a beta(1,n-1 = 2 ) from $t_0$ to 20s, $t_2$ is a beta(1, n-2 = 1) distribution from $t_1$ to 20, ie a conditional beta where alpha remains 1, and beta decreases from n to 1. We want to implement that also.

Describe the solution you'd like
It would be very helpful if these two special priors could be included in the jaxns special prior section.
I would want to help make these features if you can share some references as to how these special priors are made

Describe alternatives you've considered
I tried to make tfpd joint distributions for these but as it did not work in jaxns as tfpd joint distributions do not have a _quatile function.

@Gaurav17Joshi Gaurav17Joshi added the enhancement New feature or request label Jun 18, 2023
@Joshuaalbert
Copy link
Owner

Joshuaalbert commented Jun 18, 2023

Hi @Gaurav17Joshi, let me see if I understand correctly. You want to use multi-variate uniform distribution for A_i, and sigma_i, and for t_i you want to use a scaled and shifted Beta(1, n-i)? Note, that for t_i generated this way they are not forced to be sorted, and thus identifiable, so you will likely have degeneracies in your posterior. Also, the mean of t_i by the above definition would be 1/(1 + n - i), which goes from large to small as i increases, and seems opposite from what you want.

You can do all this without special priors. Note, the casting of constants to jnp.asarray(..., float_type), this is good practice if you'll be using 64-bit JAX, which I assume you are with GPs.

I'll assume some data for illustration.

from jaxns.types import float_type

# Your hyperparameters for priors, when using 64-bit it's important to cast things appropriately
n = 3
t_max = jnp.asarray(20., float_type)
t_min = jnp.asarray(0., float_type)
A_lower = jnp.asarray(0., float_type)
A_upper = jnp.asarray(1., float_type)
sigma_lower = jnp.asarray(0., float_type)
sigma_upper = jnp.asarray(1., float_type)

# Make fake data
X = jnp.linspace(t_min, t_max, 5)
Y = jnp.exp(jnp.sin(X))

Then define the prior model,

def prior_model() -> PriorModelGen:
    A = yield Prior(tfpd.Uniform(low=A_lower * jnp.ones(n), high=A_upper * jnp.ones(n)), name='A')
    sigma = yield Prior(tfpd.Uniform(low=sigma_lower * jnp.ones(n), high=sigma_upper * jnp.ones(n)), name='sigma')
    t_array = []
    scale_bij = tfp.bijectors.Scale(scale=t_max - t_min)
    shift_bij = tfp.bijectors.Shift(shift=t_min)
    for i in range(n):
        underlying_beta = tfpd.Beta(
            concentration1=jnp.asarray(1., float_type),
            concentration0=jnp.asarray(n - i, float_type)
        )
        t = yield Prior(shift_bij(scale_bij(underlying_beta)), name=f"t{i}")
        t_array.append(t)
    t_array = jnp.stack(t_array)
    return A, sigma, t_array

Finally, the likelihood, and let's test the model

def log_likelihood(A, sigma, t_array):
    @vmap
    def eval_mean(x):
        dx = (t_array - x) / sigma
        components = A * jnp.exp(-0.5 * jnp.square(dx))
        return jnp.sum(components)

    m_X = eval_mean(X)
    # Do something with m_X
    return -jnp.sum(jnp.square(Y - m_X))

model = Model(prior_model=prior_model,
              log_likelihood=log_likelihood)

model.sanity_check(random.PRNGKey(0), S=100)

# Example prior sample
print(model.transform(model.sample_U(random.PRNGKey(42))))
# {'A': Array([0.26283673, 0.10945365, 0.46926031], dtype=float64), 'sigma': Array([0.40959993, 0.08672034, 0.48140902], dtype=float64), 't0': Array(15.24431331, dtype=float64), 't1': Array(19.5986375, dtype=float64), 't2': Array(0.76363376, dtype=float64)}

Does this meet your need?

@Gaurav17Joshi
Copy link
Author

Gaurav17Joshi commented Jun 19, 2023

Hi @Joshuaalbert , thanks for the prompt reply
The first issue has been resolved for the A, sigma by your code. It fits my need and is working in my code.

As for the second one, the Constrained Multivariate Beta Prior, is indeed shifted and scaled, but it is also conditioned. The shift and scale factor is not same for all, for $t_i$, the shift is $t_{i-1}$ and scale is $t_{max} - t_{i-1}$, instead of having all as shifted and scaled by same factor.
This makes sure that the successive times are in increasing order $t_{min} < t_0 < t_1 .... t_n < t_{max}$.
I have written a tfpd joint distribution code for n = 3, to make it more clear:-

jointds = tfd.JointDistributionSequential([
    tfb.Shift(t_min)( tfb.Scale(t_max) (tfd.Beta(1,3)) ),                       # t_0
    lambda t_0: tfb.Shift(t_0)( tfb.Scale(t_max-t_0) (tfd.Beta(1,2)) ),         # t_1 
    lambda t_1: tfb.Shift(t_1)( tfb.Scale(t_max-t_1) (tfd.Beta(1,1)) ),         # t_2 
])

Avoiding degeneracy in the times is central to this project. As I said the ForcedIdentifiabilty Special prior does an excellent job of providing non degenerate samples following $t_{min} < t_0 < t_1 .... t_n < t_{max}$ just that it gives us uniform samples. I just need a similar prior which samples it through a beta function

An image for the constrained beta pdf:
Screenshot 2023-06-19 at 10 46 27 PM

@Joshuaalbert
Copy link
Owner

Joshuaalbert commented Jun 19, 2023

Ah, I see. I misread your original post, which I why I pointed out the degeneracy. But, it's very easy to incorporate into the code if n is reasonably small.

    t_array = []
    scale_bij = tfp.bijectors.Scale(scale=t_max - t_min)
    shift_bij = tfp.bijectors.Shift(shift=t_min)
    for i in range(n):
        underlying_beta = tfpd.Beta(
            concentration1=jnp.asarray(1., float_type),
            concentration0=jnp.asarray(n - i, float_type)
        )
        t = yield Prior(shift_bij(scale_bij(underlying_beta)), name=f"t{i}")
        # Update the shift and scale here
        scale_bij = tfp.bijectors.Scale(scale=t_max - t)
        shift_bij = tfp.bijectors.Shift(shift=t)
        t_array.append(t)
    t_array = jnp.stack(t_array)

@Gaurav17Joshi
Copy link
Author

Thanks, this is working well in my code. Just one more request, do you have some ideas as to how one may prepare tests for such priors and the prior_model in general.

@Joshuaalbert
Copy link
Owner

What aspect are you looking to test?

@Joshuaalbert
Copy link
Owner

@Gaurav17Joshi would love to help you further with testing prior, but can you please open a thread in the discussions? I'll close this given that we've addressed the initial issue.

https://github.com/Joshuaalbert/jaxns/discussions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants