Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1d and 2d cnn embedding nets. #751

Merged
merged 3 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 134 additions & 56 deletions sbi/neural_nets/embedding_nets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from typing import List, Tuple, Union

import torch
from torch import Tensor, nn

Expand Down Expand Up @@ -44,81 +46,157 @@ def forward(self, x: Tensor) -> Tensor:
return self.net(x)


def calculate_filter_output_size(input_size, padding, dilation, kernel, stride) -> int:
"""Returns output size of a filter given filter arguments.

Uses formulas from https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.
"""

return int(
(int(input_size) + 2 * int(padding) - int(dilation) * (int(kernel) - 1) - 1)
/ int(stride)
+ 1
)


def get_new_cnn_output_size(
input_shape: Tuple,
conv_layer: Union[nn.Conv1d, nn.Conv2d],
pool: Union[nn.MaxPool1d, nn.MaxPool2d],
) -> Union[Tuple[int], Tuple[int, int]]:
"""Returns new output size after applying a given convolution and pooling.

Args:
input_shape: tup.
conv_layer: applied convolutional layers
pool: applied pooling layer

Returns:
new output dimension of the cnn layer.

"""
assert isinstance(input_shape, Tuple), "input shape must be Tuple."
assert 0 < len(input_shape) < 3, "input shape must be 1 or 2d."
assert isinstance(conv_layer.padding, Tuple), "conv layer attributes must be Tuple."
assert isinstance(pool.padding, int), "pool layer attributes must be integers."

out_after_conv = [
calculate_filter_output_size(
input_shape[i],
conv_layer.padding[i],
conv_layer.dilation[i],
conv_layer.kernel_size[i],
conv_layer.stride[i],
)
for i in range(len(input_shape))
]
out_after_pool = [
calculate_filter_output_size(
out_after_conv[i],
pool.padding,
pool.dilation,
pool.kernel_size,
pool.stride,
)
for i in range(len(input_shape))
]
return tuple(out_after_pool)


class CNNEmbedding(nn.Module):
def __init__(
self,
input_dim: int,
input_shape: Tuple,
in_channels: int = 1,
out_channels_per_layer: List = [6, 12],
num_conv_layers: int = 2,
num_linear_layers: int = 2,
num_linear_units: int = 50,
output_dim: int = 20,
num_fully_connected: int = 2,
num_hiddens: int = 120,
out_channels_cnn_1: int = 10,
out_channels_cnn_2: int = 16,
kernel_size: int = 5,
pool_size=4,
pool_kernel_size: int = 2,
):
"""Multi-layer (C)NN
First two layers are convolutional, followed by fully connected layers.
Performing 1d convolution and max pooling with preset configs.
"""Convolutional embedding network.
First two layers are convolutional, followed by fully connected layers.

Automatically infers whether to apply 1D or 2D convolution depending on
input_shape.
Allows usage of multiple (color) channels by passing in_channels > 1.

Args:
input_dim: Dimensionality of input.
output_dim: Dimensionality of the output.
num_conv: Number of convolutional layers.
num_fully_connected: Number fully connected layer, minimum of 2.
num_hiddens: Number of hidden dimensions in fully-connected layers.
out_channels_cnn_1: Number of oputput channels for the first convolutional
layer.
out_channels_cnn_2: Number of oputput channels for the second
convolutional layer.
input_shape: Dimensionality of input, e.g., (28,) for 1D, (28, 28) for 2D.
in_channels: Number of image channels, default 1.
out_channels_per_layer: Number of out convolutional out_channels for each
layer. Must match the number of layers passed below.
num_cnn_layers: Number of convolutional layers.
num_linear_layers: Number fully connected layer.
num_linear_units: Number of hidden units in fully-connected layers.
output_dim: Number of output units of the final layer.
kernel_size: Kernel size for both convolutional layers.
pool_size: pool size for MaxPool1d operation after the convolutional
layers.

Remark: The implementation of the convolutional layers was not tested
rigourously. While it works for the default configuration parameters it
might cause shape conflicts fot badly chosen parameters.
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.num_hiddens = num_hiddens

# construct convolutional-pooling subnet
pool = nn.MaxPool1d(pool_size)
conv_layers = [
nn.Conv1d(1, out_channels_cnn_1, kernel_size, padding="same"),
nn.ReLU(),
pool,
nn.Conv1d(
out_channels_cnn_1, out_channels_cnn_2, kernel_size, padding="same"
),
nn.ReLU(),
pool,
]
self.conv_subnet = nn.Sequential(*conv_layers)
super(CNNEmbedding, self).__init__()

# construct fully connected layers
input_dim_fc = out_channels_cnn_2 * (int(input_dim / out_channels_cnn_2))
assert isinstance(
input_shape, Tuple
), "input_shape must be a Tuple of size 1 or 2, e.g., (width, [height])."
assert (
0 < len(input_shape) < 3
), """input_shape must be a Tuple of size 1 or 2, e.g.,
(width, [height]). Number of input channels are passed separately"""

self.fc_subnet = FCEmbedding(
input_dim=input_dim_fc,
use_2d_cnn = len(input_shape) == 2
conv_module = nn.Conv2d if use_2d_cnn else nn.Conv1d
pool_module = nn.MaxPool2d if use_2d_cnn else nn.MaxPool1d

assert (
len(out_channels_per_layer) == num_conv_layers
), "out_channels needs as many entries as num_cnn_layers."

# define input shape with channel
self.input_shape = (in_channels, *input_shape)

# Construct CNN feature extractor.
cnn_layers = []
cnn_output_size = input_shape
stride = 1
padding = 1
for ii in range(num_conv_layers):
# Defining another 2D convolution layer
conv_layer = conv_module(
in_channels=in_channels if ii == 0 else out_channels_per_layer[ii - 1],
out_channels=out_channels_per_layer[ii],
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
pool = pool_module(kernel_size=pool_kernel_size)
cnn_layers += [conv_layer, nn.ReLU(inplace=True), pool]
# Calculate change of output size of each CNN layer
cnn_output_size = get_new_cnn_output_size(cnn_output_size, conv_layer, pool)

self.cnn_subnet = nn.Sequential(*cnn_layers)

# Construct linear post processing net.
self.linear_subnet = FCEmbedding(
input_dim=out_channels_per_layer[-1]
* torch.prod(torch.tensor(cnn_output_size)),
output_dim=output_dim,
num_layers=num_fully_connected,
num_hiddens=num_hiddens,
num_layers=num_linear_layers,
num_hiddens=num_linear_units,
)

# Defining the forward pass
def forward(self, x: Tensor) -> Tensor:
"""Network forward pass.
Args:
x: Input tensor (batch_size, input_dim)
Returns:
Network output (batch_size, output_dim).
"""
x = self.conv_subnet(x.unsqueeze(1))
x = torch.flatten(x, 1) # flatten all dimensions except batch
embedding = self.fc_subnet(x)
batch_size = x.size(0)

return embedding
# reshape to account for single channel data.
x = self.cnn_subnet(x.view(batch_size, *self.input_shape))
# flatten for linear layers.
x = x.view(batch_size, -1)
x = self.linear_subnet(x)
return x


class PermutationInvariantEmbedding(nn.Module):
Expand Down
62 changes: 61 additions & 1 deletion tests/embedding_net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from sbi import utils as utils
from sbi.inference import SNLE, SNPE, SNRE
from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding
from sbi.neural_nets.embedding_nets import (
CNNEmbedding,
FCEmbedding,
PermutationInvariantEmbedding,
)
from sbi.simulators.linear_gaussian import (
linear_gaussian,
true_posterior_linear_gaussian_mvn_prior,
Expand Down Expand Up @@ -178,3 +182,59 @@ def test_iid_inference(num_trials, num_dim, method):
check_c2st(samples, reference_samples, alg=method + " permuted")
else:
check_c2st(samples, reference_samples, alg=method)


@pytest.mark.parametrize(
"input_shape",
[
(32,),
(32, 32),
(32, 64),
],
)
@pytest.mark.parametrize("num_channels", (1, 2, 3))
def test_1d_and_2d_cnn_embedding_net(input_shape, num_channels):
import torch
from torch.distributions import MultivariateNormal

estimator_provider = posterior_nn(
"mdn",
embedding_net=CNNEmbedding(
input_shape, in_channels=num_channels, output_dim=20
),
)

num_dim = input_shape[0]

def simulator2d(theta):
x = MultivariateNormal(
loc=theta, covariance_matrix=0.5 * torch.eye(num_dim)
).sample()
return x.unsqueeze(2).repeat(1, 1, input_shape[1])

def simulator1d(theta):
return torch.rand_like(theta) + theta

if len(input_shape) == 1:
simulator = simulator1d
xo = torch.ones(1, num_channels, *input_shape).squeeze(1)
else:
simulator = simulator2d
xo = torch.ones(1, num_channels, *input_shape).squeeze(1)

prior = MultivariateNormal(torch.zeros(num_dim), torch.eye(num_dim))

num_simulations = 1000
theta = prior.sample((num_simulations,))
x = simulator(theta)
if num_channels > 1:
x = x.unsqueeze(1).repeat(
1, num_channels, *[1 for _ in range(len(input_shape))]
)

trainer = SNPE(prior=prior, density_estimator=estimator_provider)
trainer.append_simulations(theta, x).train(max_num_epochs=2)
posterior = trainer.build_posterior().set_default_x(xo)

s = posterior.sample((10,))
posterior.potential(s)