# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Code adapted from https://github.com/yunhe20/D-PCC
from __future__ import annotations
import numpy as np
import torch
import torch.nn as nn
from compressai.entropy_models import EntropyBottleneck
from compressai.latent_codecs import EntropyBottleneckLatentCodec
from compressai.layers.pointcloud.hrtzxf2022 import (
DownsampleLayer,
EdgeConv,
RefineLayer,
UpsampleLayer,
UpsampleNumLayer,
nearby_distance_sum,
)
from compressai.layers.pointcloud.utils import select_xyzs_and_feats
from compressai.models import CompressionModel
from compressai.registry import register_model
__all__ = [
"DensityPreservingReconstructionPccModel",
]
[docs]
@register_model("hrtzxf2022-pcc-rec")
class DensityPreservingReconstructionPccModel(CompressionModel):
"""Density-preserving deep point cloud compression.
Model introduced by [He2022pcc]_.
References:
.. [He2022pcc] `"Density-preserving Deep Point Cloud Compression"
<https://arxiv.org/abs/2204.12684>`_, by Yun He, Xinlin Ren,
Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu,
CVPR 2022.
"""
def __init__(
self,
downsample_rate=(1 / 3, 1 / 3, 1 / 3),
candidate_upsample_rate=(8, 8, 8),
in_dim=3,
feat_dim=8,
hidden_dim=64,
k=16,
ngroups=1,
sub_point_conv_mode="mlp",
compress_normal=False,
latent_xyzs_codec=None,
**kwargs,
):
super().__init__()
self.compress_normal = compress_normal
self.pre_conv = nn.Sequential(
nn.Conv1d(in_dim, hidden_dim, 1),
nn.GroupNorm(ngroups, hidden_dim),
nn.ReLU(),
nn.Conv1d(hidden_dim, feat_dim, 1),
)
self.encoder = Encoder(
downsample_rate,
feat_dim,
hidden_dim,
k,
ngroups,
)
self.decoder = Decoder(
downsample_rate,
candidate_upsample_rate,
feat_dim,
hidden_dim,
k,
sub_point_conv_mode,
compress_normal,
)
self.latent_codec = nn.ModuleDict(
{
"feat": EntropyBottleneckLatentCodec(channels=feat_dim),
"xyz": XyzsLatentCodec(
feat_dim, hidden_dim, k, ngroups, **(latent_xyzs_codec or {})
),
}
)
def _prepare_input(self, input):
input_data = [input["pos"]]
if self.compress_normal:
input_data.append(input["normal"])
input_data = torch.cat(input_data, dim=1).permute(0, 2, 1).contiguous()
xyzs = input_data[:, :3].contiguous()
gt_normals = input_data[:, 3 : 3 + 3 * self.compress_normal].contiguous()
feats = input_data
return xyzs, gt_normals, feats
def forward(self, input):
# xyzs: (b, 3, n)
xyzs, gt_normals, feats = self._prepare_input(input)
feats = self.pre_conv(feats)
gt_xyzs_, gt_dnums_, gt_mdis_, latent_xyzs, latent_feats = self.encoder(
xyzs, feats
)
gt_latent_xyzs = latent_xyzs
# NOTE: Temporarily reshape to (b, c, m, 1) for compatibility.
latent_feats = latent_feats.unsqueeze(-1)
latent_feats_out = self.latent_codec["feat"](latent_feats)
latent_feats_hat = latent_feats_out["y_hat"].squeeze(-1)
latent_xyzs_out = self.latent_codec["xyz"](latent_xyzs)
latent_xyzs_hat = latent_xyzs_out["y_hat"]
xyzs_hat_, unums_hat_, mdis_hat_, feats_hat = self.decoder(
latent_xyzs_hat, latent_feats_hat
)
# Permute final xyzs_hat back to (b, n, c)
xyzs_hat = xyzs_hat_[-1].permute(0, 2, 1).contiguous()
return {
"x_hat": xyzs_hat,
"xyz_hat_": xyzs_hat_,
"latent_xyz_hat": latent_xyzs_hat,
"feat_hat": feats_hat,
"upsample_num_hat_": unums_hat_,
"mean_distance_hat_": mdis_hat_,
"gt_xyz_": gt_xyzs_,
"gt_latent_xyz": gt_latent_xyzs,
"gt_normal": gt_normals,
"gt_downsample_num_": gt_dnums_,
"gt_mean_distance_": gt_mdis_,
"likelihoods": {
"latent_feat": latent_feats_out["likelihoods"]["y"],
"latent_xyz": latent_xyzs_out["likelihoods"]["y"],
},
}
def compress(self, input):
xyzs, _, feats = self._prepare_input(input)
feats = self.pre_conv(feats)
_, _, _, latent_xyzs, latent_feats = self.encoder(xyzs, feats)
latent_feats = latent_feats.unsqueeze(-1)
latent_feats_out = self.latent_codec["feat"].compress(latent_feats)
latent_xyzs = latent_xyzs
latent_xyzs_out = self.latent_codec["xyz"].compress(latent_xyzs)
return {
"strings": [
latent_feats_out["strings"],
latent_xyzs_out["strings"],
],
"shape": [
latent_feats_out["shape"],
latent_xyzs_out["shape"],
],
}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 2
latent_feats_out = self.latent_codec["feat"].decompress(strings[0], shape[0])
latent_feats_hat = latent_feats_out["y_hat"].squeeze(-1)
latent_xyzs_out = self.latent_codec["xyz"].decompress(strings[1], shape[1])
latent_xyzs_hat = latent_xyzs_out["y_hat"]
xyzs_hat_, _, _, feats_hat = self.decoder(latent_xyzs_hat, latent_feats_hat)
# Permute final xyzs_hat back to (b, n, c)
xyzs_hat = xyzs_hat_[-1].permute(0, 2, 1).contiguous()
return {
"x_hat": xyzs_hat,
"feat_hat": feats_hat,
}
class XyzsLatentCodec(nn.Module):
def __init__(self, dim, hidden_dim, k, ngroups, mode="learned", conv_mode="mlp"):
super().__init__()
self.mode = mode
if mode == "learned":
if conv_mode == "edge_conv":
self.analysis = EdgeConv(3, dim, hidden_dim, k)
self.synthesis = EdgeConv(dim, 3, hidden_dim, k)
elif conv_mode == "mlp":
self.analysis = nn.Sequential(
nn.Conv1d(3, hidden_dim, 1),
nn.GroupNorm(ngroups, hidden_dim),
nn.ReLU(inplace=True),
nn.Conv1d(hidden_dim, dim, 1),
)
self.synthesis = nn.Sequential(
nn.Conv1d(dim, hidden_dim, 1),
nn.GroupNorm(ngroups, hidden_dim),
nn.ReLU(inplace=True),
nn.Conv1d(hidden_dim, 3, 1),
)
else:
raise ValueError(f"Unknown conv_mode: {conv_mode}")
self.entropy_bottleneck = EntropyBottleneck(dim)
else:
self.placeholder = nn.Parameter(torch.empty(0))
def forward(self, latent_xyzs):
if self.mode == "learned":
z = self.analysis(latent_xyzs)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
latent_xyzs_hat = self.synthesis(z_hat)
elif self.mode == "float16":
z_likelihoods = latent_xyzs.new_full(latent_xyzs.shape, 2**-16)
latent_xyzs_hat = latent_xyzs.to(torch.float16).float()
else:
raise ValueError(f"Unknown mode: {self.mode}")
return {"likelihoods": {"y": z_likelihoods}, "y_hat": latent_xyzs_hat}
def compress(self, latent_xyzs):
if self.mode == "learned":
z = self.analysis(latent_xyzs)
shape = z.shape[2:]
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, shape)
latent_xyzs_hat = self.synthesis(z_hat)
elif self.mode == "float16":
z = latent_xyzs
shape = z.shape[2:]
z_hat = latent_xyzs.to(torch.float16)
z_strings = [
np.ascontiguousarray(x, dtype=">f2").tobytes()
for x in z_hat.cpu().numpy()
]
latent_xyzs_hat = z_hat.float()
else:
raise ValueError(f"Unknown mode: {self.mode}")
return {"strings": [z_strings], "shape": shape, "y_hat": latent_xyzs_hat}
def decompress(self, strings, shape):
[z_strings] = strings
if self.mode == "learned":
z_hat = self.entropy_bottleneck.decompress(z_strings, shape)
latent_xyzs_hat = self.synthesis(z_hat)
elif self.mode == "float16":
z_hat = [np.frombuffer(s, dtype=">f2").reshape(shape) for s in z_strings]
z_hat = torch.from_numpy(np.stack(z_hat)).to(self.placeholder.device)
latent_xyzs_hat = z_hat.float()
else:
raise ValueError(f"Unknown mode: {self.mode}")
return {"y_hat": latent_xyzs_hat}
class Encoder(nn.Module):
def __init__(self, downsample_rate, dim, hidden_dim, k, ngroups):
super().__init__()
downsample_layers = [
DownsampleLayer(downsample_rate[i], dim, hidden_dim, k, ngroups)
for i in range(len(downsample_rate))
]
self.downsample_layers = nn.ModuleList(downsample_layers)
def forward(self, xyzs, feats):
# xyzs: (b, 3, n)
# feats: (b, c, n)
gt_xyzs_ = []
gt_dnums_ = []
gt_mdis_ = []
for downsample_layer in self.downsample_layers:
gt_xyzs_.append(xyzs)
xyzs, feats, downsample_num, mean_distance = downsample_layer(xyzs, feats)
gt_dnums_.append(downsample_num)
gt_mdis_.append(mean_distance)
latent_xyzs = xyzs
latent_feats = feats
return gt_xyzs_, gt_dnums_, gt_mdis_, latent_xyzs, latent_feats
class Decoder(nn.Module):
def __init__(
self,
downsample_rate,
candidate_upsample_rate,
dim,
hidden_dim,
k,
sub_point_conv_mode,
compress_normal,
):
super().__init__()
self.k = k
self.compress_normal = compress_normal
self.num_layers = len(downsample_rate)
self.downsample_rate = downsample_rate
self.upsample_layers = nn.ModuleList(
[
UpsampleLayer(
dim,
hidden_dim,
k,
sub_point_conv_mode,
candidate_upsample_rate[i],
)
for i in range(self.num_layers)
]
)
self.upsample_num_layers = nn.ModuleList(
[
UpsampleNumLayer(
dim,
hidden_dim,
candidate_upsample_rate[i],
)
for i in range(self.num_layers)
]
)
self.refine_layers = nn.ModuleList(
[
RefineLayer(
dim,
hidden_dim,
k,
sub_point_conv_mode,
compress_normal and i == self.num_layers - 1,
)
for i in range(self.num_layers)
]
)
def forward(self, xyzs, feats):
# xyzs: (b, 3, n)
# feats: (b, c, n)
latent_xyzs = xyzs.clone()
xyzs_hat_ = []
unums_hat_ = []
for i, (upsample_nn, upsample_num_nn, refine_nn) in enumerate(
zip(self.upsample_layers, self.upsample_num_layers, self.refine_layers)
):
# candidate_xyzs: (b, 3, n u)
# candidate_feats: (b, c, n u)
# upsample_num: (b, n)
# xyzs: (b, 3, m) [after upsample and select]
# feats: (b, c, m) [after upsample and select]
# For each point within the current set of "n" points,
# upsample a fixed number "u" of candidate points.
# The resulting candidate points have the shape (n, u).
candidate_xyzs, candidate_feats = upsample_nn(xyzs, feats)
# Determine local point cloud density near each upsampled group:
upsample_num = upsample_num_nn(feats)
# Subsample each point group to match the desired local density.
# That is, from the i-th point group, select upsample_num[..., i] points.
# Then, collect all the points so the resulting point set has shape (m_i,).
#
# If the batch size is >1, then the "m_i"s may be different.
# In that case, resample each point set within the batch
# until they all have the same shape (m,).
# This can be done by either selecting a subset or
# duplicating points as necessary.
#
# Since one of the goals is to reduce local point cloud
# density in certain regions, we are happy with throwing
# away distinct points, and then duplicating the remaining
# points until they can fit within the desired tensor shape.
# Select subset of points to match predicted local point cloud densities:
xyzs, feats = select_xyzs_and_feats(
candidate_xyzs,
candidate_feats,
upsample_num,
upsample_rate=1 / self.downsample_rate[self.num_layers - i - 1],
)
# Refine upsampled points.
xyzs, feats = refine_nn(xyzs, feats)
xyzs_hat_.append(xyzs)
unums_hat_.append(upsample_num)
# Compute mean distance between centroids and the upsampled points.
mdis_hat_ = self.get_pred_mdis([latent_xyzs, *xyzs_hat_], unums_hat_)
return xyzs_hat_, unums_hat_, mdis_hat_, feats
def get_pred_mdis(self, xyzs_hat_, unums_hat_):
mdis_hat_ = []
for prev_xyzs, curr_xyzs, curr_unums in zip(
xyzs_hat_[:-1], xyzs_hat_[1:], unums_hat_
):
# Compute mean distance for each point in "prev" to upsampled "curr".
distance, _, _, _ = nearby_distance_sum(prev_xyzs, curr_xyzs, self.k)
curr_mdis = distance / curr_unums
mdis_hat_.append(curr_mdis)
return mdis_hat_