Skip to content

Commit

Permalink
Add Bernoulli-gamma likelihood
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jun 10, 2024
1 parent d01e487 commit 51da293
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 20 deletions.
6 changes: 3 additions & 3 deletions neuralprocesses/architectures/agnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def construct_agnp(*args, nps=nps, num_heads=8, **kw_args):
width (int, optional): Widths of all intermediate MLPs. Defaults to 512.
nonlinearity (Callable or str, optional): Nonlinearity. Can also be specified
as a string: `"ReLU"` or `"LeakyReLU"`. Defaults to ReLUs.
likelihood (str, optional): Likelihood. Must be one of `"het"` or `"lowrank"`.
Defaults to `"lowrank"`.
likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`,
`"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
num_basis_functions (int, optional): Number of basis functions for the
low-rank likelihood. Defaults to 512.
dim_lv (int, optional): Dimensionality of the latent variable. Defaults to 0.
lv_likelihood (str, optional): Likelihood of the latent variable. Must be one of
`"het"`, `"dense"`, or `"spikes-beta"`. Defaults to `"het"`.
`"het"` or `"dense"`. Defaults to `"het"`.
transform (str or tuple[float, float]): Bijection applied to the
output of the model. This can help deal with positive of bounded data.
Must be either `"positive"`, `"exp"`, `"softplus"`, or
Expand Down
4 changes: 2 additions & 2 deletions neuralprocesses/architectures/climate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def construct_climate_convgnp_mlp(
to 128.
lr_deg (float, optional): Resolution of the low-resolution grid. Defaults to
0.75.
likelihood (str, optional): Likelihood. Must be one of `"het"` or `"lowrank".
Defaults to `"lowrank"`.
likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`,
`"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
dtype (dtype, optional): Data type.
"""
mlp_width = 128
Expand Down
2 changes: 1 addition & 1 deletion neuralprocesses/architectures/convgnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def construct_convgnp(
margin (float, optional): Margin of the internal discretisation. Defaults to
0.1.
likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`,
or `"spikes-beta"`. Defaults to `"lowrank"`.
`"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
conv_arch (str, optional): Convolutional architecture to use. Must be one of
`"unet[-res][-sep]"` or `"conv[-res][-sep]"`. Defaults to `"unet"`.
unet_channels (tuple[int], optional): Channels of every layer of the UNet.
Expand Down
6 changes: 3 additions & 3 deletions neuralprocesses/architectures/gnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ def construct_gnp(
width (int, optional): Widths of all intermediate MLPs. Defaults to 512.
nonlinearity (Callable or str, optional): Nonlinearity. Can also be specified
as a string: `"ReLU"` or `"LeakyReLU"`. Defaults to ReLUs.
likelihood (str, optional): Likelihood. Must be one of `"het"` or `"lowrank"`.
Defaults to `"lowrank"`.
likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`,
`"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
num_basis_functions (int, optional): Number of basis functions for the
low-rank likelihood. Defaults to 512.
dim_lv (int, optional): Dimensionality of the latent variable. Defaults to 0.
lv_likelihood (str, optional): Likelihood of the latent variable. Must be one of
`"het"`, `"dense"`, or `"spikes-beta"`. Defaults to `"het"`.
`"het"` or `"dense"`. Defaults to `"het"`.
transform (str or tuple[float, float]): Bijection applied to the
output of the model. This can help deal with positive of bounded data.
Must be either `"positive"`, `"exp"`, `"softplus"`, or
Expand Down
8 changes: 6 additions & 2 deletions neuralprocesses/architectures/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def construct_likelihood(nps=nps, *, spec, dim_y, num_basis_functions, dtype):
Args:
nps (module): Appropriate backend-specific module.
spec (str, optional): Specification. Must be one of `"het"`, `"lowrank"`,
`"dense"`, or `"spikes-beta"`. Defaults to `"lowrank"`. Must be given as
a keyword argument.
`"dense"`, `"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`.
Must be given as a keyword argument.
dim_y (int): Dimensionality of the outputs. Must be given as a keyword argument.
num_basis_functions (int): Number of basis functions for the low-rank
likelihood. Must be given as a keyword argument.
Expand Down Expand Up @@ -52,6 +52,10 @@ def construct_likelihood(nps=nps, *, spec, dim_y, num_basis_functions, dtype):
num_channels = (2 + 3) * dim_y # Alpha, beta, and three log-probabilities
selector = nps.SelectFromChannels(dim_y, dim_y, dim_y, dim_y, dim_y)
lik = nps.SpikesBetaLikelihood()
elif spec == "bernoulli-gamma":
num_channels = (2 + 2) * dim_y # Shape, scale, and two log-probabilities
selector = nps.SelectFromChannels(dim_y, dim_y, dim_y, dim_y)
lik = nps.BernoulliGammaLikelihood()

else:
raise ValueError(f'Incorrect likelihood specification "{spec}".')
Expand Down
1 change: 1 addition & 0 deletions neuralprocesses/dist/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .beta import *
from .dirac import *
from .dist import *
from .gamma import *
from .geom import *
from .normal import *
from .spikeslab import *
Expand Down
97 changes: 95 additions & 2 deletions neuralprocesses/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import _dispatch
from .aggregate import Aggregate, AggregateInput
from .datadims import data_dims
from .dist import Beta, Dirac, MultiOutputNormal, SpikesSlab
from .dist import Beta, Dirac, Gamma, MultiOutputNormal, SpikesSlab
from .parallel import Parallel
from .util import register_module, split, split_dimension

Expand All @@ -16,6 +16,8 @@
"HeterogeneousGaussianLikelihood",
"LowRankGaussianLikelihood",
"DenseGaussianLikelihood",
"SpikesBetaLikelihood",
"BernoulliGammaLikelihood",
]


Expand Down Expand Up @@ -359,7 +361,7 @@ def _dense_var(coder: DenseGaussianLikelihood, xz, z: B.Numeric):

@register_module
class SpikesBetaLikelihood(AbstractLikelihood):
"""Gaussian likelihood with heterogeneous noise.
"""Mixture of a beta distribution, a Dirac delta at zero, and a Dirac delta at one.
Args:
epsilon (float, optional): Tolerance for equality checking. Defaults to `1e-6`.
Expand Down Expand Up @@ -451,3 +453,94 @@ def _spikesbeta(coder: SpikesBetaLikelihood, xz, z: B.Numeric):
logps = z_logps

return alpha, beta, logp0, logp1, logps, d + 1


@register_module
class BernoulliGammaLikelihood(AbstractLikelihood):
"""Mixture of a gamma distribution and a Dirac delta at zero.
Args:
epsilon (float, optional): Tolerance for equality checking. Defaults to `1e-6`.
Args:
epsilon (float): Tolerance for equality checking.
"""

@_dispatch
def __init__(self, epsilon: float = 1e-6):
self.epsilon = epsilon

def __str__(self):
return repr(self)

def __repr__(self):
return f"BernoulliGammaLikelihood(epsilon={self.epsilon!r})"


@_dispatch
def code(
coder: BernoulliGammaLikelihood,
xz,
z,
x,
*,
dtype_lik=None,
**kw_args,
):
k, scale, logp0, logps, d = _bernoulligamma(coder, xz, z)

# Cast parameters to the right data type.
if dtype_lik:
k = B.cast(dtype_lik, k)
scale = B.cast(dtype_lik, scale)
logp0 = B.cast(dtype_lik, logp0)
logps = B.cast(dtype_lik, logps)

# Create the spikes vector.
with B.on_device(z):
dtype = dtype_lik or B.dtype(z)
spikes = B.stack(B.zero(dtype))

return xz, SpikesSlab(
spikes,
Gamma(k, scale, d),
B.stack(logp0, logps, axis=-1),
d,
epsilon=coder.epsilon,
)


@_dispatch
def _bernoulligamma(
coder: BernoulliGammaLikelihood,
xz: AggregateInput,
z: Aggregate,
):
ks, scales, logp0s, logpss, ds = zip(
*[_bernoulligamma(coder, xzi, zi) for (xzi, _), zi in zip(xz, z)]
)

# Concatenate into one big distribution.
k = Aggregate(*ks)
scale = Aggregate(*scales)
logp0 = Aggregate(*logp0s)
logps = Aggregate(*logpss)
d = Aggregate(*ds)

return k, scale, logp0, logps, d


@_dispatch
def _bernoulligamma(coder: BernoulliGammaLikelihood, xz, z: B.Numeric):
d = data_dims(xz)
dim_y = B.shape(z, -d - 1) // 4

z_k, z_scale, z_logp0, z_logps = split(z, (dim_y, dim_y, dim_y, dim_y), -d - 1)

# Transform into parameters.
k = 1e-3 + B.softplus(z_k)
scale = 1e-3 + B.softplus(z_scale)
logp0 = z_logp0
logps = z_logps

return k, scale, logp0, logps, d + 1
14 changes: 7 additions & 7 deletions tests/test_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def product_kw_args(config, **kw_args):
},
dim_x=[1, 2],
dim_y=[1, 2],
likelihood=["het", "lowrank", "spikes-beta"],
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
)
# NP:
+ product_kw_args(
Expand All @@ -79,7 +79,7 @@ def product_kw_args(config, **kw_args):
},
dim_x=[1, 2],
dim_y=[1, 2],
likelihood=["het", "lowrank", "spikes-beta"],
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
lv_likelihood=["het", "dense"],
)
# ACNP:
Expand All @@ -94,7 +94,7 @@ def product_kw_args(config, **kw_args):
},
dim_x=[1, 2],
dim_y=[1, 2],
likelihood=["het", "lowrank", "spikes-beta"],
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
)
# ANP:
+ product_kw_args(
Expand All @@ -108,7 +108,7 @@ def product_kw_args(config, **kw_args):
},
dim_x=[1, 2],
dim_y=[1, 2],
likelihood=["het", "lowrank", "spikes-beta"],
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
lv_likelihood=["het", "dense"],
)
# ConvCNP and ConvGNP:
Expand All @@ -122,7 +122,7 @@ def product_kw_args(config, **kw_args):
},
dim_x=[1, 2],
dim_y=[1, 2],
likelihood=["het", "lowrank", "spikes-beta"],
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
encoder_scales_learnable=[True, False],
decoder_scale_learnable=[True, False],
)
Expand All @@ -138,7 +138,7 @@ def product_kw_args(config, **kw_args):
},
dim_x=[1, 2],
dim_y=[1, 2],
likelihood=["het", "lowrank", "spikes-beta"],
likelihood=["het", "lowrank", "spikes-beta", "bernoulli-gamma"],
lv_likelihood=["het", "lowrank"],
)
)
Expand Down Expand Up @@ -219,7 +219,7 @@ def construct_model():

def sample():
if "likelihood" in config:
binary = config["likelihood"] == "spikes-beta"
binary = config["likelihood"] in {"spikes-beta", "bernoulli-gamma"}
else:
binary = False
return generate_data(
Expand Down

0 comments on commit 51da293

Please sign in to comment.