Source code for compressai.models.pointcloud.sfu_pointnet2

# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from __future__ import annotations

import torch
import torch.nn as nn

from compressai.latent_codecs import EntropyBottleneckLatentCodec
from compressai.layers.basic import Gain, Interleave, Reshape, Transpose
from compressai.layers.pointcloud.pointnet import GAIN
from compressai.layers.pointcloud.pointnet2 import PointNetSetAbstraction
from compressai.layers.pointcloud.pointnet2_sfu import UpsampleBlock
from compressai.models import CompressionModel
from compressai.registry import register_model

__all__ = [
    "PointNet2SsgReconstructionPccModel",
]


[docs] @register_model("sfu2024-pcc-rec-pointnet2-ssg") class PointNet2SsgReconstructionPccModel(CompressionModel): """PointNet++-based PCC reconstruction model. Model based on PointNet++ [Qi2017PointNetPlusPlus]_, and modified for compression by [Ulhaq2024]_. Uses single-scale grouping (SSG) for point set abstraction. References: .. [Qi2017PointNetPlusPlus] `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" <https://arxiv.org/abs/1706.02413>`_, by Charles R. Qi, Li Yi, Hao Su, and Leonidas J. Guibas, NIPS 2017. .. [Ulhaq2024] `"Scalable Human-Machine Point Cloud Compression" <TODO>`_, by Mateen Ulhaq and Ivan V. Bajić, PCS 2024. """ def __init__( self, num_points=1024, num_classes=40, D=(0, 128, 192, 256), P=(1024, 256, 64, 1), S=(None, 4, 4, 64), R=(None, 0.2, 0.4, None), E=(3, 64, 32, 16, 0), M=(0, 0, 64, 64), normal_channel=False, ): """ Args: num_points: Number of input points. [unused] num_classes: Number of object classes. [unused] D: Number of input feature channels. P: Number of output points. S: Number of samples per centroid. R: Radius of the ball to query points within. E: Number of output feature channels after each upsample. M: Number of latent channels for the bottleneck. normal_channel: Whether the input includes normals. """ super().__init__() self.num_points = num_points self.num_classes = num_classes self.D = D self.P = P self.S = S self.R = R self.E = E self.M = M self.normal_channel = bool(normal_channel) # Original PointNet++ architecture: # D = [3 * self.normal_channel, 128, 256, 1024] # P = [None, 512, 128, 1] # S = [None, 32, 64, 128] # R = [None, 0.2, 0.4, None] # NOTE: P[0] is only used to determine the number of output points. # assert P[0] == num_points assert P[0] == P[1] * S[1] assert P[1] == P[2] * S[2] assert P[2] == P[3] * S[3] self.levels = 4 self.down = nn.ModuleDict( { "_1": PointNetSetAbstraction( npoint=P[1], radius=R[1], nsample=S[1], in_channel=D[0] + 3, mlp=[D[1] // 2, D[1] // 2, D[1]], group_all=False, ), "_2": PointNetSetAbstraction( npoint=P[2], radius=R[2], nsample=S[2], in_channel=D[1] + 3, mlp=[D[1], D[1], D[2]], group_all=False, ), "_3": PointNetSetAbstraction( npoint=None, radius=None, nsample=None, in_channel=D[2] + 3, mlp=[D[2], D[3], D[3]], group_all=True, ), } ) i_final = self.levels - 1 groups_h_final = 1 if D[i_final] * M[i_final] <= 2**16 else 4 self.h_a = nn.ModuleDict( { **{ f"_{i}": nn.Sequential( Reshape((D[i] + 3, P[i + 1] * S[i + 1])), nn.Conv1d(D[i] + 3, M[i], 1), Gain((M[i], 1), factor=GAIN), ) for i in range(self.levels - 1) if M[i] > 0 }, f"_{i_final}": nn.Sequential( Reshape((D[i_final], 1)), nn.Conv1d(D[i_final], M[i_final], 1, groups=groups_h_final), Interleave(groups=groups_h_final), Gain((M[i_final], 1), factor=GAIN), ), } ) self.h_s = nn.ModuleDict( { **{ f"_{i}": nn.Sequential( Gain((M[i], 1), factor=1 / GAIN), nn.Conv1d(M[i], D[i] + 3, 1), ) for i in range(self.levels - 1) if M[i] > 0 }, f"_{i_final}": nn.Sequential( Gain((M[i_final], 1), factor=1 / GAIN), nn.Conv1d(M[i_final], D[i_final], 1, groups=groups_h_final), Interleave(groups=groups_h_final), ), } ) self.up = nn.ModuleDict( { "_0": nn.Sequential( nn.Conv1d(E[1] + D[0] + 3 * bool(M[0]), E[1], 1), # nn.BatchNorm1d(E[1]), nn.ReLU(inplace=True), nn.Conv1d(E[1], E[0], 1), Reshape((E[0], P[0])), Transpose(-2, -1), ), "_1": UpsampleBlock(D, E, M, P, S, i=1, extra_in_ch=3, groups=(1, 4)), "_2": UpsampleBlock(D, E, M, P, S, i=2, extra_in_ch=3, groups=(1, 4)), "_3": UpsampleBlock(D, E, M, P, S, i=3, extra_in_ch=0, groups=(1, 4)), } ) self.latent_codec = nn.ModuleDict( { f"_{i}": EntropyBottleneckLatentCodec(channels=M[i], tail_mass=1e-4) for i in range(self.levels) if M[i] > 0 } ) def forward(self, input): xyz, norm = self._get_inputs(input) y_out_, u_, uu_ = self._compress(xyz, norm, mode="forward") x_hat, y_hat_, v_ = self._decompress(y_out_, mode="forward") return { "x_hat": x_hat, "likelihoods": { f"y_{i}": y_out_[i]["likelihoods"]["y"] for i in range(self.levels) if "likelihoods" in y_out_[i] }, # Additional outputs: "debug_outputs": { **{f"u_{i}": v for i, v in u_.items() if v is not None}, **{f"uu_{i}": v for i, v in uu_.items()}, **{f"y_hat_{i}": v for i, v in y_hat_.items()}, **{f"v_{i}": v for i, v in v_.items() if v.numel() > 0}, }, } def compress(self, input): xyz, norm = self._get_inputs(input) y_out_, _, _ = self._compress(xyz, norm, mode="compress") return { # "strings": {f"y_{i}": y_out_[i]["strings"] for i in range(self.levels)}, # Flatten nested structure into list[list[str]]: "strings": [ ss for level in range(self.levels) for ss in y_out_[level]["strings"] ], "shape": {f"y_{i}": y_out_[i]["shape"] for i in range(self.levels)}, } def decompress(self, strings, shape): y_inputs_ = { i: { "strings": [strings[i]], "shape": shape[f"y_{i}"], } for i in range(self.levels) } x_hat, _, _ = self._decompress(y_inputs_, mode="decompress") return { "x_hat": x_hat, } def _get_inputs(self, input): points = input["pos"].transpose(-2, -1) if self.normal_channel: xyz = points[:, :3, :] norm = points[:, 3:, :] else: xyz = points norm = None return xyz, norm def _compress(self, xyz, norm, *, mode): lc_func = {"forward": lambda lc: lc, "compress": lambda lc: lc.compress}[mode] B, _, _ = xyz.shape xyz_ = {0: xyz} u_ = {0: norm} uu_ = {} y_ = {} y_out_ = {} for i in range(1, self.levels): down_out_i = self.down[f"_{i}"](xyz_[i - 1], u_[i - 1]) xyz_[i] = down_out_i["new_xyz"] u_[i] = down_out_i["new_features"] uu_[i - 1] = down_out_i["grouped_features"] uu_[self.levels - 1] = u_[self.levels - 1][:, :, None, :] for i in reversed(range(0, self.levels)): if self.M[i] == 0: y_out_[i] = {"strings": [[b""] * B], "shape": ()} continue y_[i] = self.h_a[f"_{i}"](uu_[i]) # NOTE: Reshape 1D -> 2D since latent codecs expect 2D inputs. y_out_[i] = lc_func(self.latent_codec[f"_{i}"])(y_[i][..., None]) return y_out_, u_, uu_ def _decompress(self, y_inputs_, *, mode): y_hat_ = {} y_out_ = {} uu_hat_ = {} v_ = {} for i in reversed(range(0, self.levels)): if self.M[i] == 0: continue if mode == "forward": y_out_[i] = y_inputs_[i] elif mode == "decompress": y_out_[i] = self.latent_codec[f"_{i}"].decompress( y_inputs_[i]["strings"], shape=y_inputs_[i]["shape"] ) # NOTE: Reshape 2D -> 1D since latent codecs return 2D outputs. y_hat_[i] = y_out_[i]["y_hat"].squeeze(-1) uu_hat_[i] = self.h_s[f"_{i}"](y_hat_[i]) B, _, *tail = uu_hat_[self.levels - 1].shape v_[self.levels] = uu_hat_[self.levels - 1].new_zeros((B, 0, *tail)) for i in reversed(range(0, self.levels)): v_[i] = self.up[f"_{i}"]( v_[i + 1] if self.M[i] == 0 else torch.cat([v_[i + 1], uu_hat_[i]], dim=1) ) x_hat = v_[0] return x_hat, y_hat_, v_