Training#
An example training script train.py
is provided script in the
examples/
folder of the CompressAI source tree.
Example:
python3 examples/train.py -m mbt2018-mean -d /path/to/image/dataset \
--batch-size 16 -lr 1e-4 --save --cuda
Run train.py –help to list the available options. See also the model zoo training section to reproduce the performances of the pretrained models.
Model update#
Once a model has been trained, you need to run the update_model
script
to update the internal parameters of the entropy bottlenecks:
python -m compressai.utils.update_model --architecture ARCH checkpoint_best_loss.pth.tar
This will modify the buffers related to the learned cumulative distribution functions (CDFs) required to perform the actual entropy coding.
You can run python -m compressai.utils.update_model --help
to get the
complete list of options.
Alternatively, you can call the update()
method of a CompressionModel
or
EntropyBottleneck
instance at the end of your
training script, before saving the model checkpoint.
Model evaluation#
Once a model checkpoint has been updated, you can use eval_model
to get
its performances on an image dataset:
python -m compressai.utils.eval_model checkpoint /path/to/image/dataset \
-a ARCH -p path/to/checkpoint-xxxxxxxx.pth.tar
You can run python -m compressai.utils.eval_model --help
to get the
complete list of options.
Entropy coding#
By default CompressAI uses a range Asymmetric Numeral Systems (ANS) entropy
coder. You can use compressai.available_entropy_coders()
to get a list
of the implemented entropy coders and change the default entropy coder via
compressai.set_entropy_coder()
.
Compress an image tensor to a bit-stream:
x = torch.rand(1, 3, 64, 64)
y = net.encode(x)
strings = net.entropy_bottleneck.compress(y)
Decompress a bit-stream to an image tensor:
shape = y.size()[2:]
y_hat = net.entropy_bottleneck.decompress(strings, shape)
x_hat = net.decode(y_hat)