# Copyright (c) 2021-2023, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
import random
import time
from itertools import chain
from typing import Any, List, Optional, TypeVar, cast
import pandas as pd
import torch
import torch.nn.functional as F
from catalyst import metrics
from compressai.models.base import CompressionModel
from compressai.typing import TCriterion
from compressai_trainer.registry import register_runner
from compressai_trainer.utils.metrics import compute_metrics
from compressai_trainer.utils.utils import compute_padding, flatten_values, ld_to_dl
from .base import BaseRunner
from .image_compression import (
RD_PLOT_DESCRIPTIONS,
RD_PLOT_METRICS,
RD_PLOT_SETTINGS_COMMON,
RD_PLOT_TITLE,
)
from .utils import (
ChannelwiseBppMeter,
DebugOutputsLogger,
EbDistributionsFigureLogger,
GradientClipper,
RdFigureLogger,
)
K = TypeVar("K")
V = TypeVar("V")
[docs]@register_runner("GVAEImageCompressionRunner")
class GVAEImageCompressionRunner(BaseRunner):
"""Runner for image compression experiments.
Reimplementation of CompressAI's `examples/train.py
<https://github.com/InterDigitalInc/CompressAI/blob/master/examples/train.py>`_,
with additional functionality such as:
- Plots RD curves, learned entropy bottleneck distributions,
and histograms for latent channel-wise rate distributions.
- Saves inference outputs including images and featuremaps.
Set the input arguments by overriding the defaults in
``conf/runner/ImageCompressionRunner.yaml``.
"""
def __init__(
self,
inference: dict[str, Any],
meters: dict[str, list[str]],
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._inference_kwargs = inference
self._meter_keys = meters
self._grad_clip = GradientClipper(self)
self._debug_outputs_logger = DebugOutputsLogger(self)
self._eb_distributions_figure_logger = EbDistributionsFigureLogger(self)
self._rd_figure_logger = RdFigureLogger(self)
[docs] def on_loader_start(self, runner):
super().on_loader_start(runner)
self._setup_metrics()
[docs] def handle_batch(self, batch):
if self.loader_key == "infer":
return self._handle_batch_infer(batch)
# Choose random lambda for this batch.
lmbda_idx = random.randint(0, len(self._lmbdas) - 1)
x = batch
out_net = self.model(x, lmbda_idx=lmbda_idx)
out_criterion = self.criterion(out_net, x)
loss = {
"net": out_criterion["loss"],
"aux": self.model_module.aux_loss(),
}
if self.loader_key == "train":
loss["net"].backward()
loss["aux"].backward()
self._grad_clip()
self.optimizer["net"].step()
self.optimizer["aux"].step()
self.optimizer["net"].zero_grad()
self.optimizer["aux"].zero_grad()
batch_metrics = {
"loss": loss["net"],
"aux_loss": loss["aux"],
**out_criterion,
"lmbda": self._lmbdas[lmbda_idx],
}
# Append suffixes, i.e. f"{...}_{lmbda_idx}".
batch_metrics = self._flatten_batch_metricses([batch_metrics], [lmbda_idx])
self._update_batch_metrics(batch_metrics)
def _handle_batch_infer(self, batch):
x = batch.to(self.engine.device)
# Run inference for each lambda, then flatten results.
batch_metrics = self._flatten_batch_metricses(
[
self._handle_batch_infer_lmbda(x, lmbda_idx=lmbda_idx)
for lmbda_idx in self._lmbda_idxs
],
self._lmbda_idxs,
)
self._update_batch_metrics(batch_metrics)
# Save per-sample metrics, too.
for metric in self._loader_metrics.keys():
if metric in batch_metrics:
self._loader_metrics[metric].append(batch_metrics[metric])
def _handle_batch_infer_lmbda(self, x, lmbda_idx):
out_infer = self.predict_batch(x, lmbda_idx=lmbda_idx, **self._inference_kwargs)
out_net = out_infer["out_net"]
out_criterion = self.criterion(out_net, x)
out_metrics = compute_metrics(x, out_net["x_hat"], RD_PLOT_METRICS)
out_metrics["bpp"] = out_infer["bpp"]
loss = {
"net": out_criterion["loss"],
"aux": self.model_module.aux_loss(),
}
batch_metrics = {
"loss": loss["net"],
"aux_loss": loss["aux"],
**out_criterion,
**out_metrics,
"bpp": out_infer["bpp"],
"lmbda": self._lmbdas[lmbda_idx],
}
self._debug_outputs_logger.log(x, out_infer, context={"lmbda_idx": lmbda_idx})
self._loader_metrics[f"chan_bpp_{lmbda_idx}"].update(out_net)
return batch_metrics
[docs] def predict_batch(self, batch, lmbda_idx=None, lmbda=None, **kwargs):
x = batch.to(self.engine.device)
if lmbda_idx is not None:
assert lmbda is None or lmbda == self._lmbdas[lmbda_idx]
lmbda = self._lmbdas[lmbda_idx]
return inference(
self.model_module, x, criterion=self.criterion, lmbda=lmbda, **kwargs
)
[docs] def on_loader_end(self, runner):
super().on_loader_end(runner)
if self.loader_key == "infer":
self._log_rd_curves()
self._eb_distributions_figure_logger.log(
log_kwargs=dict(track_kwargs=dict(step=0))
)
for i in range(len(self._lmbdas)):
self._loader_metrics[f"chan_bpp_{i}"].log(context={"lmbda_idx": i})
@property
def _current_dataframe(self):
r = lambda x: float(f"{x:.6g}") # noqa: E731
d = {
"name": [self.hparams["model"]["name"] + "*" for _ in self._lmbda_idxs],
"epoch": [self.epoch_step for _ in self._lmbda_idxs],
"criterion.lmbda": self._lmbdas,
"loss": [r(self.loader_metrics[f"loss_{i}"]) for i in self._lmbda_idxs],
"bpp": [r(self.loader_metrics[f"bpp_{i}"]) for i in self._lmbda_idxs],
"psnr": [r(self.loader_metrics[f"psnr_{i}"]) for i in self._lmbda_idxs],
"ms-ssim": [
r(self.loader_metrics[f"ms-ssim_{i}"]) for i in self._lmbda_idxs
],
# dB of the mean of MS-SSIM samples:
# "ms-ssim-db": [
# r(db(1 - self.loader_metrics[f"ms-ssim_{i}"]))
# for i in self._lmbda_idxs
# ],
# Mean of MS-SSIM dB samples:
"ms-ssim-db": [
r(self.loader_metrics[f"ms-ssim-db_{i}"]) for i in self._lmbda_idxs
],
}
return pd.DataFrame.from_dict(d)
def _current_traces(self, metric):
return [
trace
for lmbda_idx, lmbda in enumerate(self._lmbdas)
for trace in self._rd_figure_logger.current_rd_traces(
x=f"bpp_{lmbda_idx}", y=f"{metric}_{lmbda_idx}", lmbda=lmbda
)
]
def _log_rd_curves(self, **kwargs):
return [
self._log_rd_curves_figure(metric, description, **kwargs)
for metric, description in zip(RD_PLOT_METRICS, RD_PLOT_DESCRIPTIONS)
]
def _log_rd_curves_figure(
self, metric, description, df=None, traces=None, **kwargs
):
if df is None:
df = self._current_dataframe
if traces is None:
traces = self._current_traces(metric)
meta = self.hparams["dataset"]["infer"]["meta"]
return self._rd_figure_logger.log(
df=df,
traces=traces,
metric=metric,
dataset=meta["identifier"],
**RD_PLOT_SETTINGS_COMMON,
layout_kwargs=dict(
title=RD_PLOT_TITLE.format(
dataset=meta["name"],
metric=description,
),
),
**kwargs,
)
def _flatten_batch_metricses(self, batch_metricses, lmbda_idxs):
# Flatten metrics for each lambda.
singles = {
f"{metric_name}_{lmbda_idx}": value
for lmbda_idx, batch_metrics in zip(lmbda_idxs, batch_metricses)
for metric_name, value in batch_metrics.items()
}
# Average metrics over lambdas.
averages = {
metric_name: sum(values) / len(values)
for metric_name, values in ld_to_dl(batch_metricses).items()
}
return {**singles, **averages}
def _setup_metrics(self):
# Expand any meters containing a * to work with multiple lmbdas.
meter_keys = [
[meter_name]
if "*" not in meter_name
else [meter_name.replace("*", str(i)) for i in self._lmbda_idxs]
for meter_name in self._meter_keys[self.loader_key]
]
meter_keys = list(chain(*meter_keys)) # Flatten list of lists.
self.batch_meters = {
key: metrics.AdditiveMetric(compute_on_call=False) for key in meter_keys
}
self._loader_metrics = {
**{f"chan_bpp_{i}": ChannelwiseBppMeter(self) for i in self._lmbda_idxs},
**{k: [] for k in ["bpp", *RD_PLOT_METRICS]},
**{
f"{meter_name}_{i}": []
for meter_name in ["bpp", *RD_PLOT_METRICS]
for i in self._lmbda_idxs
},
}
@property
def _lmbdas(self) -> List[float]:
# Alternative:
# return cast(List[float], list(self.hparams["hp"]["lambdas"]))
return cast(List[float], list(self.model_module.lambdas))
@property
def _lmbda_idxs(self) -> List[int]:
return list(range(len(self._lmbdas)))
@torch.no_grad()
def inference(
model: CompressionModel,
x: torch.Tensor,
skip_compress: bool = False,
skip_decompress: bool = False,
criterion: Optional[TCriterion] = None,
min_div: int = 64,
*,
lmbda: float = None,
) -> dict[str, Any]:
"""Run compression model on image batch."""
n, _, h, w = x.shape
pad, unpad = compute_padding(h, w, min_div=min_div)
x_padded = F.pad(x, pad, mode="constant", value=0)
# Compress using forward.
out_net = model(x_padded, lmbda=lmbda)
out_net["x_hat"] = F.pad(out_net["x_hat"], unpad)
out_net["x_hat"] = out_net["x_hat"].clamp_(0, 1)
# Compress using compress/decompress.
if not skip_compress:
start = time.time()
out_enc = model.compress(x_padded, lmbda=lmbda)
enc_time = time.time() - start
else:
out_enc = {}
enc_time = None
if not skip_decompress:
assert not skip_compress
start = time.time()
out_dec = model.decompress(out_enc["strings"], out_enc["shape"], lmbda=lmbda)
dec_time = time.time() - start
out_dec["x_hat"] = F.pad(out_dec["x_hat"], unpad)
else:
out_dec = dict(out_net)
del out_dec["likelihoods"]
dec_time = None
# Compute bpp.
if not skip_compress:
num_bits = sum(len(s) for s in flatten_values(out_enc["strings"], bytes)) * 8.0
num_pixels = n * h * w
bpp = num_bits / num_pixels
else:
out_criterion = criterion(out_net, x, lmbda=lmbda)
bpp = out_criterion["bpp_loss"].item()
return {
"out_net": out_net,
"out_enc": out_enc,
"out_dec": out_dec,
"bpp": bpp,
"encoding_time": enc_time,
"decoding_time": dec_time,
}