from typing import Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
YCBCR_WEIGHTS = {
# Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b
"ITU-R_BT.709": (0.2126, 0.7152, 0.0722)
}
def _check_input_tensor(tensor: Tensor) -> None:
if (
not isinstance(tensor, Tensor)
or not tensor.is_floating_point()
or not len(tensor.size()) in (3, 4)
or not tensor.size(-3) == 3
):
raise ValueError(
"Expected a 3D or 4D tensor with shape (Nx3xHxW) or (3xHxW) as input"
)
[docs]
def rgb2ycbcr(rgb: Tensor) -> Tensor:
"""RGB to YCbCr conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
rgb (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
ycbcr (torch.Tensor): converted tensor
"""
_check_input_tensor(rgb)
r, g, b = rgb.chunk(3, -3)
Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
y = Kr * r + Kg * g + Kb * b
cb = 0.5 * (b - y) / (1 - Kb) + 0.5
cr = 0.5 * (r - y) / (1 - Kr) + 0.5
ycbcr = torch.cat((y, cb, cr), dim=-3)
return ycbcr
[docs]
def ycbcr2rgb(ycbcr: Tensor) -> Tensor:
"""YCbCr to RGB conversion for torch Tensor.
Using ITU-R BT.709 coefficients.
Args:
ycbcr (torch.Tensor): 3D or 4D floating point RGB tensor
Returns:
rgb (torch.Tensor): converted tensor
"""
_check_input_tensor(ycbcr)
y, cb, cr = ycbcr.chunk(3, -3)
Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
r = y + (2 - 2 * Kr) * (cr - 0.5)
b = y + (2 - 2 * Kb) * (cb - 0.5)
g = (y - Kr * r - Kb * b) / Kg
rgb = torch.cat((r, g, b), dim=-3)
return rgb
[docs]
def yuv_444_to_420(
yuv: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
mode: str = "avg_pool",
) -> Tuple[Tensor, Tensor, Tensor]:
"""Convert a 444 tensor to a 420 representation.
Args:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 444
input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple
of 3 (Nx1xHxW) tensors.
mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
``'avg_pool'``
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
"""
if mode not in ("avg_pool",):
raise ValueError(f'Invalid downsampling mode "{mode}".')
if mode == "avg_pool":
def _downsample(tensor):
return F.avg_pool2d(tensor, kernel_size=2, stride=2)
if isinstance(yuv, torch.Tensor):
y, u, v = yuv.chunk(3, 1)
else:
y, u, v = yuv
return (y, _downsample(u), _downsample(v))
[docs]
def yuv_420_to_444(
yuv: Tuple[Tensor, Tensor, Tensor],
mode: str = "bilinear",
return_tuple: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]:
"""Convert a 420 input to a 444 representation.
Args:
yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
(Nx1xHxW) format
mode (str): algorithm used for upsampling: ``'bilinear'`` |
| ``'bilinear'`` | ``'nearest'`` Default ``'bilinear'``
return_tuple (bool): return input as tuple of tensors instead of a
concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
tensor (default: False)
Returns:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
444
"""
if len(yuv) != 3 or any(not isinstance(c, torch.Tensor) for c in yuv):
raise ValueError("Expected a tuple of 3 torch tensors")
if mode not in ("bilinear", "bicubic", "nearest"):
raise ValueError(f'Invalid upsampling mode "{mode}".')
kwargs = {}
if mode != "nearest":
kwargs = {"align_corners": False}
def _upsample(tensor):
return F.interpolate(tensor, scale_factor=2, mode=mode, **kwargs)
y, u, v = yuv
u, v = _upsample(u), _upsample(v)
if return_tuple:
return y, u, v
return torch.cat((y, u, v), dim=1)