Source code for compressai.models.video.google

# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import math

from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.cuda import amp

from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.layers import QReLU
from compressai.ops import quantize_ste
from compressai.registry import register_model

from ..base import CompressionModel
from ..utils import conv, deconv, gaussian_blur, gaussian_kernel2d, meshgrid2d


[docs] @register_model("ssf2020") class ScaleSpaceFlow(CompressionModel): r"""Google's first end-to-end optimized video compression from E. Agustsson, D. Minnen, N. Johnston, J. Balle, S. J. Hwang, G. Toderici: `"Scale-space flow for end-to-end optimized video compression" <https://openaccess.thecvf.com/content_CVPR_2020/html/Agustsson_Scale-Space_Flow_for_End-to-End_Optimized_Video_Compression_CVPR_2020_paper.html>`_, IEEE Conference on Computer Vision and Pattern Recognition (CVPR 2020). Args: num_levels (int): Number of Scale-space sigma0 (float): standard deviation for gaussian kernel of the first space scale. scale_field_shift (float): """ def __init__( self, num_levels: int = 5, sigma0: float = 1.5, scale_field_shift: float = 1.0, ): super().__init__() class Encoder(nn.Sequential): def __init__( self, in_planes: int, mid_planes: int = 128, out_planes: int = 192 ): super().__init__( conv(in_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), conv(mid_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), conv(mid_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), conv(mid_planes, out_planes, kernel_size=5, stride=2), ) class Decoder(nn.Sequential): def __init__( self, out_planes: int, in_planes: int = 192, mid_planes: int = 128 ): super().__init__( deconv(in_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), deconv(mid_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), deconv(mid_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), deconv(mid_planes, out_planes, kernel_size=5, stride=2), ) class HyperEncoder(nn.Sequential): def __init__( self, in_planes: int = 192, mid_planes: int = 192, out_planes: int = 192 ): super().__init__( conv(in_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), conv(mid_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), conv(mid_planes, mid_planes, kernel_size=5, stride=2), ) class HyperDecoder(nn.Sequential): def __init__( self, in_planes: int = 192, mid_planes: int = 192, out_planes: int = 192 ): super().__init__( deconv(in_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), deconv(mid_planes, mid_planes, kernel_size=5, stride=2), nn.ReLU(inplace=True), deconv(mid_planes, out_planes, kernel_size=5, stride=2), ) class HyperDecoderWithQReLU(nn.Module): def __init__( self, in_planes: int = 192, mid_planes: int = 192, out_planes: int = 192 ): super().__init__() def qrelu(input, bit_depth=8, beta=100): return QReLU.apply(input, bit_depth, beta) self.deconv1 = deconv(in_planes, mid_planes, kernel_size=5, stride=2) self.qrelu1 = qrelu self.deconv2 = deconv(mid_planes, mid_planes, kernel_size=5, stride=2) self.qrelu2 = qrelu self.deconv3 = deconv(mid_planes, out_planes, kernel_size=5, stride=2) self.qrelu3 = qrelu def forward(self, x): x = self.qrelu1(self.deconv1(x)) x = self.qrelu2(self.deconv2(x)) x = self.qrelu3(self.deconv3(x)) return x class Hyperprior(CompressionModel): def __init__(self, planes: int = 192, mid_planes: int = 192): super().__init__() self.entropy_bottleneck = EntropyBottleneck(mid_planes) self.hyper_encoder = HyperEncoder(planes, mid_planes, planes) self.hyper_decoder_mean = HyperDecoder(planes, mid_planes, planes) self.hyper_decoder_scale = HyperDecoderWithQReLU( planes, mid_planes, planes ) self.gaussian_conditional = GaussianConditional(None) def forward(self, y): z = self.hyper_encoder(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) scales = self.hyper_decoder_scale(z_hat) means = self.hyper_decoder_mean(z_hat) _, y_likelihoods = self.gaussian_conditional(y, scales, means) y_hat = quantize_ste(y - means) + means return y_hat, {"y": y_likelihoods, "z": z_likelihoods} def compress(self, y): z = self.hyper_encoder(y) z_string = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_string, z.size()[-2:]) scales = self.hyper_decoder_scale(z_hat) means = self.hyper_decoder_mean(z_hat) indexes = self.gaussian_conditional.build_indexes(scales) y_string = self.gaussian_conditional.compress(y, indexes, means) y_hat = self.gaussian_conditional.quantize(y, "dequantize", means) return y_hat, {"strings": [y_string, z_string], "shape": z.size()[-2:]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) scales = self.hyper_decoder_scale(z_hat) means = self.hyper_decoder_mean(z_hat) indexes = self.gaussian_conditional.build_indexes(scales) y_hat = self.gaussian_conditional.decompress( strings[0], indexes, z_hat.dtype, means ) return y_hat self.img_encoder = Encoder(3) self.img_decoder = Decoder(3) self.img_hyperprior = Hyperprior() self.res_encoder = Encoder(3) self.res_decoder = Decoder(3, in_planes=384) self.res_hyperprior = Hyperprior() self.motion_encoder = Encoder(2 * 3) self.motion_decoder = Decoder(2 + 1) self.motion_hyperprior = Hyperprior() self.sigma0 = sigma0 self.num_levels = num_levels self.scale_field_shift = scale_field_shift def forward(self, frames): if not isinstance(frames, List): raise RuntimeError(f"Invalid number of frames: {len(frames)}.") reconstructions = [] frames_likelihoods = [] x_hat, likelihoods = self.forward_keyframe(frames[0]) reconstructions.append(x_hat) frames_likelihoods.append(likelihoods) x_ref = x_hat.detach() # stop gradient flow (cf: google2020 paper) for i in range(1, len(frames)): x = frames[i] x_ref, likelihoods = self.forward_inter(x, x_ref) reconstructions.append(x_ref) frames_likelihoods.append(likelihoods) return { "x_hat": reconstructions, "likelihoods": frames_likelihoods, } def forward_keyframe(self, x): y = self.img_encoder(x) y_hat, likelihoods = self.img_hyperprior(y) x_hat = self.img_decoder(y_hat) return x_hat, {"keyframe": likelihoods} def encode_keyframe(self, x): y = self.img_encoder(x) y_hat, out_keyframe = self.img_hyperprior.compress(y) x_hat = self.img_decoder(y_hat) return x_hat, out_keyframe def decode_keyframe(self, strings, shape): y_hat = self.img_hyperprior.decompress(strings, shape) x_hat = self.img_decoder(y_hat) return x_hat def forward_inter(self, x_cur, x_ref): # encode the motion information x = torch.cat((x_cur, x_ref), dim=1) y_motion = self.motion_encoder(x) y_motion_hat, motion_likelihoods = self.motion_hyperprior(y_motion) # decode the space-scale flow information motion_info = self.motion_decoder(y_motion_hat) x_pred = self.forward_prediction(x_ref, motion_info) # residual x_res = x_cur - x_pred y_res = self.res_encoder(x_res) y_res_hat, res_likelihoods = self.res_hyperprior(y_res) # y_combine y_combine = torch.cat((y_res_hat, y_motion_hat), dim=1) x_res_hat = self.res_decoder(y_combine) # final reconstruction: prediction + residual x_rec = x_pred + x_res_hat return x_rec, {"motion": motion_likelihoods, "residual": res_likelihoods} def encode_inter(self, x_cur, x_ref): # encode the motion information x = torch.cat((x_cur, x_ref), dim=1) y_motion = self.motion_encoder(x) y_motion_hat, out_motion = self.motion_hyperprior.compress(y_motion) # decode the space-scale flow information motion_info = self.motion_decoder(y_motion_hat) x_pred = self.forward_prediction(x_ref, motion_info) # residual x_res = x_cur - x_pred y_res = self.res_encoder(x_res) y_res_hat, out_res = self.res_hyperprior.compress(y_res) # y_combine y_combine = torch.cat((y_res_hat, y_motion_hat), dim=1) x_res_hat = self.res_decoder(y_combine) # final reconstruction: prediction + residual x_rec = x_pred + x_res_hat return x_rec, { "strings": { "motion": out_motion["strings"], "residual": out_res["strings"], }, "shape": {"motion": out_motion["shape"], "residual": out_res["shape"]}, } def decode_inter(self, x_ref, strings, shapes): key = "motion" y_motion_hat = self.motion_hyperprior.decompress(strings[key], shapes[key]) # decode the space-scale flow information motion_info = self.motion_decoder(y_motion_hat) x_pred = self.forward_prediction(x_ref, motion_info) # residual key = "residual" y_res_hat = self.res_hyperprior.decompress(strings[key], shapes[key]) # y_combine y_combine = torch.cat((y_res_hat, y_motion_hat), dim=1) x_res_hat = self.res_decoder(y_combine) # final reconstruction: prediction + residual x_rec = x_pred + x_res_hat return x_rec @staticmethod def gaussian_volume(x, sigma: float, num_levels: int): """Efficient gaussian volume construction. From: "Generative Video Compression as Hierarchical Variational Inference", by Yang et al. """ k = 2 * int(math.ceil(3 * sigma)) + 1 device = x.device dtype = x.dtype if torch.is_floating_point(x) else torch.float32 kernel = gaussian_kernel2d(k, sigma, device=device, dtype=dtype) volume = [x.unsqueeze(2)] x = gaussian_blur(x, kernel=kernel) volume += [x.unsqueeze(2)] for i in range(1, num_levels): x = F.avg_pool2d(x, kernel_size=(2, 2), stride=(2, 2)) x = gaussian_blur(x, kernel=kernel) interp = x for _ in range(0, i): interp = F.interpolate( interp, scale_factor=2, mode="bilinear", align_corners=False ) volume.append(interp.unsqueeze(2)) return torch.cat(volume, dim=2) @amp.autocast(enabled=False) def warp_volume(self, volume, flow, scale_field, padding_mode: str = "border"): """3D volume warping.""" if volume.ndimension() != 5: raise ValueError( f"Invalid number of dimensions for volume {volume.ndimension()}" ) N, C, _, H, W = volume.size() grid = meshgrid2d(N, C, H, W, volume.device) update_grid = grid + flow.permute(0, 2, 3, 1).float() update_scale = scale_field.permute(0, 2, 3, 1).float() volume_grid = torch.cat((update_grid, update_scale), dim=-1).unsqueeze(1) out = F.grid_sample( volume.float(), volume_grid, padding_mode=padding_mode, align_corners=False ) return out.squeeze(2) def forward_prediction(self, x_ref, motion_info): flow, scale_field = motion_info.chunk(2, dim=1) volume = self.gaussian_volume(x_ref, self.sigma0, self.num_levels) x_pred = self.warp_volume(volume, flow, scale_field) return x_pred def aux_loss(self): """Return a list of the auxiliary entropy bottleneck over module(s).""" aux_loss_list = [] for m in self.modules(): if isinstance(m, CompressionModel) and m is not self: aux_loss_list.append(m.aux_loss()) return aux_loss_list def compress(self, frames): if not isinstance(frames, List): raise RuntimeError(f"Invalid number of frames: {len(frames)}.") frame_strings = [] shape_infos = [] x_ref, out_keyframe = self.encode_keyframe(frames[0]) frame_strings.append(out_keyframe["strings"]) shape_infos.append(out_keyframe["shape"]) for i in range(1, len(frames)): x = frames[i] x_ref, out_interframe = self.encode_inter(x, x_ref) frame_strings.append(out_interframe["strings"]) shape_infos.append(out_interframe["shape"]) return frame_strings, shape_infos def decompress(self, strings, shapes): if not isinstance(strings, List) or not isinstance(shapes, List): raise RuntimeError(f"Invalid number of frames: {len(strings)}.") assert len(strings) == len( shapes ), f"Number of information should match {len(strings)} != {len(shapes)}." dec_frames = [] x_ref = self.decode_keyframe(strings[0], shapes[0]) dec_frames.append(x_ref) for i in range(1, len(strings)): string = strings[i] shape = shapes[i] x_ref = self.decode_inter(x_ref, string, shape) dec_frames.append(x_ref) return dec_frames @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" net = cls() net.load_state_dict(state_dict) return net