Train your own model#
In this tutorial we are going to implement a custom auto encoder architecture by using some modules and layers pre-defined in CompressAI.
For a complete runnable example, check out the train.py
script in the
examples/
folder of the CompressAI source tree.
Defining a custom model#
Let’s build a simple auto encoder with an
EntropyBottleneck
module, 3 convolutions at
the encoder, 3 transposed deconvolutions for the decoder, and
GDN
activation functions:
import torch.nn as nn
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import GDN
class Network(nn.Module):
def __init__(self, N=128):
super().__init__()
self.entropy_bottleneck = EntropyBottleneck(N)
self.encode = nn.Sequential(
nn.Conv2d(3, N, stride=2, kernel_size=5, padding=2),
GDN(N)
nn.Conv2d(N, N, stride=2, kernel_size=5, padding=2),
GDN(N)
nn.Conv2d(N, N, stride=2, kernel_size=5, padding=2),
)
self.decode = nn.Sequential(
nn.ConvTranspose2d(N, N, kernel_size=5, padding=2, output_padding=1, stride=2)
GDN(N, inverse=True),
nn.ConvTranspose2d(N, N, kernel_size=5, padding=2, output_padding=1, stride=2)
GDN(N, inverse=True),
nn.ConvTranspose2d(N, 3, kernel_size=5, padding=2, output_padding=1, stride=2)
)
def forward(self, x):
y = self.encode(x)
y_hat, y_likelihoods = self.entropy_bottleneck(y)
x_hat = self.decode(y_hat)
return x_hat, y_likelihoods
The convolutions are strided to reduce the spatial dimensions of the tensor, while increasing the number of channels (which helps to learn better latent representation). The bottleneck module is used to obtain a differentiable entropy estimation of the latent tensors while training.
Note
See the original paper: “Variational image compression with a scale hyperprior”, and the tensorflow/compression documentation for a detailed explanation of the EntropyBottleneck module.
Loss functions#
1. Rate distortion loss#
We are going to define a simple rate-distortion loss, which maximizes the
PSNR reconstruction (RGB) and minimizes the length (in bits) of the quantized
latent tensor (y_hat
).
A scalar is used to balance between the reconstruction quality and the bit-rate (like the JPEG quality parameter, or the QP with HEVC):
import math
import torch.nn as nn
import torch.nn.functional as F
x = torch.rand(1, 3, 64, 64)
net = Network()
x_hat, y_likelihoods = net(x)
# bitrate of the quantized latent
N, _, H, W = x.size()
num_pixels = N * H * W
bpp_loss = torch.log(y_likelihoods).sum() / (-math.log(2) * num_pixels)
# mean square error
mse_loss = F.mse_loss(x, x_hat)
# final loss term
loss = mse_loss + lmbda * bpp_loss
Note
It’s possible to train architectures that can handle multiple bit-rate distortion points but that’s outside the scope of this tutorial. See this paper: “Variable Rate Deep Image Compression With a Conditional Autoencoder” for a good example.
2. Auxiliary loss#
The entropy bottleneck parameters need to be trained to minimize the density
model evaluation of the latent elements. The auxiliary loss is accessible
through the entropy_bottleneck
layer:
aux_loss = net.entropy_bottleneck.loss()
The auxiliary loss must be minimized during or after the training of the network.
Optimizers#
To train both the compression network and the entropy bottleneck densities
estimation, we will thus need two optimizers. To simplify the implementation,
CompressAI provides a CompressionModel
base class,
that includes an EntropyBottleneck
module
and some helper methods, let’s rewrite our network:
from compressai.models import CompressionModel
from compressai.models.utils import conv, deconv
class Network(CompressionModel):
def __init__(self, N=128):
super().__init__()
self.encode = nn.Sequential(
conv(3, N),
GDN(N)
conv(N, N),
GDN(N)
conv(N, N),
)
self.decode = nn.Sequential(
deconv(N, N),
GDN(N, inverse=True),
deconv(N, N),
GDN(N, inverse=True),
deconv(N, 3),
)
def forward(self, x):
y = self.encode(x)
y_hat, y_likelihoods = self.entropy_bottleneck(y)
x_hat = self.decode(y_hat)
return x_hat, y_likelihoods
Now, we can simply access the two sets of trainable parameters:
import torch.optim as optim
parameters = set(p for n, p in net.named_parameters() if not n.endswith(".quantiles"))
aux_parameters = set(p for n, p in net.named_parameters() if n.endswith(".quantiles"))
optimizer = optim.Adam(parameters, lr=1e-4)
aux_optimizer = optim.Adam(aux_parameters, lr=1e-3)
Note
You can also use torch.optim.Optimizer
parameter groups to define a single optimizer.
Training loop#
And write a training loop:
x = torch.rand(1, 3, 64, 64)
for i in range(10):
optimizer.zero_grad()
aux_optimizer.zero_grad()
x_hat, y_likelihoods = net(x)
# ...
# compute loss as before
# ...
loss.backward()
optimizer.step()
aux_loss = net.aux_loss()
aux_loss.backward()
aux_optimizer.step()