compressai_trainer.runners#
Catalyst Runner
s describe the process for training a model.
The following functions are called during the training loop:
on_experiment_start # Once, at the beginning.
on_epoch_start # Beginning of an epoch.
on_loader_start # For each loader (train / valid / infer).
on_batch_start # Before each batch.
handle_batch # For each image batch.
on_batch_end
on_loader_end
on_epoch_end
on_experiment_end
The training loop is effectively equivalent to:
on_experiment_start()
for epoch in range(1, num_epochs):
on_epoch_start()
for loader in ["train", "valid", "infer"]:
on_loader_start()
for batch in loader:
on_batch_start()
handle_batch(batch)
on_batch_end()
on_loader_end()
on_epoch_end()
on_experiment_end()
Please see the Catalyst documentation for more information.
We provide the following pre-made runners:
BaseRunner
(base compression class)VideoCompressionRunner
(future release)
For guidance on defining your own runner, see: Defining a custom Runner training loop.
- class compressai_trainer.runners.BaseRunner(*args, **kwargs)[source]#
Bases:
catalyst.runners.runner.Runner
,compressai_trainer.utils.catalyst.loggers.logger.AllSuperlogger
Generic runner for all CompressAI Trainer experiments.
See the
catalyst.dl.Runner
documentation for info on runners.BaseRunner
provides functionality for common tasks such as:Logging environment: git hashes/diff, pip list, YAML config.
Logging model basic info: num params, weight shapes, etc.
Batch meters that aggregate (e.g. average) per-loader metrics (e.g. loss) which are collected per-batch.
Calls
model.update()
before inference (i.e. test).
- criterion: TorchCriterion#
- model: CompressionModel | DataParallel | DistributedDataParallel#
- property model_module: compressai.models.base.CompressionModel#
Returns model instance.
- class compressai_trainer.runners.GVAEImageCompressionRunner(inference: dict[str, Any], meters: dict[str, list[str]], *args, **kwargs)[source]#
Bases:
compressai_trainer.runners.base.BaseRunner
Runner for image compression experiments.
Reimplementation of CompressAI’s examples/train.py, with additional functionality such as:
Plots RD curves, learned entropy bottleneck distributions, and histograms for latent channel-wise rate distributions.
Saves inference outputs including images and featuremaps.
Set the input arguments by overriding the defaults in
conf/runner/ImageCompressionRunner.yaml
.- handle_batch(batch)[source]#
Inner method to handle specified data batch. Used to make a train/valid/infer step during Experiment run.
- Parameters
batch (Mapping[str, Any]) – dictionary with data batches from DataLoader.
- predict_batch(batch, lmbda_idx=None, lmbda=None, **kwargs)[source]#
Run model inference on specified data batch.
- Parameters
batch – dictionary with data batches from DataLoader.
**kwargs – additional kwargs to pass to the model
- Returns: # noqa: DAR202
Mapping: model output dictionary
- Raises
NotImplementedError – if not implemented yet
- class compressai_trainer.runners.ImageCompressionRunner(inference: dict[str, Any], meters: dict[str, list[str]], *args, **kwargs)[source]#
Bases:
compressai_trainer.runners.base.BaseRunner
Runner for image compression experiments.
Reimplementation of CompressAI’s examples/train.py, with additional functionality such as:
Plots RD curves, learned entropy bottleneck distributions, and histograms for latent channel-wise rate distributions.
Saves inference outputs including images and featuremaps.
Set the input arguments by overriding the defaults in
conf/runner/ImageCompressionRunner.yaml
.- handle_batch(batch)[source]#
Inner method to handle specified data batch. Used to make a train/valid/infer step during Experiment run.
- Parameters
batch (Mapping[str, Any]) – dictionary with data batches from DataLoader.
- predict_batch(batch, **kwargs)[source]#
Run model inference on specified data batch.
- Parameters
batch – dictionary with data batches from DataLoader.
**kwargs – additional kwargs to pass to the model
- Returns: # noqa: DAR202
Mapping: model output dictionary
- Raises
NotImplementedError – if not implemented yet