compressai_trainer.config#
Configuration system with dynamic object creation.
config#
- compressai_trainer.config.config.create_criterion(conf: omegaconf.dictconfig.DictConfig) torch.nn.modules.module.Module [source]#
- compressai_trainer.config.config.create_dataloaders(conf: omegaconf.dictconfig.DictConfig) dict[str, torch.utils.data.dataloader.DataLoader] [source]#
- compressai_trainer.config.config.create_model(conf: omegaconf.dictconfig.DictConfig) torch.nn.modules.module.Module [source]#
- compressai_trainer.config.config.create_module(conf: omegaconf.dictconfig.DictConfig) torch.nn.modules.module.Module [source]#
- compressai_trainer.config.config.create_optimizer(conf: omegaconf.dictconfig.DictConfig, net: torch.nn.modules.module.Module) Union[torch.optim.optimizer.Optimizer, Dict[str, torch.optim.optimizer.Optimizer]] [source]#
- compressai_trainer.config.config.create_scheduler(conf: omegaconf.dictconfig.DictConfig, optimizer: Union[torch.optim.optimizer.Optimizer, Dict[str, torch.optim.optimizer.Optimizer]]) dict[str, Union[torch.optim.lr_scheduler.ReduceLROnPlateau, torch.optim.lr_scheduler._LRScheduler]] [source]#
dataset#
- class compressai_trainer.config.dataset.DatasetTuple(transform: 'transforms.Compose', dataset: 'TDataset', loader: 'TDataLoader')[source]#
- dataset: torch.utils.data.dataset.Dataset#
- transform: torchvision.transforms.transforms.Compose#
- compressai_trainer.config.dataset.create_data_transform(transform_conf: omegaconf.dictconfig.DictConfig) Callable [source]#
- compressai_trainer.config.dataset.create_data_transform_composition(conf: omegaconf.dictconfig.DictConfig) torchvision.transforms.transforms.Compose [source]#
- compressai_trainer.config.dataset.create_dataloader(conf: omegaconf.dictconfig.DictConfig, dataset: torch.utils.data.dataset.Dataset, device: str) torch.utils.data.dataloader.DataLoader [source]#
- compressai_trainer.config.dataset.create_dataset(conf: omegaconf.dictconfig.DictConfig, transform: Callable) torch.utils.data.dataset.Dataset [source]#
- compressai_trainer.config.dataset.create_dataset_tuple(conf: omegaconf.dictconfig.DictConfig, device: str) compressai_trainer.config.dataset.DatasetTuple [source]#
engine#
- compressai_trainer.config.engine.configure_engine(conf: omegaconf.dictconfig.DictConfig) dict[str, Any] [source]#
- compressai_trainer.config.engine.create_callback(conf: omegaconf.dictconfig.DictConfig) catalyst.core.callback.Callback [source]#
- compressai_trainer.config.engine.create_logger(conf: omegaconf.dictconfig.DictConfig, logger_type: str) catalyst.core.logger.ILogger [source]#
- compressai_trainer.config.engine.create_runner(conf: omegaconf.dictconfig.DictConfig) catalyst.runners.runner.Runner [source]#
env#
load#
- compressai_trainer.config.load.get_checkpoint_path(conf: Mapping[str, Any], epoch: int | str = 'best') str [source]#
Returns checkpoint path for given conf.
- Parameters
conf – Configuration.
epoch – Exact epoch (int) or “best” or “last” (str).
- compressai_trainer.config.load.load_checkpoint(conf: DictConfig, *, epoch: int | str = 'best', warn_only: bool = True) nn.Module [source]#
Loads particular checkpoint for given conf.
A particular model is a function of:
Hyperparameters/configuration
Source code
Checkpoint file
This tries to reassemble/verify the same environment.
- compressai_trainer.config.load.load_config(run_root: str) omegaconf.dictconfig.DictConfig [source]#
Returns config file given run root path.
Example of run root path:
/path/to/runs/e4e6d4d5e5c59c69f3bd7be2
.
- compressai_trainer.config.load.load_model(conf: omegaconf.dictconfig.DictConfig) torch.nn.modules.module.Module [source]#
Load a model from one of various sources.
The source is determined by setting the config setting
model.source
to one of the following:- “config”:
Uses CompressAI Trainer standard config. (e.g.
hp
,paths.model_checkpoint
, etc.)
- “from_state_dict”:
Uses model’s
from_state_dict()
factory method. Requiresmodel.name
andpaths.model_checkpoint
to be set. For example:model: name: "bmshj2018-factorized" paths: model_checkpoint: "/home/user/.cache/torch/hub/checkpoints/bmshj2018-factorized-prior-3-5c6f152b.pth.tar"
- “zoo”:
Uses CompressAI’s zoo of models. Requires
model.name
,model.metric
,model.quality
, andmodel.pretrained
to be set. For example:model: name: "bmshj2018-factorized" metric: "mse" quality: 3 pretrained: True
- compressai_trainer.config.load.state_dict_from_checkpoint(ckpt) OrderedDict[str, torch.Tensor] [source]#
Gets model state dict, with fallback for non-Catalyst trained models.