Source code for compressai.models.google

# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F

from compressai.ans import BufferedRansEncoder, RansDecoder
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.layers import GDN, MaskedConv2d
from compressai.registry import register_model

from .base import (
    SCALES_LEVELS,
    SCALES_MAX,
    SCALES_MIN,
    CompressionModel,
    get_scale_table,
)
from .utils import conv, deconv

__all__ = [
    "CompressionModel",
    "FactorizedPrior",
    "FactorizedPriorReLU",
    "ScaleHyperprior",
    "MeanScaleHyperprior",
    "JointAutoregressiveHierarchicalPriors",
    "get_scale_table",
    "SCALES_MIN",
    "SCALES_MAX",
    "SCALES_LEVELS",
]


[docs] @register_model("bmshj2018-factorized") class FactorizedPrior(CompressionModel): r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: `"Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations (ICLR), 2018. .. code-block:: none ┌───┐ y x ──►─┤g_a├──►─┐ └───┘ │ ┌─┴─┐ │ Q │ └─┬─┘ y_hat ▼ · EB : · y_hat ▼ ┌───┐ │ x_hat ──◄─┤g_s├────┘ └───┘ EB = Entropy bottleneck Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N, M, **kwargs): super().__init__(**kwargs) self.entropy_bottleneck = EntropyBottleneck(M) self.g_a = nn.Sequential( conv(3, N), GDN(N), conv(N, N), GDN(N), conv(N, N), GDN(N), conv(N, M), ) self.g_s = nn.Sequential( deconv(M, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, 3), ) self.N = N self.M = M @property def downsampling_factor(self) -> int: return 2**4 def forward(self, x): y = self.g_a(x) y_hat, y_likelihoods = self.entropy_bottleneck(y) x_hat = self.g_s(y_hat) return { "x_hat": x_hat, "likelihoods": { "y": y_likelihoods, }, } @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict["g_a.0.weight"].size(0) M = state_dict["g_a.6.weight"].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def compress(self, x): y = self.g_a(x) y_strings = self.entropy_bottleneck.compress(y) return {"strings": [y_strings], "shape": y.size()[-2:]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 1 y_hat = self.entropy_bottleneck.decompress(strings[0], shape) x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat}
@register_model("bmshj2018-factorized-relu") class FactorizedPriorReLU(FactorizedPrior): r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: `"Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations (ICLR), 2018. GDN activations are replaced by ReLU. Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N, M, **kwargs): super().__init__(N=N, M=M, **kwargs) self.g_a = nn.Sequential( conv(3, N), nn.ReLU(inplace=True), conv(N, N), nn.ReLU(inplace=True), conv(N, N), nn.ReLU(inplace=True), conv(N, M), ) self.g_s = nn.Sequential( deconv(M, N), nn.ReLU(inplace=True), deconv(N, N), nn.ReLU(inplace=True), deconv(N, N), nn.ReLU(inplace=True), deconv(N, 3), )
[docs] @register_model("bmshj2018-hyperprior") class ScaleHyperprior(CompressionModel): r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: `"Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>`_ Int. Conf. on Learning Representations (ICLR), 2018. .. code-block:: none ┌───┐ y ┌───┐ z ┌───┐ z_hat z_hat ┌───┐ x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯⋯·───►───┤h_s├─┐ └───┘ │ └───┘ └───┘ EB └───┘ │ ▼ │ ┌─┴─┐ │ │ Q │ ▼ └─┬─┘ │ │ │ y_hat ▼ │ │ │ · │ GC : ◄─────────────────────◄────────────────────┘ · scales_hat y_hat ▼ ┌───┐ │ x_hat ──◄─┤g_s├────┘ └───┘ EB = Entropy bottleneck GC = Gaussian conditional Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N, M, **kwargs): super().__init__(**kwargs) self.entropy_bottleneck = EntropyBottleneck(N) self.g_a = nn.Sequential( conv(3, N), GDN(N), conv(N, N), GDN(N), conv(N, N), GDN(N), conv(N, M), ) self.g_s = nn.Sequential( deconv(M, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, 3), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.ReLU(inplace=True), conv(N, N), nn.ReLU(inplace=True), conv(N, N), ) self.h_s = nn.Sequential( deconv(N, N), nn.ReLU(inplace=True), deconv(N, N), nn.ReLU(inplace=True), conv(N, M, stride=1, kernel_size=3), nn.ReLU(inplace=True), ) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M) @property def downsampling_factor(self) -> int: return 2 ** (4 + 2) def forward(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_hat, z_likelihoods = self.entropy_bottleneck(z) scales_hat = self.h_s(z_hat) y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat) x_hat = self.g_s(y_hat) return { "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict["g_a.0.weight"].size(0) M = state_dict["g_a.6.weight"].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def compress(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) scales_hat = self.h_s(z_hat) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_strings = self.gaussian_conditional.compress(y, indexes) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) scales_hat = self.h_s(z_hat) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_hat = self.gaussian_conditional.decompress(strings[0], indexes, z_hat.dtype) x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat}
[docs] @register_model("mbt2018-mean") class MeanScaleHyperprior(ScaleHyperprior): r"""Scale Hyperprior with non zero-mean Gaussian conditionals from D. Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_, Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). .. code-block:: none ┌───┐ y ┌───┐ z ┌───┐ z_hat z_hat ┌───┐ x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯⋯·───►───┤h_s├─┐ └───┘ │ └───┘ └───┘ EB └───┘ │ ▼ │ ┌─┴─┐ │ │ Q │ ▼ └─┬─┘ │ │ │ y_hat ▼ │ │ │ · │ GC : ◄─────────────────────◄────────────────────┘ · scales_hat │ means_hat y_hat ▼ ┌───┐ │ x_hat ──◄─┤g_s├────┘ └───┘ EB = Entropy bottleneck GC = Gaussian conditional Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N, M, **kwargs): super().__init__(N=N, M=M, **kwargs) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.LeakyReLU(inplace=True), conv(N, N), nn.LeakyReLU(inplace=True), conv(N, N), ) self.h_s = nn.Sequential( deconv(N, M), nn.LeakyReLU(inplace=True), deconv(M, M * 3 // 2), nn.LeakyReLU(inplace=True), conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), ) def forward(self, x): y = self.g_a(x) z = self.h_a(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) gaussian_params = self.h_s(z_hat) scales_hat, means_hat = gaussian_params.chunk(2, 1) y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) x_hat = self.g_s(y_hat) return { "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } def compress(self, x): y = self.g_a(x) z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) gaussian_params = self.h_s(z_hat) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_strings = self.gaussian_conditional.compress(y, indexes, means=means_hat) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) gaussian_params = self.h_s(z_hat) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_hat = self.gaussian_conditional.decompress( strings[0], indexes, means=means_hat ) x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat}
[docs] @register_model("mbt2018") class JointAutoregressiveHierarchicalPriors(MeanScaleHyperprior): r"""Joint Autoregressive Hierarchical Priors model from D. Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_, Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). .. code-block:: none ┌───┐ y ┌───┐ z ┌───┐ z_hat z_hat ┌───┐ x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯⋯·───►───┤h_s├─┐ └───┘ │ └───┘ └───┘ EB └───┘ │ ▼ │ ┌─┴─┐ │ │ Q │ params ▼ └─┬─┘ │ y_hat ▼ ┌─────┐ │ ├──────────►───────┤ CP ├────────►──────────┤ │ └─────┘ │ ▼ ▼ │ │ · ┌─────┐ │ GC : ◄────────◄───────┤ EP ├────────◄──────────┘ · scales_hat └─────┘ │ means_hat y_hat ▼ ┌───┐ │ x_hat ──◄─┤g_s├────┘ └───┘ EB = Entropy bottleneck GC = Gaussian conditional EP = Entropy parameters network CP = Context prediction (masked convolution) Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N=192, M=192, **kwargs): super().__init__(N=N, M=M, **kwargs) self.g_a = nn.Sequential( conv(3, N, kernel_size=5, stride=2), GDN(N), conv(N, N, kernel_size=5, stride=2), GDN(N), conv(N, N, kernel_size=5, stride=2), GDN(N), conv(N, M, kernel_size=5, stride=2), ) self.g_s = nn.Sequential( deconv(M, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, 3, kernel_size=5, stride=2), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), ) self.h_s = nn.Sequential( deconv(N, M, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), deconv(M, M * 3 // 2, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), ) self.entropy_parameters = nn.Sequential( nn.Conv2d(M * 12 // 3, M * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), ) self.context_prediction = MaskedConv2d( M, 2 * M, kernel_size=5, padding=2, stride=1 ) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M) @property def downsampling_factor(self) -> int: return 2 ** (4 + 2) def forward(self, x): y = self.g_a(x) z = self.h_a(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) params = self.h_s(z_hat) y_hat = self.gaussian_conditional.quantize( y, "noise" if self.training else "dequantize" ) ctx_params = self.context_prediction(y_hat) gaussian_params = self.entropy_parameters( torch.cat((params, ctx_params), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) x_hat = self.g_s(y_hat) return { "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict["g_a.0.weight"].size(0) M = state_dict["g_a.6.weight"].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def compress(self, x): if next(self.parameters()).device != torch.device("cpu"): warnings.warn( "Inference on GPU is not recommended for the autoregressive " "models (the entropy coder is run sequentially on CPU).", stacklevel=2, ) y = self.g_a(x) z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s y_hat = F.pad(y, (padding, padding, padding, padding)) y_strings = [] for i in range(y.size(0)): string = self._compress_ar( y_hat[i : i + 1], params[i : i + 1], y_height, y_width, kernel_size, padding, ) y_strings.append(string) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def _compress_ar(self, y_hat, params, height, width, kernel_size, padding): cdf = self.gaussian_conditional.quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional.cdf_length.tolist() offsets = self.gaussian_conditional.offset.tolist() encoder = BufferedRansEncoder() symbols_list = [] indexes_list = [] # Warning, this is slow... # TODO: profile the calls to the bindings... masked_weight = self.context_prediction.weight * self.context_prediction.mask for h in range(height): for w in range(width): y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size] ctx_p = F.conv2d( y_crop, masked_weight, bias=self.context_prediction.bias, ) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[:, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1)) gaussian_params = gaussian_params.squeeze(3).squeeze(2) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_crop = y_crop[:, :, padding, padding] y_q = self.gaussian_conditional.quantize(y_crop, "symbols", means_hat) y_hat[:, :, h + padding, w + padding] = y_q + means_hat symbols_list.extend(y_q.squeeze().tolist()) indexes_list.extend(indexes.squeeze().tolist()) encoder.encode_with_indexes( symbols_list, indexes_list, cdf, cdf_lengths, offsets ) string = encoder.flush() return string def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 if next(self.parameters()).device != torch.device("cpu"): warnings.warn( "Inference on GPU is not recommended for the autoregressive " "models (the entropy coder is run sequentially on CPU).", stacklevel=2, ) # FIXME: we don't respect the default entropy coder and directly call the # range ANS decoder z_hat = self.entropy_bottleneck.decompress(strings[1], shape) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s # initialize y_hat to zeros, and pad it so we can directly work with # sub-tensors of size (N, C, kernel size, kernel_size) y_hat = torch.zeros( (z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding), device=z_hat.device, ) for i, y_string in enumerate(strings[0]): self._decompress_ar( y_string, y_hat[i : i + 1], params[i : i + 1], y_height, y_width, kernel_size, padding, ) y_hat = F.pad(y_hat, (-padding, -padding, -padding, -padding)) x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat} def _decompress_ar( self, y_string, y_hat, params, height, width, kernel_size, padding ): cdf = self.gaussian_conditional.quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional.cdf_length.tolist() offsets = self.gaussian_conditional.offset.tolist() decoder = RansDecoder() decoder.set_stream(y_string) # Warning: this is slow due to the auto-regressive nature of the # decoding... See more recent publication where they use an # auto-regressive module on chunks of channels for faster decoding... for h in range(height): for w in range(width): # only perform the 5x5 convolution on a cropped tensor # centered in (h, w) y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size] ctx_p = F.conv2d( y_crop, self.context_prediction.weight, bias=self.context_prediction.bias, ) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[:, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1)) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) rv = decoder.decode_stream( indexes.squeeze().tolist(), cdf, cdf_lengths, offsets ) rv = torch.Tensor(rv).reshape(1, -1, 1, 1) rv = self.gaussian_conditional.dequantize(rv, means_hat) hp = h + padding wp = w + padding y_hat[:, :, hp : hp + 1, wp : wp + 1] = rv