Source code for compressai.models.base

# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import math
import warnings

from typing import cast

import torch
import torch.nn as nn

from torch import Tensor

from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.latent_codecs import LatentCodec
from compressai.models.utils import remap_old_keys, update_registered_buffers

__all__ = [
    "CompressionModel",
    "SimpleVAECompressionModel",
    "get_scale_table",
    "SCALES_MIN",
    "SCALES_MAX",
    "SCALES_LEVELS",
]


# From Balle's tensorflow compression examples
SCALES_MIN = 0.11
SCALES_MAX = 256
SCALES_LEVELS = 64


def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
    """Returns table of logarithmically scales."""
    return torch.exp(torch.linspace(math.log(min), math.log(max), levels))


[docs] class CompressionModel(nn.Module): """Base class for constructing an auto-encoder with any number of EntropyBottleneck or GaussianConditional modules. """ def __init__(self, entropy_bottleneck_channels=None, init_weights=None): super().__init__() if entropy_bottleneck_channels is not None: warnings.warn( "The entropy_bottleneck_channels parameter is deprecated. " "Create an entropy_bottleneck in your model directly instead:\n\n" "class YourModel(CompressionModel):\n" " def __init__(self):\n" " super().__init__()\n" " self.entropy_bottleneck = " "EntropyBottleneck(entropy_bottleneck_channels)\n", DeprecationWarning, stacklevel=2, ) self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels) if init_weights is not None: warnings.warn( "The init_weights parameter was removed as it was never functional.", DeprecationWarning, stacklevel=2, )
[docs] def load_state_dict(self, state_dict, strict=True): for name, module in self.named_modules(): if not any(x.startswith(name) for x in state_dict.keys()): continue if isinstance(module, EntropyBottleneck): update_registered_buffers( module, name, ["_quantized_cdf", "_offset", "_cdf_length"], state_dict, ) state_dict = remap_old_keys(name, state_dict) if isinstance(module, GaussianConditional): update_registered_buffers( module, name, ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], state_dict, ) return nn.Module.load_state_dict(self, state_dict, strict=strict)
[docs] def update(self, scale_table=None, force=False, update_quantiles: bool = False): """Updates EntropyBottleneck and GaussianConditional CDFs. Needs to be called once after training to be able to later perform the evaluation with an actual entropy coder. Args: scale_table (torch.Tensor): table of scales (i.e. stdev) for initializing the Gaussian distributions (default: 64 logarithmically spaced scales from 0.11 to 256) force (bool): overwrite previous values (default: False) update_quantiles (bool): fast update quantiles (default: False) Returns: updated (bool): True if at least one of the modules was updated. """ if scale_table is None: scale_table = get_scale_table() updated = False for _, module in self.named_modules(): if isinstance(module, EntropyBottleneck): updated |= module.update(force=force, update_quantiles=update_quantiles) if isinstance(module, GaussianConditional): updated |= module.update_scale_table(scale_table, force=force) return updated
[docs] def aux_loss(self) -> Tensor: r"""Returns the total auxiliary loss over all ``EntropyBottleneck``\s. In contrast to the primary "net" loss used by the "net" optimizer, the "aux" loss is only used by the "aux" optimizer to update *only* the ``EntropyBottleneck.quantiles`` parameters. In fact, the "aux" loss does not depend on image data at all. The purpose of the "aux" loss is to determine the range within which most of the mass of a given distribution is contained, as well as its median (i.e. 50% probability). That is, for a given distribution, the "aux" loss converges towards satisfying the following conditions for some chosen ``tail_mass`` probability: * ``cdf(quantiles[0]) = tail_mass / 2`` * ``cdf(quantiles[1]) = 0.5`` * ``cdf(quantiles[2]) = 1 - tail_mass / 2`` This ensures that the concrete ``_quantized_cdf``\s operate primarily within a finitely supported region. Any symbols outside this range must be coded using some alternative method that does *not* involve the ``_quantized_cdf``\s. Luckily, one may choose a ``tail_mass`` probability that is sufficiently small so that this rarely occurs. It is important that we work with ``_quantized_cdf``\s that have a small finite support; otherwise, entropy coding runtime performance would suffer. Thus, ``tail_mass`` should not be too small, either! """ loss = sum(m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck)) return cast(Tensor, loss)
[docs] class SimpleVAECompressionModel(CompressionModel): """Simple VAE model with arbitrary latent codec. .. code-block:: none ┌───┐ y ┌────┐ y_hat ┌───┐ x ──►──┤g_a├──►──┤ lc ├───►───┤g_s├──►── x_hat └───┘ └────┘ └───┘ """ g_a: nn.Module g_s: nn.Module latent_codec: LatentCodec def __getitem__(self, key: str) -> LatentCodec: return self.latent_codec[key] def forward(self, x): y = self.g_a(x) y_out = self.latent_codec(y) y_hat = y_out["y_hat"] x_hat = self.g_s(y_hat) return { "x_hat": x_hat, "likelihoods": y_out["likelihoods"], } def compress(self, x): y = self.g_a(x) outputs = self.latent_codec.compress(y) return outputs def decompress(self, *args, **kwargs): y_out = self.latent_codec.decompress(*args, **kwargs) y_hat = y_out["y_hat"] x_hat = self.g_s(y_hat).clamp_(0, 1) return { "x_hat": x_hat, }