Source code for compressai_trainer.runners.image_compression

# Copyright (c) 2021-2023, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from __future__ import annotations

import time
from typing import Any, Optional

import pandas as pd
import torch
import torch.nn.functional as F
from catalyst import metrics
from compressai.models.base import CompressionModel
from compressai.typing import TCriterion

from compressai_trainer.registry import register_runner
from compressai_trainer.utils.metrics import compute_metrics
from compressai_trainer.utils.utils import compute_padding, flatten_values

from .base import BaseRunner
from .utils import (
    ChannelwiseBppMeter,
    DebugOutputsLogger,
    EbDistributionsFigureLogger,
    GradientClipper,
    RdFigureLogger,
)

RD_PLOT_TITLE = "Performance evaluation on {dataset} - {metric}"

RD_PLOT_METRICS = [
    "psnr",
    "ms-ssim",
    "ms-ssim-db",
]

RD_PLOT_DESCRIPTIONS = [
    "PSNR (RGB)",
    "MS-SSIM (RGB)",
    "MS-SSIM (RGB)",
]

RD_PLOT_SETTINGS_COMMON: dict[str, Any] = dict(
    codecs=[
        "image/kodak/compressai-bmshj2018-factorized_mse_cuda.json",
        "image/kodak/compressai-bmshj2018-hyperprior_mse_cuda.json",
        "image/kodak/compressai-mbt2018-mean_mse_cuda.json",
        "image/kodak/compressai-mbt2018_mse_cuda.json",
        "image/kodak/compressai-cheng2020-anchor_mse_cuda.json",
        "image/kodak/vtm.json",
    ],
    scatter_kwargs=dict(
        hover_data=[
            "name",
            "bpp",
            "psnr",
            "ms-ssim",
            "loss",
            "epoch",
            "criterion.lmbda",
        ],
    ),
)


[docs]@register_runner("ImageCompressionRunner") class ImageCompressionRunner(BaseRunner): """Runner for image compression experiments. Reimplementation of CompressAI's `examples/train.py <https://github.com/InterDigitalInc/CompressAI/blob/master/examples/train.py>`_, with additional functionality such as: - Plots RD curves, learned entropy bottleneck distributions, and histograms for latent channel-wise rate distributions. - Saves inference outputs including images and featuremaps. Set the input arguments by overriding the defaults in ``conf/runner/ImageCompressionRunner.yaml``. """ def __init__( self, inference: dict[str, Any], meters: dict[str, list[str]], *args, **kwargs, ): super().__init__(*args, **kwargs) self._inference_kwargs = inference self._meter_keys = meters self._grad_clip = GradientClipper(self) self._debug_outputs_logger = DebugOutputsLogger(self) self._eb_distributions_figure_logger = EbDistributionsFigureLogger(self) self._rd_figure_logger = RdFigureLogger(self)
[docs] def on_loader_start(self, runner): super().on_loader_start(runner) self._setup_metrics()
[docs] def handle_batch(self, batch): if self.loader_key == "infer": return self._handle_batch_infer(batch) x = batch out_net = self.model(x) out_criterion = self.criterion(out_net, x) loss = { "net": out_criterion["loss"], "aux": self.model_module.aux_loss(), } if self.loader_key == "train": loss["net"].backward() loss["aux"].backward() self._grad_clip() self.optimizer["net"].step() self.optimizer["aux"].step() self.optimizer["net"].zero_grad() self.optimizer["aux"].zero_grad() batch_metrics = { "loss": loss["net"], "aux_loss": loss["aux"], **out_criterion, } self._update_batch_metrics(batch_metrics)
def _handle_batch_infer(self, batch): x = batch.to(self.engine.device) out_infer = self.predict_batch(x, **self._inference_kwargs) out_net = out_infer["out_net"] out_dec = out_infer["out_dec"] out_criterion = self.criterion(out_net, x) out_metrics = compute_metrics(x, out_dec["x_hat"], RD_PLOT_METRICS) out_metrics["bpp"] = out_infer["bpp"] loss = { "net": out_criterion["loss"], "aux": self.model_module.aux_loss(), } batch_metrics = { "loss": loss["net"], "aux_loss": loss["aux"], **out_criterion, **out_metrics, "bpp": out_infer["bpp"], } self._update_batch_metrics(batch_metrics) self._handle_custom_metrics(out_net, out_metrics) self._debug_outputs_logger.log(x, out_infer)
[docs] def predict_batch(self, batch, **kwargs): x = batch.to(self.engine.device) return inference(self.model_module, x, criterion=self.criterion, **kwargs)
[docs] def on_loader_end(self, runner): super().on_loader_end(runner) if self.loader_key == "infer": self._log_rd_curves() self._eb_distributions_figure_logger.log( log_kwargs=dict(track_kwargs=dict(step=0)) ) self._loader_metrics["chan_bpp"].log()
@property def _current_dataframe(self): r = lambda x: float(f"{x:.6g}") # noqa: E731 d = { "name": self.hparams["model"]["name"] + "*", "epoch": self.epoch_step, "criterion.lmbda": self.hparams["criterion"]["lmbda"], "loss": r(self.loader_metrics["loss"]), "bpp": r(self.loader_metrics["bpp"]), "psnr": r(self.loader_metrics["psnr"]), "ms-ssim": r(self.loader_metrics["ms-ssim"]), # dB of the mean of MS-SSIM samples: # "ms-ssim-db": r(db(1 - self.loader_metrics["ms-ssim"])), # Mean of MS-SSIM dB samples: "ms-ssim-db": r(self.loader_metrics["ms-ssim-db"]), } return pd.DataFrame.from_records([d]) def _current_traces(self, metric): return self._rd_figure_logger.current_rd_traces( x="bpp", y=metric, lmbda=self.hparams["criterion"]["lmbda"] ) def _log_rd_curves(self, **kwargs): return [ self._log_rd_curves_figure(metric, description, **kwargs) for metric, description in zip(RD_PLOT_METRICS, RD_PLOT_DESCRIPTIONS) ] def _log_rd_curves_figure( self, metric, description, df=None, traces=None, **kwargs ): if df is None: df = self._current_dataframe if traces is None: traces = self._current_traces(metric) meta = self.hparams["dataset"]["infer"]["meta"] return self._rd_figure_logger.log( df=df, traces=traces, metric=metric, dataset=meta["identifier"], **RD_PLOT_SETTINGS_COMMON, layout_kwargs=dict( title=RD_PLOT_TITLE.format( dataset=meta["name"], metric=description, ), ), **kwargs, ) def _handle_custom_metrics(self, out_net, out_metrics): self._loader_metrics["chan_bpp"].update(out_net) for metric in ["bpp", *RD_PLOT_METRICS]: self._loader_metrics[metric].append(out_metrics[metric]) def _setup_metrics(self): self.batch_meters = { key: metrics.AdditiveMetric(compute_on_call=False) for key in self._meter_keys[self.loader_key] } self._loader_metrics = { "chan_bpp": ChannelwiseBppMeter(self), **{k: [] for k in ["bpp", *RD_PLOT_METRICS]}, }
@torch.no_grad() def inference( model: CompressionModel, x: torch.Tensor, skip_compress: bool = False, skip_decompress: bool = False, criterion: Optional[TCriterion] = None, min_div: int = 64, ) -> dict[str, Any]: """Run compression model on image batch.""" n, _, h, w = x.shape pad, unpad = compute_padding(h, w, min_div=min_div) x_padded = F.pad(x, pad, mode="constant", value=0) # Compress using forward. out_net = model(x_padded) out_net["x_hat"] = F.pad(out_net["x_hat"], unpad) out_net["x_hat"] = out_net["x_hat"].clamp_(0, 1) # Compress using compress/decompress. if not skip_compress: start = time.time() out_enc = model.compress(x_padded) enc_time = time.time() - start else: out_enc = {} enc_time = None if not skip_decompress: assert not skip_compress start = time.time() out_dec = model.decompress(**out_enc) dec_time = time.time() - start out_dec["x_hat"] = F.pad(out_dec["x_hat"], unpad) else: out_dec = dict(out_net) del out_dec["likelihoods"] dec_time = None # Compute bpp. if not skip_compress: num_bits = sum(len(s) for s in flatten_values(out_enc["strings"], bytes)) * 8.0 num_pixels = n * h * w bpp = num_bits / num_pixels else: out_criterion = criterion(out_net, x) bpp = out_criterion["bpp_loss"].item() return { "out_net": out_net, "out_enc": out_enc, "out_dec": out_dec, "bpp": bpp, "encoding_time": enc_time, "decoding_time": dec_time, }