# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import math
from typing import Any, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.autograd import Function
from .gdn import GDN
__all__ = [
"AttentionBlock",
"MaskedConv2d",
"CheckerboardMaskedConv2d",
"ResidualBlock",
"ResidualBlockUpsample",
"ResidualBlockWithStride",
"conv1x1",
"SpectralConv2d",
"SpectralConvTranspose2d",
"conv3x3",
"subpel_conv3x3",
"QReLU",
"sequential_channel_ramp",
]
class _SpectralConvNdMixin:
def __init__(self, dim: Tuple[int, ...]):
self.dim = dim
self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight))
del self._parameters["weight"] # Unregister weight, and fallback to property.
@property
def weight(self) -> Tensor:
return self._from_transform_domain(self.weight_transformed)
def _to_transform_domain(self, x: Tensor) -> Tensor:
return torch.fft.rfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho")
def _from_transform_domain(self, x: Tensor) -> Tensor:
return torch.fft.irfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho")
class SpectralConv2d(nn.Conv2d, _SpectralConvNdMixin):
r"""Spectral 2D convolution.
Introduced in [Balle2018efficient].
Reparameterizes the weights to be derived from weights stored in the
frequency domain.
In the original paper, this is referred to as "spectral Adam" or
"Sadam" due to its effect on the Adam optimizer update rule.
The motivation behind representing the weights in the frequency
domain is that optimizer updates/steps may now affect all
frequencies to an equal amount.
This improves the gradient conditioning, thus leading to faster
convergence and increased stability at larger learning rates.
For comparison, see the TensorFlow Compression implementations of
`SignalConv2D
<https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61>`_
and
`RDFTParameter
<https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71>`_.
[Balle2018efficient]: `"Efficient Nonlinear Transforms for Lossy
Image Compression" <https://arxiv.org/abs/1802.00847>`_,
by Johannes Ballé, PCS 2018.
"""
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
_SpectralConvNdMixin.__init__(self, dim=(-2, -1))
class SpectralConvTranspose2d(nn.ConvTranspose2d, _SpectralConvNdMixin):
r"""Spectral 2D transposed convolution.
Transposed version of :class:`SpectralConv2d`.
"""
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
_SpectralConvNdMixin.__init__(self, dim=(-2, -1))
[docs]
class MaskedConv2d(nn.Conv2d):
r"""Masked 2D convolution implementation, mask future "unseen" pixels.
Useful for building auto-regressive network components.
Introduced in `"Conditional Image Generation with PixelCNN Decoders"
<https://arxiv.org/abs/1606.05328>`_.
Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the
first layer (which also masks the "current pixel"), `mask_type='B'` for the
following layers.
"""
def __init__(self, *args: Any, mask_type: str = "A", **kwargs: Any):
super().__init__(*args, **kwargs)
if mask_type not in ("A", "B"):
raise ValueError(f'Invalid "mask_type" value "{mask_type}"')
self.register_buffer("mask", torch.ones_like(self.weight.data))
_, _, h, w = self.mask.size()
self.mask[:, :, h // 2, w // 2 + (mask_type == "B") :] = 0
self.mask[:, :, h // 2 + 1 :] = 0
def forward(self, x: Tensor) -> Tensor:
# TODO(begaintj): weight assigment is not supported by torchscript
self.weight.data = self.weight.data * self.mask
return super().forward(x)
class CheckerboardMaskedConv2d(MaskedConv2d):
r"""Checkerboard masked 2D convolution; mask future "unseen" pixels.
Checkerboard mask variant used in
`"Checkerboard Context Model for Efficient Learned Image Compression"
<https://arxiv.org/abs/2103.15306>`_, by Dailan He, Yaoyan Zheng,
Baocheng Sun, Yan Wang, and Hongwei Qin, CVPR 2021.
Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the
first layer (which also masks the "current pixel"), `mask_type='B'` for the
following layers.
"""
def __init__(self, *args: Any, mask_type: str = "A", **kwargs: Any):
super().__init__(*args, **kwargs)
if mask_type not in ("A", "B"):
raise ValueError(f'Invalid "mask_type" value "{mask_type}"')
_, _, h, w = self.mask.size()
self.mask[:] = 1
self.mask[:, :, 0::2, 0::2] = 0
self.mask[:, :, 1::2, 1::2] = 0
self.mask[:, :, h // 2, w // 2] = mask_type == "B"
def conv3x3(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
"""3x3 convolution with padding."""
return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
def subpel_conv3x3(in_ch: int, out_ch: int, r: int = 1) -> nn.Sequential:
"""3x3 sub-pixel convolution for up-sampling."""
return nn.Sequential(
nn.Conv2d(in_ch, out_ch * r**2, kernel_size=3, padding=1), nn.PixelShuffle(r)
)
def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
"""1x1 convolution."""
return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
[docs]
class ResidualBlockWithStride(nn.Module):
"""Residual block with a stride on the first convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
stride (int): stride value (default: 2)
"""
def __init__(self, in_ch: int, out_ch: int, stride: int = 2):
super().__init__()
self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
self.leaky_relu = nn.LeakyReLU(inplace=True)
self.conv2 = conv3x3(out_ch, out_ch)
self.gdn = GDN(out_ch)
if stride != 1 or in_ch != out_ch:
self.skip = conv1x1(in_ch, out_ch, stride=stride)
else:
self.skip = None
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.leaky_relu(out)
out = self.conv2(out)
out = self.gdn(out)
if self.skip is not None:
identity = self.skip(x)
out += identity
return out
[docs]
class ResidualBlockUpsample(nn.Module):
"""Residual block with sub-pixel upsampling on the last convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
upsample (int): upsampling factor (default: 2)
"""
def __init__(self, in_ch: int, out_ch: int, upsample: int = 2):
super().__init__()
self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample)
self.leaky_relu = nn.LeakyReLU(inplace=True)
self.conv = conv3x3(out_ch, out_ch)
self.igdn = GDN(out_ch, inverse=True)
self.upsample = subpel_conv3x3(in_ch, out_ch, upsample)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.subpel_conv(x)
out = self.leaky_relu(out)
out = self.conv(out)
out = self.igdn(out)
identity = self.upsample(x)
out += identity
return out
[docs]
class ResidualBlock(nn.Module):
"""Simple residual block with two 3x3 convolutions.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
"""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.conv1 = conv3x3(in_ch, out_ch)
self.leaky_relu = nn.LeakyReLU(inplace=True)
self.conv2 = conv3x3(out_ch, out_ch)
if in_ch != out_ch:
self.skip = conv1x1(in_ch, out_ch)
else:
self.skip = None
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.leaky_relu(out)
out = self.conv2(out)
out = self.leaky_relu(out)
if self.skip is not None:
identity = self.skip(x)
out = out + identity
return out
[docs]
class AttentionBlock(nn.Module):
"""Self attention block.
Simplified variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Args:
N (int): Number of channels)
"""
def __init__(self, N: int):
super().__init__()
class ResidualUnit(nn.Module):
"""Simple residual unit."""
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
conv1x1(N, N // 2),
nn.ReLU(inplace=True),
conv3x3(N // 2, N // 2),
nn.ReLU(inplace=True),
conv1x1(N // 2, N),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv(x)
out += identity
out = self.relu(out)
return out
self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit())
self.conv_b = nn.Sequential(
ResidualUnit(),
ResidualUnit(),
ResidualUnit(),
conv1x1(N, N),
)
def forward(self, x: Tensor) -> Tensor:
identity = x
a = self.conv_a(x)
b = self.conv_b(x)
out = a * torch.sigmoid(b)
out += identity
return out
[docs]
class QReLU(Function):
"""QReLU
Clamping input with given bit-depth range.
Suppose that input data presents integer through an integer network
otherwise any precision of input will simply clamp without rounding
operation.
Pre-computed scale with gamma function is used for backward computation.
More details can be found in
`"Integer networks for data compression with latent-variable models"
<https://openreview.net/pdf?id=S1zz2i0cY7>`_,
by Johannes Ballé, Nick Johnston and David Minnen, ICLR in 2019
Args:
input: a tensor data
bit_depth: source bit-depth (used for clamping)
beta: a parameter for modeling the gradient during backward computation
"""
@staticmethod
def forward(ctx, input, bit_depth, beta):
# TODO(choih): allow to use adaptive scale instead of
# pre-computed scale with gamma function
ctx.alpha = 0.9943258522851727
ctx.beta = beta
ctx.max_value = 2**bit_depth - 1
ctx.save_for_backward(input)
return input.clamp(min=0, max=ctx.max_value)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
(input,) = ctx.saved_tensors
grad_input = grad_output.clone()
grad_sub = (
torch.exp(
(-ctx.alpha**ctx.beta)
* torch.abs(2.0 * input / ctx.max_value - 1) ** ctx.beta
)
* grad_output.clone()
)
grad_input[input < 0] = grad_sub[input < 0]
grad_input[input > ctx.max_value] = grad_sub[input > ctx.max_value]
return grad_input, None, None
def sequential_channel_ramp(
in_ch: int,
out_ch: int,
*,
min_ch: int = 0,
num_layers: int = 3,
interp: str = "linear",
make_layer=None,
make_act=None,
skip_last_act: bool = True,
**layer_kwargs,
) -> nn.Module:
"""Interleave layers of gradually ramping channels with nonlinearities."""
channels = ramp(in_ch, out_ch, num_layers + 1, method=interp).floor().int()
channels[1:-1] = channels[1:-1].clip(min=min_ch)
channels = channels.tolist()
layers = [
module
for ch_in, ch_out in zip(channels[:-1], channels[1:])
for module in [
make_layer(ch_in, ch_out, **layer_kwargs),
make_act(),
]
]
if skip_last_act:
layers = layers[:-1]
return nn.Sequential(*layers)
def ramp(a, b, steps=None, method="linear", **kwargs):
if method == "linear":
return torch.linspace(a, b, steps, **kwargs)
if method == "log":
return torch.logspace(math.log10(a), math.log10(b), steps, **kwargs)
raise ValueError(f"Unknown ramp method: {method}")