compressai.transforms#

Transforms on Tensors#

class compressai.transforms.RGB2YCbCr[source]#

Convert a RGB tensor to YCbCr. The tensor is expected to be in the [0, 1] floating point range, with a shape of (3xHxW) or (Nx3xHxW).

class compressai.transforms.YCbCr2RGB[source]#

Convert a YCbCr tensor to RGB. The tensor is expected to be in the [0, 1] floating point range, with a shape of (3xHxW) or (Nx3xHxW).

class compressai.transforms.YUV420To444(mode: str = 'bilinear', return_tuple: bool = False)[source]#

Convert a YUV 420 input to a 444 representation.

Parameters:
  • mode (str) – algorithm used for upsampling: 'bilinear' | 'nearest'. Default 'bilinear'

  • return_tuple (bool) – return input as tuple of tensors instead of a concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) tensor (default: False)

Example

>>> y = torch.rand(1, 1, 32, 32)
>>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16)
>>> x = YUV420To444()((y, u, v))
>>> x.size()  # 1, 3, 32, 32
class compressai.transforms.YUV444To420(mode: str = 'avg_pool')[source]#

Convert a YUV 444 tensor to a 420 representation.

Parameters:

mode (str) – algorithm used for downsampling: 'avg_pool'. Default 'avg_pool'

Example

>>> x = torch.rand(1, 3, 32, 32)
>>> y, u, v = YUV444To420()(x)
>>> y.size()  # 1, 1, 32, 32
>>> u.size()  # 1, 1, 16, 16

Functional Transforms#

Functional transforms can be used to define custom transform classes.

compressai.transforms.functional.rgb2ycbcr(rgb: Tensor) Tensor[source]#

RGB to YCbCr conversion for torch Tensor. Using ITU-R BT.709 coefficients.

Parameters:

rgb (torch.Tensor) – 3D or 4D floating point RGB tensor

Returns:

converted tensor

Return type:

ycbcr (torch.Tensor)

compressai.transforms.functional.ycbcr2rgb(ycbcr: Tensor) Tensor[source]#

YCbCr to RGB conversion for torch Tensor. Using ITU-R BT.709 coefficients.

Parameters:

ycbcr (torch.Tensor) – 3D or 4D floating point RGB tensor

Returns:

converted tensor

Return type:

rgb (torch.Tensor)

compressai.transforms.functional.yuv_420_to_444(yuv: Tuple[Tensor, Tensor, Tensor], mode: str = 'bilinear', return_tuple: bool = False) Tensor | Tuple[Tensor, Tensor, Tensor][source]#

Convert a 420 input to a 444 representation.

Parameters:
  • yuv (torch.Tensor, torch.Tensor, torch.Tensor) – 420 input frames in (Nx1xHxW) format

  • mode (str) – algorithm used for upsampling: 'bilinear' | | 'bilinear' | 'nearest' Default 'bilinear'

  • return_tuple (bool) – return input as tuple of tensors instead of a concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) tensor (default: False)

Returns:

Converted

444

Return type:

(torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor))

compressai.transforms.functional.yuv_444_to_420(yuv: Tensor | Tuple[Tensor, Tensor, Tensor], mode: str = 'avg_pool') Tuple[Tensor, Tensor, Tensor][source]#

Convert a 444 tensor to a 420 representation.

Parameters:
  • yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)) – 444 input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple of 3 (Nx1xHxW) tensors.

  • mode (str) – algorithm used for downsampling: 'avg_pool'. Default 'avg_pool'

Returns:

Converted 420

Return type:

(torch.Tensor, torch.Tensor, torch.Tensor)

Point Cloud Transforms#

class compressai.transforms.point.GeneratePositionNormals(*, method='any', **kwargs)[source]#

Generates normals from node positions (functional name: generate_position_normals).

class compressai.transforms.point.NormalizeScaleV2(*, center=True, scale_method='linf')[source]#

Centers and normalizes node positions (functional name: normalize_scale_v2).

class compressai.transforms.point.RandomPermutation(*, attrs=('pos',))[source]#

Randomly permutes points and associated attributes (functional name: random_permutation).

class compressai.transforms.point.RandomRotateFull[source]#

Randomly rotates node positions around the origin (functional name: random_rotate_full).

class compressai.transforms.point.RandomSample(num=None, *, attrs=('pos',), remove_duplicates_by=None, preserve_order=False, seed=None, static_seed=None)[source]#

Randomly samples points and associated attributes (functional name: random_sample).

class compressai.transforms.point.SamplePointsV2(num: int, *, remove_faces: bool = True, include_normals: bool = False, seed=None, static_seed=None)[source]#

Uniformly samples a fixed number of points on the mesh faces according to their face area (functional name: sample_points).

Adapted from PyTorch Geometric under MIT license at pyg-team/pytorch_geometric.

Parameters:
  • num (int) – The number of points to sample.

  • remove_faces (bool, optional) – If set to False, the face tensor will not be removed. (default: True)

  • include_normals (bool, optional) – If set to True, then compute normals for each sampled point. (default: False)

  • seed (int, optional) – Initial random seed.

  • static_seed (int, optional) – Reset random seed to this every call.

class compressai.transforms.point.ToDict(*, wrapper='dict')[source]#

Convert Mapping[str, Any] (functional name: to_dict).