diff --git a/FrEIA/modules/all_in_one_block.py b/FrEIA/modules/all_in_one_block.py index 2f07cbe..07bb1e6 100644 --- a/FrEIA/modules/all_in_one_block.py +++ b/FrEIA/modules/all_in_one_block.py @@ -102,10 +102,13 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and self.splits = [split_len1, split_len2] try: - self.permute_function = {0: F.linear, - 1: F.conv1d, - 2: F.conv2d, - 3: F.conv3d}[self.input_rank] + if permute_soft or learned_householder_permutation: + self.permute_function = {0: F.linear, + 1: F.conv1d, + 2: F.conv2d, + 3: F.conv3d}[self.input_rank] + else: + self.permute_function = lambda x, p: x[:, p] except KeyError: raise ValueError(f"Data is {1 + self.input_rank}D. Must be 1D-4D.") @@ -143,9 +146,7 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and if permute_soft: w = special_ortho_group.rvs(channels) else: - w = np.zeros((channels, channels)) - for i, j in enumerate(np.random.permutation(channels)): - w[i, j] = 1. + w_index = torch.randperm(channels, requires_grad=False) if self.householder: # instead of just the permutation matrix w, the learned housholder @@ -154,12 +155,15 @@ class or callable ``f``, called as ``f(channels_in, channels_out)`` and self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True) self.w_perm = None self.w_perm_inv = None - self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False) - else: - self.w_perm = nn.Parameter(torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)), + self.w_0 = nn.Parameter(torch.from_numpy(w).float(), requires_grad=False) + elif permute_soft: + self.w_perm = nn.Parameter(torch.from_numpy(w).float().view(channels, channels, *([1] * self.input_rank)).contiguous(), requires_grad=False) - self.w_perm_inv = nn.Parameter(torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)), + self.w_perm_inv = nn.Parameter(torch.from_numpy(w.T).float().view(channels, channels, *([1] * self.input_rank)).contiguous(), requires_grad=False) + else: + self.w_perm = nn.Parameter(w_index, requires_grad=False) + self.w_perm_inv = nn.Parameter(torch.argsort(w_index), requires_grad=False) if subnet_constructor is None: raise ValueError("Please supply a callable subnet_constructor " @@ -222,7 +226,7 @@ def _affine(self, x, a, rev=False): a *= 0.1 ch = x.shape[1] - sub_jac = self.clamp * torch.tanh(a[:, :ch]) + sub_jac = self.clamp * torch.tanh(a[:, :ch]/self.clamp) if self.GIN: sub_jac -= torch.mean(sub_jac, dim=self.sum_dims, keepdim=True) @@ -235,6 +239,9 @@ def _affine(self, x, a, rev=False): def forward(self, x, c=[], rev=False, jac=True): '''See base class docstring''' + if tuple(x[0].shape[1:]) != self.dims_in[0]: + raise RuntimeError(f"Expected input of shape {self.dims_in[0]}, " + f"got {tuple(x[0].shape[1:])}.") if self.householder: self.w_perm = self._construct_householder_permutation() if rev or self.reverse_pre_permute: diff --git a/FrEIA/modules/invertible_resnet.py b/FrEIA/modules/invertible_resnet.py index efac11f..1759120 100644 --- a/FrEIA/modules/invertible_resnet.py +++ b/FrEIA/modules/invertible_resnet.py @@ -30,9 +30,9 @@ def __init__(self, dims_in, dims_c=None, init_data: torch.Tensor = None): self.register_buffer("is_initialized", torch.tensor(False)) - dim = next(iter(dims_in))[0] - self.log_scale = nn.Parameter(torch.empty(1, dim)) - self.loc = nn.Parameter(torch.empty(1, dim)) + dims = next(iter(dims_in)) + self.log_scale = nn.Parameter(torch.empty(1, *dims)) + self.loc = nn.Parameter(torch.empty(1, *dims)) if init_data is not None: self.initialize(init_data) diff --git a/FrEIA/modules/splines/binned.py b/FrEIA/modules/splines/binned.py index 5c6607f..ad20009 100644 --- a/FrEIA/modules/splines/binned.py +++ b/FrEIA/modules/splines/binned.py @@ -64,7 +64,7 @@ class BinnedSplineBase(InvertibleModule): def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[str, int] = None, min_bin_sizes: Tuple[float] = (0.1, 0.1), default_domain: Tuple[float] = (-3.0, 3.0, -3.0, 3.0), - identity_tails: bool = False) -> None: + identity_tails: bool = False, domain_clamping: float = None) -> None: """ Args: bins: number of bins to use @@ -75,6 +75,8 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[ default_domain: tuple of (left, right, bottom, top) default spline domain values these values will be used as the starting domain (when the network outputs zero) identity_tails: whether to use identity tails for the spline + domain_clamping: clamping value for the domain, if float, + clamp spline width and height to (-domain_clamping, domain_clamping) """ if dims_c is None: dims_c = [] @@ -98,6 +100,8 @@ def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[ self.register_buffer("identity_tails", torch.tensor(identity_tails, dtype=torch.bool)) self.register_buffer("default_width", torch.as_tensor(default_domain[1] - default_domain[0], dtype=torch.float32)) + self.domain_clamping = domain_clamping + # The default parameters are # parameter constraints count # 1. the leftmost bin edge - 1 @@ -131,6 +135,15 @@ def split_parameters(self, parameters: torch.Tensor, split_len: int) -> Dict[str return dict(zip(keys, values)) + def clamp_domain(self, domain: torch.Tensor) -> torch.Tensor: + """ + Clamp domain to the a size between (-domain_clamping, domain_clamping) + """ + if self.domain_clamping is None: + return domain + else: + return self.domain_clamping * torch.tanh(domain / self.domain_clamping) + def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Constrain Parameters to meet certain conditions (e.g. positivity) @@ -143,6 +156,7 @@ def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, total_width = parameters["total_width"] shift = np.log(np.e - 1) total_width = self.default_width * F.softplus(total_width + shift) + total_width = self.clamp_domain(total_width) parameters["left"] = -total_width / 2 parameters["bottom"] = -total_width / 2 @@ -161,7 +175,16 @@ def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, parameters["widths"] = self.min_bin_sizes[0] + F.softplus(parameters["widths"] + xshift) parameters["heights"] = self.min_bin_sizes[1] + F.softplus(parameters["heights"] + yshift) - + + domain_width = torch.sum(parameters["widths"], dim=-1, keepdim=True) + domain_height = torch.sum(parameters["heights"], dim=-1, keepdim=True) + width_resize = self.clamp_domain(domain_width) / domain_width + height_resize = self.clamp_domain(domain_height) / domain_height + + parameters["widths"] = parameters["widths"] * width_resize + parameters["heights"] = parameters["heights"] * height_resize + parameters["left"] = parameters["left"] * width_resize + parameters["bottom"] = parameters["bottom"] * height_resize return parameters diff --git a/FrEIA/modules/splines/rational_quadratic.py b/FrEIA/modules/splines/rational_quadratic.py index 6656fb8..e8673d0 100644 --- a/FrEIA/modules/splines/rational_quadratic.py +++ b/FrEIA/modules/splines/rational_quadratic.py @@ -164,7 +164,8 @@ def rational_quadratic_spline(x: torch.Tensor, # Eq 29 in the appendix of the paper discriminant = b ** 2 - 4 * a * c - assert torch.all(discriminant >= 0) + if not torch.all(discriminant >= 0): + raise(RuntimeError(f"Discriminant must be positive, but is violated by {torch.min(discriminant)}")) xi = 2 * c / (-b - torch.sqrt(discriminant))