compressai_trainer.config#

Configuration system with dynamic object creation.

config#

compressai_trainer.config.config.configure_conf(conf: omegaconf.dictconfig.DictConfig)[source]#
compressai_trainer.config.config.create_criterion(conf: omegaconf.dictconfig.DictConfig) torch.nn.modules.module.Module[source]#
compressai_trainer.config.config.create_dataloaders(conf: omegaconf.dictconfig.DictConfig) dict[str, torch.utils.data.dataloader.DataLoader][source]#
compressai_trainer.config.config.create_model(conf: omegaconf.dictconfig.DictConfig) torch.nn.modules.module.Module[source]#
compressai_trainer.config.config.create_module(conf: omegaconf.dictconfig.DictConfig) torch.nn.modules.module.Module[source]#
compressai_trainer.config.config.create_optimizer(conf: omegaconf.dictconfig.DictConfig, net: torch.nn.modules.module.Module) Union[torch.optim.optimizer.Optimizer, Dict[str, torch.optim.optimizer.Optimizer]][source]#
compressai_trainer.config.config.create_scheduler(conf: omegaconf.dictconfig.DictConfig, optimizer: Union[torch.optim.optimizer.Optimizer, Dict[str, torch.optim.optimizer.Optimizer]]) dict[str, Union[torch.optim.lr_scheduler.ReduceLROnPlateau, torch.optim.lr_scheduler._LRScheduler]][source]#

dataset#

class compressai_trainer.config.dataset.DatasetTuple(transform: 'transforms.Compose', dataset: 'TDataset', loader: 'TDataLoader')[source]#
dataset: torch.utils.data.dataset.Dataset#
loader: torch.utils.data.dataloader.DataLoader#
transform: torchvision.transforms.transforms.Compose#
compressai_trainer.config.dataset.create_data_transform(transform_conf: omegaconf.dictconfig.DictConfig) Callable[source]#
compressai_trainer.config.dataset.create_data_transform_composition(conf: omegaconf.dictconfig.DictConfig) torchvision.transforms.transforms.Compose[source]#
compressai_trainer.config.dataset.create_dataloader(conf: omegaconf.dictconfig.DictConfig, dataset: torch.utils.data.dataset.Dataset, device: str) torch.utils.data.dataloader.DataLoader[source]#
compressai_trainer.config.dataset.create_dataset(conf: omegaconf.dictconfig.DictConfig, transform: Callable) torch.utils.data.dataset.Dataset[source]#
compressai_trainer.config.dataset.create_dataset_tuple(conf: omegaconf.dictconfig.DictConfig, device: str) compressai_trainer.config.dataset.DatasetTuple[source]#

engine#

compressai_trainer.config.engine.configure_engine(conf: omegaconf.dictconfig.DictConfig) dict[str, Any][source]#
compressai_trainer.config.engine.create_callback(conf: omegaconf.dictconfig.DictConfig) catalyst.core.callback.Callback[source]#
compressai_trainer.config.engine.create_logger(conf: omegaconf.dictconfig.DictConfig, logger_type: str) catalyst.core.logger.ILogger[source]#
compressai_trainer.config.engine.create_runner(conf: omegaconf.dictconfig.DictConfig) catalyst.runners.runner.Runner[source]#

env#

compressai_trainer.config.env.get_env(conf: omegaconf.dictconfig.DictConfig) dict[str, Any][source]#

load#

compressai_trainer.config.load.get_checkpoint_path(conf: Mapping[str, Any], epoch: int | str = 'best') str[source]#

Returns checkpoint path for given conf.

Parameters
  • conf – Configuration.

  • epoch – Exact epoch (int) or “best” or “last” (str).

compressai_trainer.config.load.load_checkpoint(conf: DictConfig, *, epoch: int | str = 'best', warn_only: bool = True) nn.Module[source]#

Loads particular checkpoint for given conf.

A particular model is a function of:

  • Hyperparameters/configuration

  • Source code

  • Checkpoint file

This tries to reassemble/verify the same environment.

compressai_trainer.config.load.load_config(run_root: str) omegaconf.dictconfig.DictConfig[source]#

Returns config file given run root path.

Example of run root path: /path/to/runs/e4e6d4d5e5c59c69f3bd7be2.

compressai_trainer.config.load.load_model(conf: omegaconf.dictconfig.DictConfig) torch.nn.modules.module.Module[source]#

Load a model from one of various sources.

The source is determined by setting the config setting model.source to one of the following:

  • “config”:

    Uses CompressAI Trainer standard config. (e.g. hp, paths.model_checkpoint, etc.)

  • “from_state_dict”:

    Uses model’s from_state_dict() factory method. Requires model.name and paths.model_checkpoint to be set. For example:

    model:
      name: "bmshj2018-factorized"
    paths:
      model_checkpoint: "/home/user/.cache/torch/hub/checkpoints/bmshj2018-factorized-prior-3-5c6f152b.pth.tar"
    
  • “zoo”:

    Uses CompressAI’s zoo of models. Requires model.name, model.metric, model.quality, and model.pretrained to be set. For example:

    model:
      name: "bmshj2018-factorized"
      metric: "mse"
      quality: 3
      pretrained: True
    
compressai_trainer.config.load.state_dict_from_checkpoint(ckpt) OrderedDict[str, torch.Tensor][source]#

Gets model state dict, with fallback for non-Catalyst trained models.

outputs#

compressai_trainer.config.outputs.write_config(conf: omegaconf.dictconfig.DictConfig)[source]#
compressai_trainer.config.outputs.write_git_diff(conf: Mapping[str, Any], package: module) str[source]#
compressai_trainer.config.outputs.write_outputs(conf: omegaconf.dictconfig.DictConfig)[source]#
compressai_trainer.config.outputs.write_pip_list(conf: Mapping[str, Any]) str[source]#
compressai_trainer.config.outputs.write_pip_requirements(conf: Mapping[str, Any]) str[source]#