Source code for compressai.entropy_models.entropy_models_vbr

# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import warnings

from typing import Any, Callable, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor

from compressai.ops import LowerBound

from .entropy_models import (
    _EntropyCoder,
    _forward,
    default_entropy_coder,
    pmf_to_quantized_cdf,
)


class EntropyModelVbr(nn.Module):
    r"""Entropy model base class.

    Args:
        likelihood_bound (float): minimum likelihood bound
        entropy_coder (str, optional): set the entropy coder to use, use default
            one if None
        entropy_coder_precision (int): set the entropy coder precision
    """

    def __init__(
        self,
        likelihood_bound: float = 1e-9,
        entropy_coder: Optional[str] = None,
        entropy_coder_precision: int = 16,
    ):
        super().__init__()

        if entropy_coder is None:
            entropy_coder = default_entropy_coder()
        self.entropy_coder = _EntropyCoder(entropy_coder)
        self.entropy_coder_precision = int(entropy_coder_precision)

        self.use_likelihood_bound = likelihood_bound > 0
        if self.use_likelihood_bound:
            self.likelihood_lower_bound = LowerBound(likelihood_bound)

        # to be filled on update()
        self.register_buffer("_offset", torch.IntTensor())
        self.register_buffer("_quantized_cdf", torch.IntTensor())
        self.register_buffer("_cdf_length", torch.IntTensor())

    def __getstate__(self):
        attributes = self.__dict__.copy()
        attributes["entropy_coder"] = self.entropy_coder.name
        return attributes

    def __setstate__(self, state):
        self.__dict__ = state
        self.entropy_coder = _EntropyCoder(self.__dict__.pop("entropy_coder"))

    @property
    def offset(self):
        return self._offset

    @property
    def quantized_cdf(self):
        return self._quantized_cdf

    @property
    def cdf_length(self):
        return self._cdf_length

    # See: https://github.com/python/mypy/issues/8795
    forward: Callable[..., Any] = _forward

    def quantize(
        self, inputs: Tensor, mode: str, means: Optional[Tensor] = None
    ) -> Tensor:
        if mode not in ("noise", "dequantize", "symbols"):
            raise ValueError(f'Invalid quantization mode: "{mode}"')

        if mode == "noise":
            half = float(0.5)
            noise = torch.empty_like(inputs).uniform_(-half, half)
            inputs = inputs + noise
            return inputs

        outputs = inputs.clone()
        if means is not None:
            outputs -= means

        outputs = torch.round(outputs)

        if mode == "dequantize":
            if means is not None:
                outputs += means
            return outputs

        assert mode == "symbols", mode
        outputs = outputs.int()
        return outputs

    def quantize_variable(  # noqa: C901
        self,
        inputs: Tensor,
        mode: str,
        means: Optional[Tensor] = None,
        qs: Optional[Tensor] = None,
    ) -> Tensor:
        if mode not in ("noise", "ste", "dequantize", "symbols"):
            raise ValueError(f'Invalid quantization mode: "{mode}"')
        if qs is not None:
            assert qs.shape == torch.Size([])

        if mode == "noise":
            half = float(0.5)
            noise = torch.empty_like(inputs).uniform_(-half, half)
            if qs is None:
                inputs = inputs + noise
            else:
                inputs = inputs + noise * qs
            return inputs

        outputs = inputs.clone()
        if means is not None:
            outputs -= means

        if mode == "ste":
            if qs is None:
                outputs_ste = torch.round(outputs) - outputs.detach() + outputs
            else:
                outputs_ste = (
                    torch.round(outputs / qs) * qs - outputs.detach() + outputs
                )
            if means is not None:
                outputs_ste += means
            return outputs_ste

        if mode == "dequantize":
            if qs is None:
                outputs = torch.round(outputs)
            else:
                outputs = torch.round(outputs / qs) * qs
            if means is not None:
                outputs += means
            return outputs

        assert mode == "symbols", mode
        if qs is None:
            outputs = outputs.int()
        else:
            outputs = torch.round(outputs / qs).int()
            # Note: outputs must be mulitplied by qs and mean must be added
            # before it is fed to g_s() to reconstruct an image
        return outputs

    def _quantize(
        self, inputs: Tensor, mode: str, means: Optional[Tensor] = None
    ) -> Tensor:
        warnings.warn("_quantize is deprecated. Use quantize instead.", stacklevel=2)
        return self.quantize(inputs, mode, means)

    @staticmethod
    def dequantize(
        inputs: Tensor, means: Optional[Tensor] = None, dtype: torch.dtype = torch.float
    ) -> Tensor:
        if means is not None:
            outputs = inputs.type_as(means)
            outputs += means
        else:
            outputs = inputs.type(dtype)
        return outputs

    @staticmethod
    def dequantize_variable(
        inputs: Tensor,
        means: Optional[Tensor] = None,
        dtype: torch.dtype = torch.float,
        qs: Optional[Tensor] = None,
    ) -> Tensor:
        if means is not None:
            outputs = inputs.type_as(means)
            if qs is None:
                outputs += means
            else:
                outputs = outputs * qs + means
        else:
            if qs is None:
                outputs = inputs.type(dtype)  # .float()
            else:
                outputs = inputs.type(dtype) * qs  # .float() * qs
        return outputs

    @classmethod
    def _dequantize(cls, inputs: Tensor, means: Optional[Tensor] = None) -> Tensor:
        warnings.warn("_dequantize. Use dequantize instead.", stacklevel=2)
        return cls.dequantize(inputs, means)

    def _pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length):
        cdf = torch.zeros(
            (len(pmf_length), max_length + 2), dtype=torch.int32, device=pmf.device
        )
        for i, p in enumerate(pmf):
            prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0)
            _cdf = pmf_to_quantized_cdf(prob, self.entropy_coder_precision)
            cdf[i, : _cdf.size(0)] = _cdf
        return cdf

    def _check_cdf_size(self):
        if self._quantized_cdf.numel() == 0:
            raise ValueError("Uninitialized CDFs. Run update() first")

        if len(self._quantized_cdf.size()) != 2:
            raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}")

    def _check_offsets_size(self):
        if self._offset.numel() == 0:
            raise ValueError("Uninitialized offsets. Run update() first")

        if len(self._offset.size()) != 1:
            raise ValueError(f"Invalid offsets size {self._offset.size()}")

    def _check_cdf_length(self):
        if self._cdf_length.numel() == 0:
            raise ValueError("Uninitialized CDF lengths. Run update() first")

        if len(self._cdf_length.size()) != 1:
            raise ValueError(f"Invalid offsets size {self._cdf_length.size()}")

    def compress(self, inputs, indexes, means=None, qs=None):
        """
        Compress input tensors to char strings.

        Args:
            inputs (torch.Tensor): input tensors
            indexes (torch.IntTensor): tensors CDF indexes
            means (torch.Tensor, optional): optional tensor means
            qs (torch.Tensor, optional): optional quantization step size
        """
        if qs is None:
            symbols = self.quantize(inputs, "symbols", means)
        else:
            symbols = self.quantize_variable(inputs, "symbols", means=means, qs=qs)

        if len(inputs.size()) < 2:
            raise ValueError(
                "Invalid `inputs` size. Expected a tensor with at least 2 dimensions."
            )

        if inputs.size() != indexes.size():
            raise ValueError("`inputs` and `indexes` should have the same size.")

        self._check_cdf_size()
        self._check_cdf_length()
        self._check_offsets_size()

        strings = []
        for i in range(symbols.size(0)):
            rv = self.entropy_coder.encode_with_indexes(
                symbols[i].reshape(-1).int().tolist(),
                indexes[i].reshape(-1).int().tolist(),
                self._quantized_cdf.tolist(),
                self._cdf_length.reshape(-1).int().tolist(),
                self._offset.reshape(-1).int().tolist(),
            )
            strings.append(rv)
        return strings

    def decompress(
        self,
        strings: str,
        indexes: torch.IntTensor,
        dtype: torch.dtype = torch.float,
        means: torch.Tensor = None,
        qs=None,
    ):
        """
        Decompress char strings to tensors.

        Args:
            strings (str): compressed tensors
            indexes (torch.IntTensor): tensors CDF indexes
            dtype (torch.dtype): type of dequantized output
            means (torch.Tensor, optional): optional tensor means
            qs (torch.Tensor, optional): optional quantization step size
        """

        if not isinstance(strings, (tuple, list)):
            raise ValueError("Invalid `strings` parameter type.")

        if not len(strings) == indexes.size(0):
            raise ValueError("Invalid strings or indexes parameters")

        if len(indexes.size()) < 2:
            raise ValueError(
                "Invalid `indexes` size. Expected a tensor with at least 2 dimensions."
            )

        self._check_cdf_size()
        self._check_cdf_length()
        self._check_offsets_size()

        if means is not None:
            if means.size()[:2] != indexes.size()[:2]:
                raise ValueError("Invalid means or indexes parameters")
            if means.size() != indexes.size():
                for i in range(2, len(indexes.size())):
                    if means.size(i) != 1:
                        raise ValueError("Invalid means parameters")

        cdf = self._quantized_cdf
        outputs = cdf.new_empty(indexes.size())

        for i, s in enumerate(strings):
            values = self.entropy_coder.decode_with_indexes(
                s,
                indexes[i].reshape(-1).int().tolist(),
                cdf.tolist(),
                self._cdf_length.reshape(-1).int().tolist(),
                self._offset.reshape(-1).int().tolist(),
            )
            outputs[i] = torch.tensor(
                values, device=outputs.device, dtype=outputs.dtype
            ).reshape(outputs[i].size())
        if qs is None:
            outputs = self.dequantize(outputs, means, dtype)
        else:
            outputs = self.dequantize_variable(outputs, means=means, dtype=dtype, qs=qs)
        return outputs


[docs] class EntropyBottleneckVbr(EntropyModelVbr): r"""Entropy bottleneck layer, introduced by J. Ballé, D. Minnen, S. Singh, S. J. Hwang, N. Johnston, in `"Variational image compression with a scale hyperprior" <https://arxiv.org/abs/1802.01436>`_. This is a re-implementation of the entropy bottleneck layer in *tensorflow/compression*. See the original paper and the `tensorflow documentation <https://github.com/tensorflow/compression/blob/v1.3/docs/entropy_bottleneck.md>`__ for an introduction. """ _offset: Tensor def __init__( self, channels: int, *args: Any, tail_mass: float = 1e-9, init_scale: float = 10, filters: Tuple[int, ...] = (3, 3, 3, 3), **kwargs: Any, ): super().__init__(*args, **kwargs) self.channels = int(channels) self.filters = tuple(int(f) for f in filters) self.init_scale = float(init_scale) self.tail_mass = float(tail_mass) # Create parameters filters = (1,) + self.filters + (1,) scale = self.init_scale ** (1 / (len(self.filters) + 1)) channels = self.channels for i in range(len(self.filters) + 1): init = np.log(np.expm1(1 / scale / filters[i + 1])) matrix = torch.Tensor(channels, filters[i + 1], filters[i]) matrix.data.fill_(init) self.register_parameter(f"_matrix{i:d}", nn.Parameter(matrix)) bias = torch.Tensor(channels, filters[i + 1], 1) nn.init.uniform_(bias, -0.5, 0.5) self.register_parameter(f"_bias{i:d}", nn.Parameter(bias)) if i < len(self.filters): factor = torch.Tensor(channels, filters[i + 1], 1) nn.init.zeros_(factor) self.register_parameter(f"_factor{i:d}", nn.Parameter(factor)) self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3)) init = torch.Tensor([-self.init_scale, 0, self.init_scale]) self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1) target = np.log(2 / self.tail_mass - 1) self.register_buffer("target", torch.Tensor([-target, 0, target])) def _get_medians(self) -> Tensor: medians = self.quantiles[:, :, 1:2] return medians def update(self, force: bool = False) -> bool: # Check if we need to update the bottleneck parameters, the offsets are # only computed and stored when the conditonal model is update()'d. if self._offset.numel() > 0 and not force: return False medians = self.quantiles[:, 0, 1] minima = medians - self.quantiles[:, 0, 0] minima = torch.ceil(minima).int() minima = torch.clamp(minima, min=0) maxima = self.quantiles[:, 0, 2] - medians maxima = torch.ceil(maxima).int() maxima = torch.clamp(maxima, min=0) self._offset = -minima pmf_start = medians - minima pmf_length = maxima + minima + 1 max_length = pmf_length.max().item() device = pmf_start.device samples = torch.arange(max_length, device=device) samples = samples[None, :] + pmf_start[:, None, None] pmf, lower, upper = self._likelihood(samples, stop_gradient=True) pmf = pmf[:, 0, :] tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:]) quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) self._quantized_cdf = quantized_cdf self._cdf_length = pmf_length + 2 return True def update_variable(self, force: bool = False, qs=1.0) -> bool: # Check if we need to update the bottleneck parameters, the offsets are # only computed and stored when the conditonal model is update()'d. if self._offset.numel() > 0 and not force: return False medians = self.quantiles[:, 0, 1] minima = (medians - self.quantiles[:, 0, 0]) / qs minima = torch.ceil(minima).int() minima = torch.clamp(minima, min=0) maxima = (self.quantiles[:, 0, 2] - medians) / qs maxima = torch.ceil(maxima).int() maxima = torch.clamp(maxima, min=0) self._offset = -minima pmf_start = medians - minima * qs pmf_length = maxima + minima + 1 max_length = pmf_length.max().item() device = pmf_start.device samples = torch.arange(max_length, device=device) * qs samples = samples[None, :] + pmf_start[:, None, None] pmf, lower, upper = self._likelihood_variable( samples, stop_gradient=True, qs=qs ) pmf = pmf[:, 0, :] tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:]) quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) self._quantized_cdf = quantized_cdf self._cdf_length = pmf_length + 2 return True def loss(self) -> Tensor: logits = self._logits_cumulative(self.quantiles, stop_gradient=True) loss = torch.abs(logits - self.target).sum() return loss def _logits_cumulative(self, inputs: Tensor, stop_gradient: bool) -> Tensor: # TorchScript not yet working (nn.Mmodule indexing not supported) logits = inputs for i in range(len(self.filters) + 1): matrix = getattr(self, f"_matrix{i:d}") if stop_gradient: matrix = matrix.detach() logits = torch.matmul(F.softplus(matrix), logits) bias = getattr(self, f"_bias{i:d}") if stop_gradient: bias = bias.detach() logits += bias if i < len(self.filters): factor = getattr(self, f"_factor{i:d}") if stop_gradient: factor = factor.detach() logits += torch.tanh(factor) * torch.tanh(logits) return logits @torch.jit.unused def _likelihood( self, inputs: Tensor, stop_gradient: bool = False ) -> Tuple[Tensor, Tensor, Tensor]: half = float(0.5) lower = self._logits_cumulative(inputs - half, stop_gradient=stop_gradient) upper = self._logits_cumulative(inputs + half, stop_gradient=stop_gradient) likelihood = torch.sigmoid(upper) - torch.sigmoid(lower) return likelihood, lower, upper @torch.jit.unused def _likelihood_variable( self, inputs: Tensor, stop_gradient: bool = False, qs: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor, Tensor]: half = float(0.5) if qs is None: v0 = inputs - half v1 = inputs + half else: v0 = inputs - half * qs v1 = inputs + half * qs lower = self._logits_cumulative(v0, stop_gradient=stop_gradient) upper = self._logits_cumulative(v1, stop_gradient=stop_gradient) likelihood = torch.sigmoid(upper) - torch.sigmoid(lower) return likelihood, lower, upper def forward( self, x: Tensor, training: Optional[bool] = None, qs: Optional[Tensor] = None, ste: Optional[bool] = False, ) -> Tuple[Tensor, Tensor]: if training is None: training = self.training if not torch.jit.is_scripting(): # x from B x C x ... to C x B x ... perm = np.arange(len(x.shape)) perm[0], perm[1] = perm[1], perm[0] # Compute inverse permutation inv_perm = np.arange(len(x.shape))[np.argsort(perm)] else: raise NotImplementedError() # TorchScript in 2D for static inference # Convert to (channels, ... , batch) format # perm = (1, 2, 3, 0) # inv_perm = (3, 0, 1, 2) x = x.permute(*perm).contiguous() shape = x.size() values = x.reshape(x.size(0), 1, -1) # Add noise or quantize if qs is None: outputs = self.quantize( values, "noise" if training else "dequantize", self._get_medians() ) else: if ste is False: outputs = self.quantize_variable( values, "noise" if training else "dequantize", self._get_medians(), qs, ) else: outputs = self.quantize_variable(values, "ste", self._get_medians(), qs) if not torch.jit.is_scripting(): if qs is None: likelihood, _, _ = self._likelihood(outputs) else: if ste and training: # in this case, use also output with noise likelihood, _, _ = self._likelihood_variable(outputs, qs) else: # noise case, i.e. output already obtained by adding noise or it is not training likelihood, _, _ = self._likelihood_variable(outputs, qs) if self.use_likelihood_bound: likelihood = self.likelihood_lower_bound(likelihood) else: raise NotImplementedError() # TorchScript not yet supported # likelihood = torch.zeros_like(outputs) # Convert back to input tensor shape outputs = outputs.reshape(shape) outputs = outputs.permute(*inv_perm).contiguous() likelihood = likelihood.reshape(shape) likelihood = likelihood.permute(*inv_perm).contiguous() return outputs, likelihood @staticmethod def _build_indexes(size): dims = len(size) N = size[0] C = size[1] view_dims = np.ones((dims,), dtype=np.int64) view_dims[1] = -1 indexes = torch.arange(C).view(*view_dims) indexes = indexes.int() return indexes.repeat(N, 1, *size[2:]) @staticmethod def _extend_ndims(tensor, n): return tensor.reshape(-1, *([1] * n)) if n > 0 else tensor.reshape(-1) def compress(self, x, qs=None): indexes = self._build_indexes(x.size()) medians = self._get_medians().detach() spatial_dims = len(x.size()) - 2 medians = self._extend_ndims(medians, spatial_dims) medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1))) return super().compress(x, indexes, medians, qs) def decompress(self, strings, size, qs=None): output_size = (len(strings), self._quantized_cdf.size(0), *size) indexes = self._build_indexes(output_size).to(self._quantized_cdf.device) medians = self._extend_ndims(self._get_medians().detach(), len(size)) medians = medians.expand(len(strings), *([-1] * (len(size) + 1))) return super().decompress(strings, indexes, medians.dtype, medians, qs)