compressai_trainer.runners#

Catalyst Runners describe the process for training a model.

The following functions are called during the training loop:

Runner training loop call order.#
on_experiment_start   # Once, at the beginning.
  on_epoch_start      # Beginning of an epoch.
    on_loader_start   # For each loader (train / valid / infer).
      on_batch_start  # Before each batch.
        handle_batch  # For each image batch.
      on_batch_end
    on_loader_end
  on_epoch_end
on_experiment_end

The training loop is effectively equivalent to:

Runner training loop pseudo-code.#
on_experiment_start()

for epoch in range(1, num_epochs):
    on_epoch_start()

    for loader in ["train", "valid", "infer"]:
        on_loader_start()

        for batch in loader:
            on_batch_start()
            handle_batch(batch)
            on_batch_end()

        on_loader_end()

    on_epoch_end()

on_experiment_end()

Please see the Catalyst documentation for more information.

We provide the following pre-made runners:

For guidance on defining your own runner, see: Defining a custom Runner training loop.

class compressai_trainer.runners.BaseRunner(*args, **kwargs)[source]#

Bases: catalyst.runners.runner.Runner, compressai_trainer.utils.catalyst.loggers.logger.AllSuperlogger

Generic runner for all CompressAI Trainer experiments.

See the catalyst.dl.Runner documentation for info on runners.

BaseRunner provides functionality for common tasks such as:

  • Logging environment: git hashes/diff, pip list, YAML config.

  • Logging model basic info: num params, weight shapes, etc.

  • Batch meters that aggregate (e.g. average) per-loader metrics (e.g. loss) which are collected per-batch.

  • Calls model.update() before inference (i.e. test).

batch_meters: dict[str, metrics.IMetric]#
criterion: TorchCriterion#
log_image(*args, **kwargs)[source]#

Logs image to available loggers.

model: CompressionModel | DataParallel | DistributedDataParallel#
property model_module: compressai.models.base.CompressionModel#

Returns model instance.

on_epoch_end(runner)[source]#

Event handler.

on_epoch_start(runner)[source]#

Event handler.

on_experiment_end(runner)[source]#

Event handler.

on_experiment_start(runner)[source]#

Event handler.

on_loader_end(runner)[source]#

Event handler.

on_loader_start(runner)[source]#

Event handler.

optimizer: dict[str, TorchOptimizer]#
class compressai_trainer.runners.GVAEImageCompressionRunner(inference: dict[str, Any], meters: dict[str, list[str]], *args, **kwargs)[source]#

Bases: compressai_trainer.runners.base.BaseRunner

Runner for image compression experiments.

Reimplementation of CompressAI’s examples/train.py, with additional functionality such as:

  • Plots RD curves, learned entropy bottleneck distributions, and histograms for latent channel-wise rate distributions.

  • Saves inference outputs including images and featuremaps.

Set the input arguments by overriding the defaults in conf/runner/ImageCompressionRunner.yaml.

handle_batch(batch)[source]#

Inner method to handle specified data batch. Used to make a train/valid/infer step during Experiment run.

Parameters

batch (Mapping[str, Any]) – dictionary with data batches from DataLoader.

on_loader_end(runner)[source]#

Event handler.

on_loader_start(runner)[source]#

Event handler.

predict_batch(batch, lmbda_idx=None, lmbda=None, **kwargs)[source]#

Run model inference on specified data batch.

Parameters
  • batch – dictionary with data batches from DataLoader.

  • **kwargs – additional kwargs to pass to the model

Returns: # noqa: DAR202

Mapping: model output dictionary

Raises

NotImplementedError – if not implemented yet

class compressai_trainer.runners.ImageCompressionRunner(inference: dict[str, Any], meters: dict[str, list[str]], *args, **kwargs)[source]#

Bases: compressai_trainer.runners.base.BaseRunner

Runner for image compression experiments.

Reimplementation of CompressAI’s examples/train.py, with additional functionality such as:

  • Plots RD curves, learned entropy bottleneck distributions, and histograms for latent channel-wise rate distributions.

  • Saves inference outputs including images and featuremaps.

Set the input arguments by overriding the defaults in conf/runner/ImageCompressionRunner.yaml.

handle_batch(batch)[source]#

Inner method to handle specified data batch. Used to make a train/valid/infer step during Experiment run.

Parameters

batch (Mapping[str, Any]) – dictionary with data batches from DataLoader.

on_loader_end(runner)[source]#

Event handler.

on_loader_start(runner)[source]#

Event handler.

predict_batch(batch, **kwargs)[source]#

Run model inference on specified data batch.

Parameters
  • batch – dictionary with data batches from DataLoader.

  • **kwargs – additional kwargs to pass to the model

Returns: # noqa: DAR202

Mapping: model output dictionary

Raises

NotImplementedError – if not implemented yet