Skip to content

Commit

Permalink
refactor: remove unused DiscreteEntropyBottleneck
Browse files Browse the repository at this point in the history
  • Loading branch information
YodaEmbedding committed Jun 18, 2024
1 parent cafcff1 commit bb17598
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 328 deletions.
216 changes: 0 additions & 216 deletions compressai/entropy_models/entropy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
from torch import Tensor

from compressai._CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf
from compressai.layers.hist import DiscreteIndexing
from compressai.ops import LowerBound
from compressai.ops.bound_ops import RangeBound


class _EntropyCoder:
Expand Down Expand Up @@ -571,220 +569,6 @@ def decompress(self, strings, size):
return super().decompress(strings, indexes, medians.dtype, medians)


class SlightlyMoreGeneralizedEntropyBottleneck(EntropyBottleneck):
"""A slightly more generalized ``EntropyBottleneck``.
The only differences:
- ``continuous`` argument for ``_logits_cumulative``
"""

def _logits_cumulative(
self, inputs: Tensor, stop_gradient: bool = False, continuous=None
):
if continuous is None:
continuous = self.continuous_default
if continuous:
return self._logits_cumulative_continuous(inputs, stop_gradient)
else:
return self._logits_cumulative_discrete(inputs, stop_gradient)

def _logits_cumulative_continuous(self, *args, **kwargs):
return EntropyBottleneck._logits_cumulative(self, *args, **kwargs)

@torch.jit.unused
def _likelihood(
self, inputs: Tensor, stop_gradient: bool = False, continuous=None
) -> Tuple[Tensor, Tensor, Tensor]:
half = float(0.5)
lower = self._logits_cumulative(
inputs - half, stop_gradient=stop_gradient, continuous=continuous
)
upper = self._logits_cumulative(
inputs + half, stop_gradient=stop_gradient, continuous=continuous
)
likelihood = torch.sigmoid(upper) - torch.sigmoid(lower)
if self.use_likelihood_bound:
likelihood = self.likelihood_lower_bound(likelihood)
return likelihood, lower, upper

def forward(
self, x: Tensor, training: Optional[bool] = None, continuous=None
) -> Tuple[Tensor, Tensor]:
if training is None:
training = self.training

if not torch.jit.is_scripting():
# x from B x C x ... to C x B x ...
perm = np.arange(len(x.shape))
perm[0], perm[1] = perm[1], perm[0]
# Compute inverse permutation
inv_perm = np.arange(len(x.shape))[np.argsort(perm)]
else:
raise NotImplementedError()
# TorchScript in 2D for static inference
# Convert to (channels, ... , batch) format
# perm = (1, 2, 3, 0)
# inv_perm = (3, 0, 1, 2)

x = x.permute(*perm).contiguous()
shape = x.size()
values = x.reshape(x.size(0), 1, -1)

# Add noise or quantize

outputs = self.quantize(
values, "noise" if training else "dequantize", self._get_medians()
)

if not torch.jit.is_scripting():
likelihood, _, _ = self._likelihood(outputs, continuous=continuous)
# NOTE: This has been moved to the _likelihood function.
# if self.use_likelihood_bound:
# likelihood = self.likelihood_lower_bound(likelihood)
else:
raise NotImplementedError()
# TorchScript not yet supported
# likelihood = torch.zeros_like(outputs)

# Convert back to input tensor shape
outputs = outputs.reshape(shape)
outputs = outputs.permute(*inv_perm).contiguous()

likelihood = likelihood.reshape(shape)
likelihood = likelihood.permute(*inv_perm).contiguous()

return outputs, likelihood


class DiscreteEntropyBottleneck(SlightlyMoreGeneralizedEntropyBottleneck):
continuous_default: bool = False

# Possible distribution modeling methods to consider:
# - logits cumulative table (non-decreasing)
# - cdf table (bounded [0, 1], non-negativity, non-decreasing)
# - pdf table (bounded [0, 1], non-negativity)
#
# Methods for enforcement:
# - bounded: sigmoid, tanh
# - non-decreasing: cumsum, lower bound y[i] by y[i-1] via detached max
# - non-negativity: softplus (soft relu), sigmoid
#
# Ballé 2018 enforces (almost?) non-decreasing logits via tanh,
# and non-negativity via softplus of the H matrix, I believe.

def __init__(
self,
channels: int,
*args: Any,
tail_mass: float = 1e-9,
init_scale: float = 10,
filters: Tuple[int, ...] = (3, 3, 3, 3),
num_symbols: int = 255,
sample_rate: int = 4,
init_scale_gamma: float = 4,
**kwargs: Any,
):
super().__init__(
channels,
*args,
tail_mass=tail_mass,
init_scale=init_scale,
filters=filters,
**kwargs,
)

# Initialize channels with a variety of init_scale.
t = torch.linspace(1, 1 / channels, channels)
init_scale = (init_scale * t**init_scale_gamma).clip(min=0.5)
self.quantiles.data = torch.stack(
[-init_scale, torch.zeros_like(init_scale), init_scale], dim=-1
).unsqueeze(1)

assert num_symbols % 2 == 1
num_samples = sample_rate * num_symbols + 1
symbols_radius = (num_symbols - 1) // 2
self.sample_rate = sample_rate

q_dist = -inv_sigmoid(Tensor([tail_mass / 2])).item()
table = (
(symbols_radius * q_dist / init_scale)[:, None]
* torch.linspace(-1, 1, num_samples)[None, :]
).clip(min=-q_dist * 2, max=q_dist * 2)
self.logits_cumulative_table = nn.Parameter(self._deparametrize_table(table))

# Recommendation: set bounds conservatively to 80% of range.
self.logits_cumulative_table_bound = RangeBound(
min=int(num_samples * 0.10),
max=int(num_samples * 0.90),
)

self.register_buffer(
"logits_cumulative_table_offset",
torch.full((channels,), -sample_rate * (symbols_radius + 0.5)),
)

self.discrete_indexing = DiscreteIndexing()

self._update_quantiles()

self.register_load_state_dict_post_hook(self.load_state_dict_post_hook)

def _parametrize_table(self, f):
# NOTE: Due to implementation details, this is not entirely monotone!
return f.cumsum(axis=-1)

def _deparametrize_table(self, f):
return f.diff(axis=-1, prepend=torch.zeros_like(f[..., :1]))

def _logits_cumulative_discrete(
self, inputs: Tensor, stop_gradient: bool, eps=1e-2
) -> Tensor:
# f.shape == (C, B)
f = self.logits_cumulative_table
if stop_gradient:
f = f.detach()
# Enforce non-decreasing monotonicity:
f = self._parametrize_table(f)
if stop_gradient:
# Ensure target values exist somewhere within table boundaries:
b_min = math.ceil(self.logits_cumulative_table_bound.bound_min)
b_max = math.floor(self.logits_cumulative_table_bound.bound_max)
f[..., :b_min] = f[..., :b_min].clip(max=self.target[0] - eps)
f[..., b_max:] = f[..., b_max:].clip(min=self.target[-1] + eps)

# inputs.shape == (C, 1, -1)
q = self.quantiles.squeeze(1)
assert inputs.ndim == 3
x = self.sample_rate * (inputs.squeeze(1) - q[:, 1, None])
x = x - self.logits_cumulative_table_offset.unsqueeze(1)
x = self.logits_cumulative_table_bound(x)

logits = self.discrete_indexing(f, x)
logits = logits.unsqueeze(1)
return logits

def load_state_dict_post_hook(self, module, incompatible_keys):
missing_keys, _ = incompatible_keys
if any(key.endswith("logits_cumulative_table") for key in missing_keys):
print("Initializing logits_cumulative_table via continuous model")
self.init_logits_cumulative_table_via_continuous()

@torch.no_grad()
def init_logits_cumulative_table_via_continuous(self):
"""Initialize logits_cumulative_table using the pretrained continuous model."""
_, num_samples = self.logits_cumulative_table.shape
device = self.logits_cumulative_table.device
x = (
torch.arange(num_samples, device=device)[None, :]
+ self.logits_cumulative_table_offset[:, None]
- self.quantiles[:, :, 1]
) / self.sample_rate
x = x.unsqueeze(1)
f = self._logits_cumulative(x, stop_gradient=True, continuous=True).squeeze(1)
self.logits_cumulative_table[:] = self._deparametrize_table(f)


class GaussianConditional(EntropyModel):
r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh,
S. J. Hwang, N. Johnston, in `"Variational image compression with a scale
Expand Down
90 changes: 0 additions & 90 deletions compressai/layers/hist.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import math

from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor

Expand Down Expand Up @@ -154,93 +151,6 @@ def forward(self, x, **kwargs):
# )


# From https://github.com/pytorch/pytorch/blob/main/tools/autograd/derivatives.yaml:
#
# - name: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor # noqa: E501
# self: gather_backward(grad, self, dim, index, sparse_grad)
# index: non_differentiable
# result: auto_linear
#
# - name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
# self: grad
# index: non_differentiable
# src: grad.gather(dim, index)
# result: scatter_add(self_t, dim, index, src_t)
#
class DiscreteIndexingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, f, x, diff_kernel, grad_f_multiplier=1.0):
assert x.shape[:-1] == f.shape[:-1]
assert x.min() >= 0
assert x.max() <= f.shape[-1] - 1
grad_f_multiplier = f.new_tensor(grad_f_multiplier)
ctx.save_for_backward(f, x, diff_kernel, grad_f_multiplier)
return DiscreteIndexingFunction._lerp(f, x)

@staticmethod
def backward(ctx, grad_output):
f, x, diff_kernel, grad_f_multiplier = ctx.saved_tensors
df_dx = DiscreteIndexingFunction._estimate_derivative(f, diff_kernel)
df_dx_at_x = DiscreteIndexingFunction._lerp(df_dx, x)
grad_x = df_dx_at_x * grad_output
grad_f = DiscreteIndexingFunction._dout_df(f, x, grad_output)
grad_f = grad_f * grad_f_multiplier
return grad_f, grad_x, None, None

@staticmethod
def _lerp(f, x):
x1 = x.floor().long()
x2 = x1 + 1
y1 = f.gather(dim=-1, index=x1)
y2 = f.gather(dim=-1, index=x2)
dx = x - x1
return y1 * (1 - dx) + y2 * dx

@staticmethod
def _estimate_derivative(f, diff_kernel):
# Pad f, then estimate derivative via finite difference kernel.
*other_dims, num_bins = f.shape
pad_width = diff_kernel.shape[-1] // 2
f = f.reshape(math.prod(other_dims), 1, num_bins)
f = F.pad(f, pad=(pad_width, pad_width), mode="replicate")
df_dx = F.conv1d(f, weight=diff_kernel).reshape(*other_dims, num_bins)
return df_dx

@staticmethod
def _dout_df(f, x, grad_output):
# Nothing fancy; just manually compute the standard derivative.
x1 = x.floor().long()
x2 = x1 + 1
dx = x - x1
grad_y1 = grad_output * (1 - dx)
grad_y2 = grad_output * dx
grad_f = torch.zeros_like(f)
grad_f.scatter_add_(dim=-1, index=x1, src=grad_y1)
grad_f.scatter_add_(dim=-1, index=x2, src=grad_y2)
return grad_f


class DiscreteIndexing(nn.Module):
def __init__(self, grad_f_multiplier=1.0):
super().__init__()
self.register_buffer("diff_kernel", self._get_diff_kernel())
self.grad_f_multiplier = grad_f_multiplier

def forward(self, f, x):
return DiscreteIndexingFunction.apply(
f, x, self.diff_kernel, self.grad_f_multiplier
)

def _get_diff_kernel(self):
smoothing_kernel = torch.tensor([[[0.25, 0.5, 0.25]]])
difference_kernel = torch.tensor([[[-0.5, 0, 0.5]]])
return F.conv1d(
smoothing_kernel,
difference_kernel.flip(-1),
padding=difference_kernel.shape[-1] - 1,
)


# def cumulative_histogram(...):
# pass
# # WARN: Not the below... that's dc/dx...
Expand Down
14 changes: 1 addition & 13 deletions compressai/models/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import torch
import torch.nn as nn

from compressai.entropy_models.entropy_models import (
DiscreteEntropyBottleneck,
EntropyBottleneck,
pdf_layout,
)
from compressai.entropy_models.entropy_models import EntropyBottleneck, pdf_layout
from compressai.latent_codecs import LatentCodec
from compressai.layers import UniformHistogram
from compressai.ops import RangeBound
Expand Down Expand Up @@ -416,14 +412,6 @@ def decompress(self, strings, shape, pdf_x_default, **kwargs):
return {"x_hat": x_hat}


@register_model("bmshj2018-factorized-pdf-discrete-eb")
class AdaptiveFactorizedPrior(AdaptiveFactorizedPrior):
def __init__(self, N: int, M: int, **kwargs):
FactorizedPrior.__init__(self, N=N, M=M)
self.entropy_bottleneck = DiscreteEntropyBottleneck(M)
AdaptiveMixin.__init__(self, **kwargs)


@register_model("bmshj2018-hyperprior-pdf")
class AdaptiveScaleHyperprior(ScaleHyperprior, AdaptiveMixin):
def __init__(self, N: int, M: int, **kwargs):
Expand Down
9 changes: 0 additions & 9 deletions compressai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

from compressai.ans import BufferedRansEncoder, RansDecoder
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.entropy_models.entropy_models import DiscreteEntropyBottleneck
from compressai.layers import GDN, MaskedConv2d
from compressai.registry import register_model

Expand Down Expand Up @@ -164,14 +163,6 @@ def decompress(self, strings, shape):
return {"x_hat": x_hat}


@register_model("bmshj2018-factorized-discrete-eb")
class FactorizedPriorDiscreteEB(FactorizedPrior):
def __init__(self, N, M, **kwargs):
super().__init__(N=N, M=M, **kwargs)

self.entropy_bottleneck = DiscreteEntropyBottleneck(M)


@register_model("bmshj2018-factorized-relu")
class FactorizedPriorReLU(FactorizedPrior):
r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
Expand Down

0 comments on commit bb17598

Please sign in to comment.