From bb175985905b50ad74cb257d2ae3020396e4adbe Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Mon, 17 Jun 2024 20:19:55 -0700 Subject: [PATCH] refactor: remove unused DiscreteEntropyBottleneck --- compressai/entropy_models/entropy_models.py | 216 -------------------- compressai/layers/hist.py | 90 -------- compressai/models/adaptive.py | 14 +- compressai/models/google.py | 9 - 4 files changed, 1 insertion(+), 328 deletions(-) diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index 4feff22..e3172ba 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -41,9 +41,7 @@ from torch import Tensor from compressai._CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf -from compressai.layers.hist import DiscreteIndexing from compressai.ops import LowerBound -from compressai.ops.bound_ops import RangeBound class _EntropyCoder: @@ -571,220 +569,6 @@ def decompress(self, strings, size): return super().decompress(strings, indexes, medians.dtype, medians) -class SlightlyMoreGeneralizedEntropyBottleneck(EntropyBottleneck): - """A slightly more generalized ``EntropyBottleneck``. - - The only differences: - - - ``continuous`` argument for ``_logits_cumulative`` - """ - - def _logits_cumulative( - self, inputs: Tensor, stop_gradient: bool = False, continuous=None - ): - if continuous is None: - continuous = self.continuous_default - if continuous: - return self._logits_cumulative_continuous(inputs, stop_gradient) - else: - return self._logits_cumulative_discrete(inputs, stop_gradient) - - def _logits_cumulative_continuous(self, *args, **kwargs): - return EntropyBottleneck._logits_cumulative(self, *args, **kwargs) - - @torch.jit.unused - def _likelihood( - self, inputs: Tensor, stop_gradient: bool = False, continuous=None - ) -> Tuple[Tensor, Tensor, Tensor]: - half = float(0.5) - lower = self._logits_cumulative( - inputs - half, stop_gradient=stop_gradient, continuous=continuous - ) - upper = self._logits_cumulative( - inputs + half, stop_gradient=stop_gradient, continuous=continuous - ) - likelihood = torch.sigmoid(upper) - torch.sigmoid(lower) - if self.use_likelihood_bound: - likelihood = self.likelihood_lower_bound(likelihood) - return likelihood, lower, upper - - def forward( - self, x: Tensor, training: Optional[bool] = None, continuous=None - ) -> Tuple[Tensor, Tensor]: - if training is None: - training = self.training - - if not torch.jit.is_scripting(): - # x from B x C x ... to C x B x ... - perm = np.arange(len(x.shape)) - perm[0], perm[1] = perm[1], perm[0] - # Compute inverse permutation - inv_perm = np.arange(len(x.shape))[np.argsort(perm)] - else: - raise NotImplementedError() - # TorchScript in 2D for static inference - # Convert to (channels, ... , batch) format - # perm = (1, 2, 3, 0) - # inv_perm = (3, 0, 1, 2) - - x = x.permute(*perm).contiguous() - shape = x.size() - values = x.reshape(x.size(0), 1, -1) - - # Add noise or quantize - - outputs = self.quantize( - values, "noise" if training else "dequantize", self._get_medians() - ) - - if not torch.jit.is_scripting(): - likelihood, _, _ = self._likelihood(outputs, continuous=continuous) - # NOTE: This has been moved to the _likelihood function. - # if self.use_likelihood_bound: - # likelihood = self.likelihood_lower_bound(likelihood) - else: - raise NotImplementedError() - # TorchScript not yet supported - # likelihood = torch.zeros_like(outputs) - - # Convert back to input tensor shape - outputs = outputs.reshape(shape) - outputs = outputs.permute(*inv_perm).contiguous() - - likelihood = likelihood.reshape(shape) - likelihood = likelihood.permute(*inv_perm).contiguous() - - return outputs, likelihood - - -class DiscreteEntropyBottleneck(SlightlyMoreGeneralizedEntropyBottleneck): - continuous_default: bool = False - - # Possible distribution modeling methods to consider: - # - logits cumulative table (non-decreasing) - # - cdf table (bounded [0, 1], non-negativity, non-decreasing) - # - pdf table (bounded [0, 1], non-negativity) - # - # Methods for enforcement: - # - bounded: sigmoid, tanh - # - non-decreasing: cumsum, lower bound y[i] by y[i-1] via detached max - # - non-negativity: softplus (soft relu), sigmoid - # - # Ballé 2018 enforces (almost?) non-decreasing logits via tanh, - # and non-negativity via softplus of the H matrix, I believe. - - def __init__( - self, - channels: int, - *args: Any, - tail_mass: float = 1e-9, - init_scale: float = 10, - filters: Tuple[int, ...] = (3, 3, 3, 3), - num_symbols: int = 255, - sample_rate: int = 4, - init_scale_gamma: float = 4, - **kwargs: Any, - ): - super().__init__( - channels, - *args, - tail_mass=tail_mass, - init_scale=init_scale, - filters=filters, - **kwargs, - ) - - # Initialize channels with a variety of init_scale. - t = torch.linspace(1, 1 / channels, channels) - init_scale = (init_scale * t**init_scale_gamma).clip(min=0.5) - self.quantiles.data = torch.stack( - [-init_scale, torch.zeros_like(init_scale), init_scale], dim=-1 - ).unsqueeze(1) - - assert num_symbols % 2 == 1 - num_samples = sample_rate * num_symbols + 1 - symbols_radius = (num_symbols - 1) // 2 - self.sample_rate = sample_rate - - q_dist = -inv_sigmoid(Tensor([tail_mass / 2])).item() - table = ( - (symbols_radius * q_dist / init_scale)[:, None] - * torch.linspace(-1, 1, num_samples)[None, :] - ).clip(min=-q_dist * 2, max=q_dist * 2) - self.logits_cumulative_table = nn.Parameter(self._deparametrize_table(table)) - - # Recommendation: set bounds conservatively to 80% of range. - self.logits_cumulative_table_bound = RangeBound( - min=int(num_samples * 0.10), - max=int(num_samples * 0.90), - ) - - self.register_buffer( - "logits_cumulative_table_offset", - torch.full((channels,), -sample_rate * (symbols_radius + 0.5)), - ) - - self.discrete_indexing = DiscreteIndexing() - - self._update_quantiles() - - self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) - - def _parametrize_table(self, f): - # NOTE: Due to implementation details, this is not entirely monotone! - return f.cumsum(axis=-1) - - def _deparametrize_table(self, f): - return f.diff(axis=-1, prepend=torch.zeros_like(f[..., :1])) - - def _logits_cumulative_discrete( - self, inputs: Tensor, stop_gradient: bool, eps=1e-2 - ) -> Tensor: - # f.shape == (C, B) - f = self.logits_cumulative_table - if stop_gradient: - f = f.detach() - # Enforce non-decreasing monotonicity: - f = self._parametrize_table(f) - if stop_gradient: - # Ensure target values exist somewhere within table boundaries: - b_min = math.ceil(self.logits_cumulative_table_bound.bound_min) - b_max = math.floor(self.logits_cumulative_table_bound.bound_max) - f[..., :b_min] = f[..., :b_min].clip(max=self.target[0] - eps) - f[..., b_max:] = f[..., b_max:].clip(min=self.target[-1] + eps) - - # inputs.shape == (C, 1, -1) - q = self.quantiles.squeeze(1) - assert inputs.ndim == 3 - x = self.sample_rate * (inputs.squeeze(1) - q[:, 1, None]) - x = x - self.logits_cumulative_table_offset.unsqueeze(1) - x = self.logits_cumulative_table_bound(x) - - logits = self.discrete_indexing(f, x) - logits = logits.unsqueeze(1) - return logits - - def load_state_dict_post_hook(self, module, incompatible_keys): - missing_keys, _ = incompatible_keys - if any(key.endswith("logits_cumulative_table") for key in missing_keys): - print("Initializing logits_cumulative_table via continuous model") - self.init_logits_cumulative_table_via_continuous() - - @torch.no_grad() - def init_logits_cumulative_table_via_continuous(self): - """Initialize logits_cumulative_table using the pretrained continuous model.""" - _, num_samples = self.logits_cumulative_table.shape - device = self.logits_cumulative_table.device - x = ( - torch.arange(num_samples, device=device)[None, :] - + self.logits_cumulative_table_offset[:, None] - - self.quantiles[:, :, 1] - ) / self.sample_rate - x = x.unsqueeze(1) - f = self._logits_cumulative(x, stop_gradient=True, continuous=True).squeeze(1) - self.logits_cumulative_table[:] = self._deparametrize_table(f) - - class GaussianConditional(EntropyModel): r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh, S. J. Hwang, N. Johnston, in `"Variational image compression with a scale diff --git a/compressai/layers/hist.py b/compressai/layers/hist.py index 3a0a214..e95e8cf 100644 --- a/compressai/layers/hist.py +++ b/compressai/layers/hist.py @@ -1,10 +1,7 @@ -import math - from typing import Union import torch import torch.nn as nn -import torch.nn.functional as F from torch import Tensor @@ -154,93 +151,6 @@ def forward(self, x, **kwargs): # ) -# From https://github.com/pytorch/pytorch/blob/main/tools/autograd/derivatives.yaml: -# -# - name: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor # noqa: E501 -# self: gather_backward(grad, self, dim, index, sparse_grad) -# index: non_differentiable -# result: auto_linear -# -# - name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor -# self: grad -# index: non_differentiable -# src: grad.gather(dim, index) -# result: scatter_add(self_t, dim, index, src_t) -# -class DiscreteIndexingFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, f, x, diff_kernel, grad_f_multiplier=1.0): - assert x.shape[:-1] == f.shape[:-1] - assert x.min() >= 0 - assert x.max() <= f.shape[-1] - 1 - grad_f_multiplier = f.new_tensor(grad_f_multiplier) - ctx.save_for_backward(f, x, diff_kernel, grad_f_multiplier) - return DiscreteIndexingFunction._lerp(f, x) - - @staticmethod - def backward(ctx, grad_output): - f, x, diff_kernel, grad_f_multiplier = ctx.saved_tensors - df_dx = DiscreteIndexingFunction._estimate_derivative(f, diff_kernel) - df_dx_at_x = DiscreteIndexingFunction._lerp(df_dx, x) - grad_x = df_dx_at_x * grad_output - grad_f = DiscreteIndexingFunction._dout_df(f, x, grad_output) - grad_f = grad_f * grad_f_multiplier - return grad_f, grad_x, None, None - - @staticmethod - def _lerp(f, x): - x1 = x.floor().long() - x2 = x1 + 1 - y1 = f.gather(dim=-1, index=x1) - y2 = f.gather(dim=-1, index=x2) - dx = x - x1 - return y1 * (1 - dx) + y2 * dx - - @staticmethod - def _estimate_derivative(f, diff_kernel): - # Pad f, then estimate derivative via finite difference kernel. - *other_dims, num_bins = f.shape - pad_width = diff_kernel.shape[-1] // 2 - f = f.reshape(math.prod(other_dims), 1, num_bins) - f = F.pad(f, pad=(pad_width, pad_width), mode="replicate") - df_dx = F.conv1d(f, weight=diff_kernel).reshape(*other_dims, num_bins) - return df_dx - - @staticmethod - def _dout_df(f, x, grad_output): - # Nothing fancy; just manually compute the standard derivative. - x1 = x.floor().long() - x2 = x1 + 1 - dx = x - x1 - grad_y1 = grad_output * (1 - dx) - grad_y2 = grad_output * dx - grad_f = torch.zeros_like(f) - grad_f.scatter_add_(dim=-1, index=x1, src=grad_y1) - grad_f.scatter_add_(dim=-1, index=x2, src=grad_y2) - return grad_f - - -class DiscreteIndexing(nn.Module): - def __init__(self, grad_f_multiplier=1.0): - super().__init__() - self.register_buffer("diff_kernel", self._get_diff_kernel()) - self.grad_f_multiplier = grad_f_multiplier - - def forward(self, f, x): - return DiscreteIndexingFunction.apply( - f, x, self.diff_kernel, self.grad_f_multiplier - ) - - def _get_diff_kernel(self): - smoothing_kernel = torch.tensor([[[0.25, 0.5, 0.25]]]) - difference_kernel = torch.tensor([[[-0.5, 0, 0.5]]]) - return F.conv1d( - smoothing_kernel, - difference_kernel.flip(-1), - padding=difference_kernel.shape[-1] - 1, - ) - - # def cumulative_histogram(...): # pass # # WARN: Not the below... that's dc/dx... diff --git a/compressai/models/adaptive.py b/compressai/models/adaptive.py index 722366c..c7b6993 100644 --- a/compressai/models/adaptive.py +++ b/compressai/models/adaptive.py @@ -3,11 +3,7 @@ import torch import torch.nn as nn -from compressai.entropy_models.entropy_models import ( - DiscreteEntropyBottleneck, - EntropyBottleneck, - pdf_layout, -) +from compressai.entropy_models.entropy_models import EntropyBottleneck, pdf_layout from compressai.latent_codecs import LatentCodec from compressai.layers import UniformHistogram from compressai.ops import RangeBound @@ -416,14 +412,6 @@ def decompress(self, strings, shape, pdf_x_default, **kwargs): return {"x_hat": x_hat} -@register_model("bmshj2018-factorized-pdf-discrete-eb") -class AdaptiveFactorizedPrior(AdaptiveFactorizedPrior): - def __init__(self, N: int, M: int, **kwargs): - FactorizedPrior.__init__(self, N=N, M=M) - self.entropy_bottleneck = DiscreteEntropyBottleneck(M) - AdaptiveMixin.__init__(self, **kwargs) - - @register_model("bmshj2018-hyperprior-pdf") class AdaptiveScaleHyperprior(ScaleHyperprior, AdaptiveMixin): def __init__(self, N: int, M: int, **kwargs): diff --git a/compressai/models/google.py b/compressai/models/google.py index 557c96a..1a6d9d5 100644 --- a/compressai/models/google.py +++ b/compressai/models/google.py @@ -35,7 +35,6 @@ from compressai.ans import BufferedRansEncoder, RansDecoder from compressai.entropy_models import EntropyBottleneck, GaussianConditional -from compressai.entropy_models.entropy_models import DiscreteEntropyBottleneck from compressai.layers import GDN, MaskedConv2d from compressai.registry import register_model @@ -164,14 +163,6 @@ def decompress(self, strings, shape): return {"x_hat": x_hat} -@register_model("bmshj2018-factorized-discrete-eb") -class FactorizedPriorDiscreteEB(FactorizedPrior): - def __init__(self, N, M, **kwargs): - super().__init__(N=N, M=M, **kwargs) - - self.entropy_bottleneck = DiscreteEntropyBottleneck(M) - - @register_model("bmshj2018-factorized-relu") class FactorizedPriorReLU(FactorizedPrior): r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,