Source code for compressai.latent_codecs.hyperprior

# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Any, Dict, List, Mapping, Optional, Tuple

from torch import Tensor

from compressai.registry import register_module

from .base import LatentCodec
from .gaussian_conditional import GaussianConditionalLatentCodec
from .hyper import HyperLatentCodec

__all__ = [
    "HyperpriorLatentCodec",
]


[docs] @register_module("HyperpriorLatentCodec") class HyperpriorLatentCodec(LatentCodec): """Hyperprior codec constructed from latent codec for ``y`` that compresses ``y`` using ``params`` from ``hyper`` branch. Hyperprior entropy modeling introduced in `"Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>`_, by J. Balle, D. Minnen, S. Singh, S.J. Hwang, and N. Johnston, International Conference on Learning Representations (ICLR), 2018. .. code-block:: none ┌──────────┐ ┌─►──┤ lc_hyper ├──►─┐ │ └──────────┘ │ │ ▼ params │ │ │ ┌──┴───┐ y ──┴───────►─────────┤ lc_y ├───►── y_hat └──────┘ By default, the following codec is constructed: .. code-block:: none ┌───┐ z ┌───┐ z_hat z_hat ┌───┐ ┌─►──┤h_a├──►──┤ Q ├───►───····───►───┤h_s├──►─┐ │ └───┘ └───┘ EB └───┘ │ │ │ │ ┌──────────────◄────────────┘ │ │ params │ ┌──┴──┐ │ │ EP │ │ └──┬──┘ │ │ │ ┌───┐ y_hat ▼ y ──┴─►─┤ Q ├────►────····────►── y_hat └───┘ GC Common configurations of latent codecs include: - entropy bottleneck ``hyper`` (default) and gaussian conditional ``y`` (default) - entropy bottleneck ``hyper`` (default) and autoregressive ``y`` """ latent_codec: Mapping[str, LatentCodec] def __init__( self, latent_codec: Optional[Mapping[str, LatentCodec]] = None, **kwargs ): super().__init__() self._set_group_defaults( "latent_codec", latent_codec, defaults={ "y": GaussianConditionalLatentCodec, "hyper": HyperLatentCodec, }, save_direct=True, ) def __getitem__(self, key: str) -> LatentCodec: return self.latent_codec[key] def forward(self, y: Tensor) -> Dict[str, Any]: hyper_out = self.latent_codec["hyper"](y) y_out = self.latent_codec["y"](y, hyper_out["params"]) return { "likelihoods": { "y": y_out["likelihoods"]["y"], "z": hyper_out["likelihoods"]["z"], }, "y_hat": y_out["y_hat"], } def compress(self, y: Tensor) -> Dict[str, Any]: hyper_out = self.latent_codec["hyper"].compress(y) y_out = self.latent_codec["y"].compress(y, hyper_out["params"]) [z_strings] = hyper_out["strings"] return { "strings": [*y_out["strings"], z_strings], "shape": {"y": y_out["shape"], "hyper": hyper_out["shape"]}, "y_hat": y_out["y_hat"], } def decompress( self, strings: List[List[bytes]], shape: Dict[str, Tuple[int, ...]], **kwargs ) -> Dict[str, Any]: *y_strings_, z_strings = strings assert all(len(y_strings) == len(z_strings) for y_strings in y_strings_) hyper_out = self.latent_codec["hyper"].decompress([z_strings], shape["hyper"]) y_out = self.latent_codec["y"].decompress( y_strings_, shape["y"], hyper_out["params"] ) return {"y_hat": y_out["y_hat"]}