# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import math
import warnings
from typing import cast
import torch
import torch.nn as nn
from torch import Tensor
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.latent_codecs import LatentCodec
from compressai.models.utils import remap_old_keys, update_registered_buffers
__all__ = [
"CompressionModel",
"SimpleVAECompressionModel",
"get_scale_table",
"SCALES_MIN",
"SCALES_MAX",
"SCALES_LEVELS",
]
# From Balle's tensorflow compression examples
SCALES_MIN = 0.11
SCALES_MAX = 256
SCALES_LEVELS = 64
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
"""Returns table of logarithmically scales."""
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
[docs]
class CompressionModel(nn.Module):
"""Base class for constructing an auto-encoder with any number of
EntropyBottleneck or GaussianConditional modules.
"""
def __init__(self, entropy_bottleneck_channels=None, init_weights=None):
super().__init__()
if entropy_bottleneck_channels is not None:
warnings.warn(
"The entropy_bottleneck_channels parameter is deprecated. "
"Create an entropy_bottleneck in your model directly instead:\n\n"
"class YourModel(CompressionModel):\n"
" def __init__(self):\n"
" super().__init__()\n"
" self.entropy_bottleneck = "
"EntropyBottleneck(entropy_bottleneck_channels)\n",
DeprecationWarning,
stacklevel=2,
)
self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels)
if init_weights is not None:
warnings.warn(
"The init_weights parameter was removed as it was never functional.",
DeprecationWarning,
stacklevel=2,
)
[docs]
def load_state_dict(self, state_dict, strict=True):
for name, module in self.named_modules():
if not any(x.startswith(name) for x in state_dict.keys()):
continue
if isinstance(module, EntropyBottleneck):
update_registered_buffers(
module,
name,
["_quantized_cdf", "_offset", "_cdf_length"],
state_dict,
)
state_dict = remap_old_keys(name, state_dict)
if isinstance(module, GaussianConditional):
update_registered_buffers(
module,
name,
["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
state_dict,
)
return nn.Module.load_state_dict(self, state_dict, strict=strict)
[docs]
def update(self, scale_table=None, force=False, update_quantiles: bool = False):
"""Updates EntropyBottleneck and GaussianConditional CDFs.
Needs to be called once after training to be able to later perform the
evaluation with an actual entropy coder.
Args:
scale_table (torch.Tensor): table of scales (i.e. stdev)
for initializing the Gaussian distributions
(default: 64 logarithmically spaced scales from 0.11 to 256)
force (bool): overwrite previous values (default: False)
update_quantiles (bool): fast update quantiles (default: False)
Returns:
updated (bool): True if at least one of the modules was updated.
"""
if scale_table is None:
scale_table = get_scale_table()
updated = False
for _, module in self.named_modules():
if isinstance(module, EntropyBottleneck):
updated |= module.update(force=force, update_quantiles=update_quantiles)
if isinstance(module, GaussianConditional):
updated |= module.update_scale_table(scale_table, force=force)
return updated
[docs]
def aux_loss(self) -> Tensor:
r"""Returns the total auxiliary loss over all ``EntropyBottleneck``\s.
In contrast to the primary "net" loss used by the "net"
optimizer, the "aux" loss is only used by the "aux" optimizer to
update *only* the ``EntropyBottleneck.quantiles`` parameters. In
fact, the "aux" loss does not depend on image data at all.
The purpose of the "aux" loss is to determine the range within
which most of the mass of a given distribution is contained, as
well as its median (i.e. 50% probability). That is, for a given
distribution, the "aux" loss converges towards satisfying the
following conditions for some chosen ``tail_mass`` probability:
* ``cdf(quantiles[0]) = tail_mass / 2``
* ``cdf(quantiles[1]) = 0.5``
* ``cdf(quantiles[2]) = 1 - tail_mass / 2``
This ensures that the concrete ``_quantized_cdf``\s operate
primarily within a finitely supported region. Any symbols
outside this range must be coded using some alternative method
that does *not* involve the ``_quantized_cdf``\s. Luckily, one
may choose a ``tail_mass`` probability that is sufficiently
small so that this rarely occurs. It is important that we work
with ``_quantized_cdf``\s that have a small finite support;
otherwise, entropy coding runtime performance would suffer.
Thus, ``tail_mass`` should not be too small, either!
"""
loss = sum(m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck))
return cast(Tensor, loss)
[docs]
class SimpleVAECompressionModel(CompressionModel):
"""Simple VAE model with arbitrary latent codec.
.. code-block:: none
┌───┐ y ┌────┐ y_hat ┌───┐
x ──►──┤g_a├──►──┤ lc ├───►───┤g_s├──►── x_hat
└───┘ └────┘ └───┘
"""
g_a: nn.Module
g_s: nn.Module
latent_codec: LatentCodec
def __getitem__(self, key: str) -> LatentCodec:
return self.latent_codec[key]
def forward(self, x):
y = self.g_a(x)
y_out = self.latent_codec(y)
y_hat = y_out["y_hat"]
x_hat = self.g_s(y_hat)
return {
"x_hat": x_hat,
"likelihoods": y_out["likelihoods"],
}
def compress(self, x):
y = self.g_a(x)
outputs = self.latent_codec.compress(y)
return outputs
def decompress(self, *args, **kwargs):
y_out = self.latent_codec.decompress(*args, **kwargs)
y_hat = y_out["y_hat"]
x_hat = self.g_s(y_hat).clamp_(0, 1)
return {
"x_hat": x_hat,
}