Source code for compressai.datasets.pointcloud.modelnet

# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import os.path
import re
import shutil

from pathlib import Path

import numpy as np

try:
    from pyntcloud import PyntCloud
except ImportError:
    pass  # NOTE: Optional dependency.

from compressai.datasets.cache import CacheDataset
from compressai.datasets.utils import download_url, hash_file
from compressai.registry import register_dataset


[docs] @register_dataset("ModelNetDataset") class ModelNetDataset(CacheDataset): """ModelNet dataset. This dataset of 3D CAD models of objects was introduced by [Wu2015]_, consisting of 10 or 40 classes, with 4899 and 12311 aligned items, respectively. Each 3D model is represented in the OFF file format by a triangle mesh (i.e. faces) and has a single label (e.g. airplane). To convert the triangle meshes to point clouds, one may use a mesh sampling method (e.g. ``SamplePoints``). See also: [PapersWithCode_ModelNet]_. References: .. [Wu2015] `"3D ShapeNets: A deep representation for volumetric shapes," <https://arxiv.org/abs/1406.5670>`_, by Zhirong Wu, Shuran Song, Aditya Khosla, Fisher Yu, Linguang Zhang, Xiaoou Tang, and Jianxiong Xiao, CVPR 2015. .. [PapersWithCode_ModelNet] `PapersWithCode: ModelNet <https://paperswithcode.com/dataset/modelnet>`_ """ # fmt: off LABEL_LIST = { "10": [ "bathtub", "bed", "chair", "desk", "dresser", "monitor", "night_stand", "sofa", "table", "toilet", ], "40": [ "airplane", "bathtub", "bed", "bench", "bookshelf", "bottle", "bowl", "car", "chair", "cone", "cup", "curtain", "desk", "door", "dresser", "flower_pot", "glass_box", "guitar", "keyboard", "lamp", "laptop", "mantel", "monitor", "night_stand", "person", "piano", "plant", "radio", "range_hood", "sink", "sofa", "stairs", "stool", "table", "tent", "toilet", "tv_stand", "vase", "wardrobe", "xbox", ], } # fmt: on LABEL_STR_TO_LABEL_INDEX = { "10": {label: idx for idx, label in enumerate(LABEL_LIST["10"])}, "40": {label: idx for idx, label in enumerate(LABEL_LIST["40"])}, } URLS = { "10": "http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip", "40": "http://modelnet.cs.princeton.edu/ModelNet40.zip", } HASHES = { "10": "9d8679435fc07d1d26f13009878db164a7aa8ea5e7ea3c8880e42794b7307d51", "40": "42dc3e656932e387f554e25a4eb2cc0e1a1bd3ab54606e2a9eae444c60e536ac", } def __init__( self, root=None, cache_root=None, split="train", split_name=None, name="40", pre_transform=None, transform=None, download=True, ): if cache_root is None: assert root is not None cache_root = f"{str(root).rstrip('/')}_cache" self.root = Path(root) if root else None self.cache_root = Path(cache_root) self.split = split self.split_name = split if split_name is None else split_name self.name = name if download and self.root: self.download() super().__init__( cache_root=self.cache_root / self.split_name, pre_transform=pre_transform, transform=transform, ) self._ensure_cache() def download(self, force=False): if not force and self.root.exists(): return tmpdir = self.root.parent / "tmp" os.makedirs(tmpdir, exist_ok=True) filepath = download_url(self.URLS[self.name], tmpdir, overwrite=force) assert self.HASHES[self.name] == hash_file(filepath, method="sha256") shutil.unpack_archive(filepath, tmpdir) shutil.move(tmpdir / f"ModelNet{self.name}", self.root) def _get_items(self): return sorted(self.root.glob(f"**/{self.split}/*.off")) def _load_item(self, path): label_index, file_index = self._parse_path(path) cloud = PyntCloud.from_file(str(path)) return { "file_index": np.array([file_index], dtype=np.int32), "label": np.array([label_index], dtype=np.uint8), "pos": cloud.points.values, "face": cloud.mesh.values.T, } def _parse_path(self, path): pattern = ( r"^.*?/?" r"(?P<label_str>[a-zA-Z_]+)/" r"(?P<split>[a-zA-Z_]+)/" r"(?P<label_str_again>[a-zA-Z_]+)_(?P<file_index>\d+)\.off$" ) match = re.match(pattern, str(path)) if match is None: raise ValueError(f"Could not parse path: {path}") assert match.group("split") == self.split assert match.group("label_str") == match.group("label_str_again") label_str = match.group("label_str") label_index = self.LABEL_STR_TO_LABEL_INDEX[self.name][label_str] file_index = int(match.group("file_index")) return label_index, file_index