# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import warnings
from typing import Any, Callable, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from compressai.ops import LowerBound
from .entropy_models import (
_EntropyCoder,
_forward,
default_entropy_coder,
pmf_to_quantized_cdf,
)
class EntropyModelVbr(nn.Module):
r"""Entropy model base class.
Args:
likelihood_bound (float): minimum likelihood bound
entropy_coder (str, optional): set the entropy coder to use, use default
one if None
entropy_coder_precision (int): set the entropy coder precision
"""
def __init__(
self,
likelihood_bound: float = 1e-9,
entropy_coder: Optional[str] = None,
entropy_coder_precision: int = 16,
):
super().__init__()
if entropy_coder is None:
entropy_coder = default_entropy_coder()
self.entropy_coder = _EntropyCoder(entropy_coder)
self.entropy_coder_precision = int(entropy_coder_precision)
self.use_likelihood_bound = likelihood_bound > 0
if self.use_likelihood_bound:
self.likelihood_lower_bound = LowerBound(likelihood_bound)
# to be filled on update()
self.register_buffer("_offset", torch.IntTensor())
self.register_buffer("_quantized_cdf", torch.IntTensor())
self.register_buffer("_cdf_length", torch.IntTensor())
def __getstate__(self):
attributes = self.__dict__.copy()
attributes["entropy_coder"] = self.entropy_coder.name
return attributes
def __setstate__(self, state):
self.__dict__ = state
self.entropy_coder = _EntropyCoder(self.__dict__.pop("entropy_coder"))
@property
def offset(self):
return self._offset
@property
def quantized_cdf(self):
return self._quantized_cdf
@property
def cdf_length(self):
return self._cdf_length
# See: https://github.com/python/mypy/issues/8795
forward: Callable[..., Any] = _forward
def quantize(
self, inputs: Tensor, mode: str, means: Optional[Tensor] = None
) -> Tensor:
if mode not in ("noise", "dequantize", "symbols"):
raise ValueError(f'Invalid quantization mode: "{mode}"')
if mode == "noise":
half = float(0.5)
noise = torch.empty_like(inputs).uniform_(-half, half)
inputs = inputs + noise
return inputs
outputs = inputs.clone()
if means is not None:
outputs -= means
outputs = torch.round(outputs)
if mode == "dequantize":
if means is not None:
outputs += means
return outputs
assert mode == "symbols", mode
outputs = outputs.int()
return outputs
def quantize_variable( # noqa: C901
self,
inputs: Tensor,
mode: str,
means: Optional[Tensor] = None,
qs: Optional[Tensor] = None,
) -> Tensor:
if mode not in ("noise", "ste", "dequantize", "symbols"):
raise ValueError(f'Invalid quantization mode: "{mode}"')
if qs is not None:
assert qs.shape == torch.Size([])
if mode == "noise":
half = float(0.5)
noise = torch.empty_like(inputs).uniform_(-half, half)
if qs is None:
inputs = inputs + noise
else:
inputs = inputs + noise * qs
return inputs
outputs = inputs.clone()
if means is not None:
outputs -= means
if mode == "ste":
if qs is None:
outputs_ste = torch.round(outputs) - outputs.detach() + outputs
else:
outputs_ste = (
torch.round(outputs / qs) * qs - outputs.detach() + outputs
)
if means is not None:
outputs_ste += means
return outputs_ste
if mode == "dequantize":
if qs is None:
outputs = torch.round(outputs)
else:
outputs = torch.round(outputs / qs) * qs
if means is not None:
outputs += means
return outputs
assert mode == "symbols", mode
if qs is None:
outputs = outputs.int()
else:
outputs = torch.round(outputs / qs).int()
# Note: outputs must be mulitplied by qs and mean must be added
# before it is fed to g_s() to reconstruct an image
return outputs
def _quantize(
self, inputs: Tensor, mode: str, means: Optional[Tensor] = None
) -> Tensor:
warnings.warn("_quantize is deprecated. Use quantize instead.", stacklevel=2)
return self.quantize(inputs, mode, means)
@staticmethod
def dequantize(
inputs: Tensor, means: Optional[Tensor] = None, dtype: torch.dtype = torch.float
) -> Tensor:
if means is not None:
outputs = inputs.type_as(means)
outputs += means
else:
outputs = inputs.type(dtype)
return outputs
@staticmethod
def dequantize_variable(
inputs: Tensor,
means: Optional[Tensor] = None,
dtype: torch.dtype = torch.float,
qs: Optional[Tensor] = None,
) -> Tensor:
if means is not None:
outputs = inputs.type_as(means)
if qs is None:
outputs += means
else:
outputs = outputs * qs + means
else:
if qs is None:
outputs = inputs.type(dtype) # .float()
else:
outputs = inputs.type(dtype) * qs # .float() * qs
return outputs
@classmethod
def _dequantize(cls, inputs: Tensor, means: Optional[Tensor] = None) -> Tensor:
warnings.warn("_dequantize. Use dequantize instead.", stacklevel=2)
return cls.dequantize(inputs, means)
def _pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length):
cdf = torch.zeros(
(len(pmf_length), max_length + 2), dtype=torch.int32, device=pmf.device
)
for i, p in enumerate(pmf):
prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0)
_cdf = pmf_to_quantized_cdf(prob, self.entropy_coder_precision)
cdf[i, : _cdf.size(0)] = _cdf
return cdf
def _check_cdf_size(self):
if self._quantized_cdf.numel() == 0:
raise ValueError("Uninitialized CDFs. Run update() first")
if len(self._quantized_cdf.size()) != 2:
raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}")
def _check_offsets_size(self):
if self._offset.numel() == 0:
raise ValueError("Uninitialized offsets. Run update() first")
if len(self._offset.size()) != 1:
raise ValueError(f"Invalid offsets size {self._offset.size()}")
def _check_cdf_length(self):
if self._cdf_length.numel() == 0:
raise ValueError("Uninitialized CDF lengths. Run update() first")
if len(self._cdf_length.size()) != 1:
raise ValueError(f"Invalid offsets size {self._cdf_length.size()}")
def compress(self, inputs, indexes, means=None, qs=None):
"""
Compress input tensors to char strings.
Args:
inputs (torch.Tensor): input tensors
indexes (torch.IntTensor): tensors CDF indexes
means (torch.Tensor, optional): optional tensor means
qs (torch.Tensor, optional): optional quantization step size
"""
if qs is None:
symbols = self.quantize(inputs, "symbols", means)
else:
symbols = self.quantize_variable(inputs, "symbols", means=means, qs=qs)
if len(inputs.size()) < 2:
raise ValueError(
"Invalid `inputs` size. Expected a tensor with at least 2 dimensions."
)
if inputs.size() != indexes.size():
raise ValueError("`inputs` and `indexes` should have the same size.")
self._check_cdf_size()
self._check_cdf_length()
self._check_offsets_size()
strings = []
for i in range(symbols.size(0)):
rv = self.entropy_coder.encode_with_indexes(
symbols[i].reshape(-1).int().tolist(),
indexes[i].reshape(-1).int().tolist(),
self._quantized_cdf.tolist(),
self._cdf_length.reshape(-1).int().tolist(),
self._offset.reshape(-1).int().tolist(),
)
strings.append(rv)
return strings
def decompress(
self,
strings: str,
indexes: torch.IntTensor,
dtype: torch.dtype = torch.float,
means: torch.Tensor = None,
qs=None,
):
"""
Decompress char strings to tensors.
Args:
strings (str): compressed tensors
indexes (torch.IntTensor): tensors CDF indexes
dtype (torch.dtype): type of dequantized output
means (torch.Tensor, optional): optional tensor means
qs (torch.Tensor, optional): optional quantization step size
"""
if not isinstance(strings, (tuple, list)):
raise ValueError("Invalid `strings` parameter type.")
if not len(strings) == indexes.size(0):
raise ValueError("Invalid strings or indexes parameters")
if len(indexes.size()) < 2:
raise ValueError(
"Invalid `indexes` size. Expected a tensor with at least 2 dimensions."
)
self._check_cdf_size()
self._check_cdf_length()
self._check_offsets_size()
if means is not None:
if means.size()[:2] != indexes.size()[:2]:
raise ValueError("Invalid means or indexes parameters")
if means.size() != indexes.size():
for i in range(2, len(indexes.size())):
if means.size(i) != 1:
raise ValueError("Invalid means parameters")
cdf = self._quantized_cdf
outputs = cdf.new_empty(indexes.size())
for i, s in enumerate(strings):
values = self.entropy_coder.decode_with_indexes(
s,
indexes[i].reshape(-1).int().tolist(),
cdf.tolist(),
self._cdf_length.reshape(-1).int().tolist(),
self._offset.reshape(-1).int().tolist(),
)
outputs[i] = torch.tensor(
values, device=outputs.device, dtype=outputs.dtype
).reshape(outputs[i].size())
if qs is None:
outputs = self.dequantize(outputs, means, dtype)
else:
outputs = self.dequantize_variable(outputs, means=means, dtype=dtype, qs=qs)
return outputs
[docs]
class EntropyBottleneckVbr(EntropyModelVbr):
r"""Entropy bottleneck layer, introduced by J. Ballé, D. Minnen, S. Singh,
S. J. Hwang, N. Johnston, in `"Variational image compression with a scale
hyperprior" <https://arxiv.org/abs/1802.01436>`_.
This is a re-implementation of the entropy bottleneck layer in
*tensorflow/compression*. See the original paper and the `tensorflow
documentation
<https://github.com/tensorflow/compression/blob/v1.3/docs/entropy_bottleneck.md>`__
for an introduction.
"""
_offset: Tensor
def __init__(
self,
channels: int,
*args: Any,
tail_mass: float = 1e-9,
init_scale: float = 10,
filters: Tuple[int, ...] = (3, 3, 3, 3),
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.channels = int(channels)
self.filters = tuple(int(f) for f in filters)
self.init_scale = float(init_scale)
self.tail_mass = float(tail_mass)
# Create parameters
filters = (1,) + self.filters + (1,)
scale = self.init_scale ** (1 / (len(self.filters) + 1))
channels = self.channels
for i in range(len(self.filters) + 1):
init = np.log(np.expm1(1 / scale / filters[i + 1]))
matrix = torch.Tensor(channels, filters[i + 1], filters[i])
matrix.data.fill_(init)
self.register_parameter(f"_matrix{i:d}", nn.Parameter(matrix))
bias = torch.Tensor(channels, filters[i + 1], 1)
nn.init.uniform_(bias, -0.5, 0.5)
self.register_parameter(f"_bias{i:d}", nn.Parameter(bias))
if i < len(self.filters):
factor = torch.Tensor(channels, filters[i + 1], 1)
nn.init.zeros_(factor)
self.register_parameter(f"_factor{i:d}", nn.Parameter(factor))
self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3))
init = torch.Tensor([-self.init_scale, 0, self.init_scale])
self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1)
target = np.log(2 / self.tail_mass - 1)
self.register_buffer("target", torch.Tensor([-target, 0, target]))
def _get_medians(self) -> Tensor:
medians = self.quantiles[:, :, 1:2]
return medians
def update(self, force: bool = False) -> bool:
# Check if we need to update the bottleneck parameters, the offsets are
# only computed and stored when the conditonal model is update()'d.
if self._offset.numel() > 0 and not force:
return False
medians = self.quantiles[:, 0, 1]
minima = medians - self.quantiles[:, 0, 0]
minima = torch.ceil(minima).int()
minima = torch.clamp(minima, min=0)
maxima = self.quantiles[:, 0, 2] - medians
maxima = torch.ceil(maxima).int()
maxima = torch.clamp(maxima, min=0)
self._offset = -minima
pmf_start = medians - minima
pmf_length = maxima + minima + 1
max_length = pmf_length.max().item()
device = pmf_start.device
samples = torch.arange(max_length, device=device)
samples = samples[None, :] + pmf_start[:, None, None]
pmf, lower, upper = self._likelihood(samples, stop_gradient=True)
pmf = pmf[:, 0, :]
tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:])
quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
self._quantized_cdf = quantized_cdf
self._cdf_length = pmf_length + 2
return True
def update_variable(self, force: bool = False, qs=1.0) -> bool:
# Check if we need to update the bottleneck parameters, the offsets are
# only computed and stored when the conditonal model is update()'d.
if self._offset.numel() > 0 and not force:
return False
medians = self.quantiles[:, 0, 1]
minima = (medians - self.quantiles[:, 0, 0]) / qs
minima = torch.ceil(minima).int()
minima = torch.clamp(minima, min=0)
maxima = (self.quantiles[:, 0, 2] - medians) / qs
maxima = torch.ceil(maxima).int()
maxima = torch.clamp(maxima, min=0)
self._offset = -minima
pmf_start = medians - minima * qs
pmf_length = maxima + minima + 1
max_length = pmf_length.max().item()
device = pmf_start.device
samples = torch.arange(max_length, device=device) * qs
samples = samples[None, :] + pmf_start[:, None, None]
pmf, lower, upper = self._likelihood_variable(
samples, stop_gradient=True, qs=qs
)
pmf = pmf[:, 0, :]
tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:])
quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
self._quantized_cdf = quantized_cdf
self._cdf_length = pmf_length + 2
return True
def loss(self) -> Tensor:
logits = self._logits_cumulative(self.quantiles, stop_gradient=True)
loss = torch.abs(logits - self.target).sum()
return loss
def _logits_cumulative(self, inputs: Tensor, stop_gradient: bool) -> Tensor:
# TorchScript not yet working (nn.Mmodule indexing not supported)
logits = inputs
for i in range(len(self.filters) + 1):
matrix = getattr(self, f"_matrix{i:d}")
if stop_gradient:
matrix = matrix.detach()
logits = torch.matmul(F.softplus(matrix), logits)
bias = getattr(self, f"_bias{i:d}")
if stop_gradient:
bias = bias.detach()
logits += bias
if i < len(self.filters):
factor = getattr(self, f"_factor{i:d}")
if stop_gradient:
factor = factor.detach()
logits += torch.tanh(factor) * torch.tanh(logits)
return logits
@torch.jit.unused
def _likelihood(
self, inputs: Tensor, stop_gradient: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
half = float(0.5)
lower = self._logits_cumulative(inputs - half, stop_gradient=stop_gradient)
upper = self._logits_cumulative(inputs + half, stop_gradient=stop_gradient)
likelihood = torch.sigmoid(upper) - torch.sigmoid(lower)
return likelihood, lower, upper
@torch.jit.unused
def _likelihood_variable(
self, inputs: Tensor, stop_gradient: bool = False, qs: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Tensor]:
half = float(0.5)
if qs is None:
v0 = inputs - half
v1 = inputs + half
else:
v0 = inputs - half * qs
v1 = inputs + half * qs
lower = self._logits_cumulative(v0, stop_gradient=stop_gradient)
upper = self._logits_cumulative(v1, stop_gradient=stop_gradient)
likelihood = torch.sigmoid(upper) - torch.sigmoid(lower)
return likelihood, lower, upper
def forward(
self,
x: Tensor,
training: Optional[bool] = None,
qs: Optional[Tensor] = None,
ste: Optional[bool] = False,
) -> Tuple[Tensor, Tensor]:
if training is None:
training = self.training
if not torch.jit.is_scripting():
# x from B x C x ... to C x B x ...
perm = np.arange(len(x.shape))
perm[0], perm[1] = perm[1], perm[0]
# Compute inverse permutation
inv_perm = np.arange(len(x.shape))[np.argsort(perm)]
else:
raise NotImplementedError()
# TorchScript in 2D for static inference
# Convert to (channels, ... , batch) format
# perm = (1, 2, 3, 0)
# inv_perm = (3, 0, 1, 2)
x = x.permute(*perm).contiguous()
shape = x.size()
values = x.reshape(x.size(0), 1, -1)
# Add noise or quantize
if qs is None:
outputs = self.quantize(
values, "noise" if training else "dequantize", self._get_medians()
)
else:
if ste is False:
outputs = self.quantize_variable(
values,
"noise" if training else "dequantize",
self._get_medians(),
qs,
)
else:
outputs = self.quantize_variable(values, "ste", self._get_medians(), qs)
if not torch.jit.is_scripting():
if qs is None:
likelihood, _, _ = self._likelihood(outputs)
else:
if ste and training: # in this case, use also output with noise
likelihood, _, _ = self._likelihood_variable(outputs, qs)
else: # noise case, i.e. output already obtained by adding noise or it is not training
likelihood, _, _ = self._likelihood_variable(outputs, qs)
if self.use_likelihood_bound:
likelihood = self.likelihood_lower_bound(likelihood)
else:
raise NotImplementedError()
# TorchScript not yet supported
# likelihood = torch.zeros_like(outputs)
# Convert back to input tensor shape
outputs = outputs.reshape(shape)
outputs = outputs.permute(*inv_perm).contiguous()
likelihood = likelihood.reshape(shape)
likelihood = likelihood.permute(*inv_perm).contiguous()
return outputs, likelihood
@staticmethod
def _build_indexes(size):
dims = len(size)
N = size[0]
C = size[1]
view_dims = np.ones((dims,), dtype=np.int64)
view_dims[1] = -1
indexes = torch.arange(C).view(*view_dims)
indexes = indexes.int()
return indexes.repeat(N, 1, *size[2:])
@staticmethod
def _extend_ndims(tensor, n):
return tensor.reshape(-1, *([1] * n)) if n > 0 else tensor.reshape(-1)
def compress(self, x, qs=None):
indexes = self._build_indexes(x.size())
medians = self._get_medians().detach()
spatial_dims = len(x.size()) - 2
medians = self._extend_ndims(medians, spatial_dims)
medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1)))
return super().compress(x, indexes, medians, qs)
def decompress(self, strings, size, qs=None):
output_size = (len(strings), self._quantized_cdf.size(0), *size)
indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
medians = self._extend_ndims(self._get_medians().detach(), len(size))
medians = medians.expand(len(strings), *([-1] * (len(size) + 1)))
return super().decompress(strings, indexes, medians.dtype, medians, qs)