compressai.transforms#
Transforms on Tensors#
- class compressai.transforms.RGB2YCbCr[source]#
Convert a RGB tensor to YCbCr. The tensor is expected to be in the [0, 1] floating point range, with a shape of (3xHxW) or (Nx3xHxW).
- class compressai.transforms.YCbCr2RGB[source]#
Convert a YCbCr tensor to RGB. The tensor is expected to be in the [0, 1] floating point range, with a shape of (3xHxW) or (Nx3xHxW).
- class compressai.transforms.YUV420To444(mode: str = 'bilinear', return_tuple: bool = False)[source]#
Convert a YUV 420 input to a 444 representation.
- Parameters:
mode (str) – algorithm used for upsampling:
'bilinear'
|'nearest'
. Default'bilinear'
return_tuple (bool) – return input as tuple of tensors instead of a concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) tensor (default: False)
Example
>>> y = torch.rand(1, 1, 32, 32) >>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16) >>> x = YUV420To444()((y, u, v)) >>> x.size() # 1, 3, 32, 32
- class compressai.transforms.YUV444To420(mode: str = 'avg_pool')[source]#
Convert a YUV 444 tensor to a 420 representation.
- Parameters:
mode (str) – algorithm used for downsampling:
'avg_pool'
. Default'avg_pool'
Example
>>> x = torch.rand(1, 3, 32, 32) >>> y, u, v = YUV444To420()(x) >>> y.size() # 1, 1, 32, 32 >>> u.size() # 1, 1, 16, 16
Functional Transforms#
Functional transforms can be used to define custom transform classes.
- compressai.transforms.functional.rgb2ycbcr(rgb: Tensor) Tensor [source]#
RGB to YCbCr conversion for torch Tensor. Using ITU-R BT.709 coefficients.
- Parameters:
rgb (torch.Tensor) – 3D or 4D floating point RGB tensor
- Returns:
converted tensor
- Return type:
ycbcr (torch.Tensor)
- compressai.transforms.functional.ycbcr2rgb(ycbcr: Tensor) Tensor [source]#
YCbCr to RGB conversion for torch Tensor. Using ITU-R BT.709 coefficients.
- Parameters:
ycbcr (torch.Tensor) – 3D or 4D floating point RGB tensor
- Returns:
converted tensor
- Return type:
rgb (torch.Tensor)
- compressai.transforms.functional.yuv_420_to_444(yuv: Tuple[Tensor, Tensor, Tensor], mode: str = 'bilinear', return_tuple: bool = False) Tensor | Tuple[Tensor, Tensor, Tensor] [source]#
Convert a 420 input to a 444 representation.
- Parameters:
yuv (torch.Tensor, torch.Tensor, torch.Tensor) – 420 input frames in (Nx1xHxW) format
mode (str) – algorithm used for upsampling:
'bilinear'
| |'bilinear'
|'nearest'
Default'bilinear'
return_tuple (bool) – return input as tuple of tensors instead of a concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) tensor (default: False)
- Returns:
- Converted
444
- Return type:
(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor))
- compressai.transforms.functional.yuv_444_to_420(yuv: Tensor | Tuple[Tensor, Tensor, Tensor], mode: str = 'avg_pool') Tuple[Tensor, Tensor, Tensor] [source]#
Convert a 444 tensor to a 420 representation.
- Parameters:
yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)) – 444 input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple of 3 (Nx1xHxW) tensors.
mode (str) – algorithm used for downsampling:
'avg_pool'
. Default'avg_pool'
- Returns:
Converted 420
- Return type:
(torch.Tensor, torch.Tensor, torch.Tensor)
Point Cloud Transforms#
- class compressai.transforms.point.GeneratePositionNormals(*, method='any', **kwargs)[source]#
Generates normals from node positions (functional name:
generate_position_normals
).
- class compressai.transforms.point.NormalizeScaleV2(*, center=True, scale_method='linf')[source]#
Centers and normalizes node positions (functional name:
normalize_scale_v2
).
- class compressai.transforms.point.RandomPermutation(*, attrs=('pos',))[source]#
Randomly permutes points and associated attributes (functional name:
random_permutation
).
- class compressai.transforms.point.RandomRotateFull[source]#
Randomly rotates node positions around the origin (functional name:
random_rotate_full
).
- class compressai.transforms.point.RandomSample(num=None, *, attrs=('pos',), remove_duplicates_by=None, preserve_order=False, seed=None, static_seed=None)[source]#
Randomly samples points and associated attributes (functional name:
random_sample
).
- class compressai.transforms.point.SamplePointsV2(num: int, *, remove_faces: bool = True, include_normals: bool = False, seed=None, static_seed=None)[source]#
Uniformly samples a fixed number of points on the mesh faces according to their face area (functional name:
sample_points
).Adapted from PyTorch Geometric under MIT license at pyg-team/pytorch_geometric.
- Parameters:
num (int) – The number of points to sample.
remove_faces (bool, optional) – If set to
False
, the face tensor will not be removed. (default:True
)include_normals (bool, optional) – If set to
True
, then compute normals for each sampled point. (default:False
)seed (int, optional) – Initial random seed.
static_seed (int, optional) – Reset random seed to this every call.