Source code for compressai_vision.model_wrappers.detectron2

# Copyright (c) 2022-2024, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from enum import Enum
from pathlib import Path
from typing import Dict, List

import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.modeling import build_model
from detectron2.structures import ImageList

from compressai_vision.registry import register_vision_model

from .base_wrapper import BaseWrapper

__all__ = [
    "faster_rcnn_X_101_32x8d_FPN_3x",
    "mask_rcnn_X_101_32x8d_FPN_3x",
    "faster_rcnn_R_50_FPN_3x",
    "mask_rcnn_R_50_FPN_3x",
]

thisdir = Path(__file__).parent
root_path = thisdir.joinpath("../..")


class Split_Points(Enum):
    def __str__(self):
        return str(self.value)

    FeaturePyramidNetwork = "fpn"
    C2 = "c2"
    Res2 = "r2"


class Rcnn_R_50_X_101_FPN(BaseWrapper):
    def __init__(self, device: str, **kwargs):
        super().__init__(device)

        self._cfg = get_cfg()
        self._cfg.MODEL.DEVICE = device
        _path_prefix = (
            f"{root_path}"
            if kwargs["model_path_prefix"] == "default"
            else kwargs["model_path_prefix"]
        )
        self._cfg.merge_from_file(f"{_path_prefix}/{kwargs['cfg']}")

        self.model = build_model(self._cfg).to(device).eval()

        self.backbone = self.model.backbone
        self.top_block = self.model.backbone.top_block
        self.proposal_generator = self.model.proposal_generator
        self.roi_heads = self.model.roi_heads
        self.postprocess = self.model._postprocess
        DetectionCheckpointer(self.model).load(f"{_path_prefix}/{kwargs['weights']}")

        self.model_info = {"cfg": kwargs["cfg"], "weights": kwargs["weights"]}

        self.supported_split_points = Split_Points

        assert "splits" in kwargs, "Split layer ids must be provided"
        self.split_id = str(kwargs["splits"]).lower()

        if self.split_id == str(self.supported_split_points.FeaturePyramidNetwork):
            self.split_layer_list = ["p2", "p3", "p4", "p5"]
        elif self.split_id == str(self.supported_split_points.C2):
            self.split_layer_list = ["c2", "c3", "c4", "c5"]
        elif self.split_id == str(self.supported_split_points.Res2):
            self.split_layer_list = ["r2"]
        else:
            raise NotImplementedError

        self.features_at_splits = dict(
            zip(self.split_layer_list, [None] * len(self.split_layer_list))
        )

        assert self.top_block is not None
        assert self.proposal_generator is not None

    @property
    def SPLIT_FPN(self):
        return str(self.supported_split_points.FeaturePyramidNetwork)

    @property
    def SPLIT_C2(self):
        return str(self.supported_split_points.C2)

    @property
    def SPLIT_R2(self):
        return str(self.supported_split_points.Res2)

    @property
    def size_divisibility(self):
        return self.backbone.size_divisibility

    def input_resize(self, images: List):
        return ImageList.from_tensors(images, self.size_divisibility)

    def input_to_features(self, x, device: str) -> Dict:
        """Computes deep features at the intermediate layer(s) all the way from the input"""

        self.model = self.model.to(device).eval()

        if self.split_id == self.SPLIT_FPN:
            return self._input_to_feature_pyramid(x)
        elif self.split_id == self.SPLIT_C2:
            return self._input_to_c2(x)
        elif self.split_id == self.SPLIT_R2:
            return self._input_to_r2(x)
        else:
            self.logger.error(f"Not supported split point {self.split_id}")

        raise NotImplementedError

    def features_to_output(self, x: Dict, device: str):
        """Complete the downstream task from the intermediate deep features"""

        self.model = self.model.to(device).eval()

        if self.split_id == self.SPLIT_FPN:
            return self._feature_pyramid_to_output(
                x["data"], x["org_input_size"], x["input_size"]
            )
        elif self.split_id == self.SPLIT_C2:
            return self._feature_c2_to_output(
                x["data"], x["org_input_size"], x["input_size"]
            )
        elif self.split_id == self.SPLIT_R2:
            return self._feature_r2_to_output(
                x["data"], x["org_input_size"], x["input_size"]
            )
        else:
            self.logger.error(f"Not supported split points {self.split_id}")

        raise NotImplementedError

    @torch.no_grad()
    def _input_to_feature_pyramid(self, x):
        """Computes and return feature pyramid ['p2', 'p3', 'p4', 'p5'] all the way from the input"""
        imgs = self.model.preprocess_image(x)
        feature_pyramid = self.backbone(imgs.tensor)
        del feature_pyramid["p6"]

        return {"data": feature_pyramid, "input_size": imgs.image_sizes}

    @torch.no_grad()
    def _input_to_c2(self, x):
        """Computes and return feature tensors at C2 from input"""
        imgs = self.model.preprocess_image(x)

        c_features = self.split_layer_list
        ref_features = self.backbone.in_features

        results = []

        # Resnet FPN
        bottom_up_features = self.backbone.bottom_up(imgs.tensor)

        for idx, lateral_conv in enumerate(self.backbone.lateral_convs):
            features = bottom_up_features[ref_features[-idx - 1]]
            results.insert(0, lateral_conv(features))

        assert len(c_features) == len(results)
        out = {f: res for f, res in zip(c_features, results)}

        return {"data": out, "input_size": imgs.image_sizes}

    @torch.no_grad()
    def _input_to_r2(self, x):
        """Computes and return feature tensor at R2 from input"""
        imgs = self.model.preprocess_image(x)

        # Resnet FPN
        stem_out = self.backbone.bottom_up.stem(imgs.tensor)
        r2_out = self.backbone.bottom_up.res2(stem_out)

        return {"data": {"r2": r2_out}, "input_size": imgs.image_sizes}

    @torch.no_grad()
    def get_input_size(self, x):
        """Computes input image size to the network"""
        imgs = self.model.preprocess_image(x)
        return imgs.image_sizes

    @torch.no_grad()
    def _feature_pyramid_to_output(
        self, x: Dict, org_img_size: Dict, input_img_size: List
    ):
        """
        performs  downstream task using the feature pyramid ['p2', 'p3', 'p4', 'p5']

        Detectron2 source codes are referenced for this function, specifically the class "GeneralizedRCNN"
        Unnecessary parts for split inference are removed or modified properly.

        Please find the license statement in the downloaded original Detectron2 source codes or at here:
        https://github.com/facebookresearch/detectron2/blob/main/LICENSE

        """

        class dummy:
            def __init__(self, img_size: list):
                self.image_sizes = img_size

        cdummy = dummy(input_img_size)

        # Replacing tag names for interfacing with NN-part2
        x = dict(zip(self.features_at_splits.keys(), x.values()))
        x.update({"p6": self.top_block(x["p5"])[0]})

        proposals, _ = self.proposal_generator(cdummy, x, None)
        results, _ = self.roi_heads(cdummy, x, proposals, None)

        assert (
            not torch.jit.is_scripting()
        ), "Scripting is not supported for postprocess."
        return self.model._postprocess(
            results,
            [
                org_img_size,
            ],
            input_img_size,
        )

    @torch.no_grad()
    def _feature_c2_to_output(self, x: Dict, org_img_size: Dict, input_img_size: List):
        """
        performs  downstream task using the c2 ['c2', 'c3', 'c4', 'c5']

        Detectron2 source codes are referenced for this function, specifically the class "GeneralizedRCNN"
        Unnecessary parts for split inference are removed or modified properly.

        Please find the license statement in the downloaded original Detectron2 source codes or at here:
        https://github.com/facebookresearch/detectron2/blob/main/LICENSE

        """
        # Replacing tag names for interfacing with NN-part2
        x = dict(zip(self.features_at_splits.keys(), x.values()))
        x = self.backbone.forward_after_c2(x)

        class dummy:
            def __init__(self, img_size: list):
                self.image_sizes = img_size

        cdummy = dummy(input_img_size)

        proposals, _ = self.proposal_generator(cdummy, x, None)
        results, _ = self.roi_heads(cdummy, x, proposals, None)

        assert (
            not torch.jit.is_scripting()
        ), "Scripting is not supported for postprocess."
        return self.model._postprocess(
            results,
            [
                org_img_size,
            ],
            input_img_size,
        )

    @torch.no_grad()
    def _feature_r2_to_output(self, x: Dict, org_img_size: Dict, input_img_size: List):
        assert "r2" in x

        r2_out = x["r2"]
        r3_out = self.backbone.bottom_up.res3(r2_out)
        r4_out = self.backbone.bottom_up.res4(r3_out)
        r5_out = self.backbone.bottom_up.res5(r4_out)

        bottom_up_features = {
            "res2": r2_out,
            "res3": r3_out,
            "res4": r4_out,
            "res5": r5_out,
        }

        fptensors = self.backbone(bottom_up_features, no_bottom_up=True)

        class dummy:
            def __init__(self, img_size: list):
                self.image_sizes = img_size

        cdummy = dummy(input_img_size)

        proposals, _ = self.proposal_generator(cdummy, fptensors, None)
        results, _ = self.roi_heads(cdummy, fptensors, proposals, None)

        assert (
            not torch.jit.is_scripting()
        ), "Scripting is not supported for postprocess."

        return self.model._postprocess(
            results,
            [
                org_img_size,
            ],
            input_img_size,
        )

    @torch.no_grad()
    def deeper_features_for_accuracy_proxy(self, x: Dict):
        """
        compute accuracy proxy at the deeper layer than NN-Part1
        """
        raise NotImplementedError

        d = {}
        for e, ft in enumerate(x["data"].values()):
            nft = ft.contiguous().to(self.device)
            assert (
                nft.dim() == 3 or nft.dim() == 4
            ), f"Input feature tensor dimension is supposed to be 3 or 4, but got {nft.dim()}"
            d[e] = nft.unsqueeze(0) if nft.dim() == 3 else nft

        class dummy:
            def __init__(self, img_size: list):
                self.image_sizes = img_size

        cdummy = dummy(x["input_size"])

        # Replacing tag names for interfacing with NN-part2
        d = dict(zip(self.features_at_splits.keys(), d.values()))
        d.update({"p6": self.top_block(d["p5"])[0]})

        proposals, _ = self.proposal_generator(cdummy, d, None)

        return proposals[0]

    @torch.no_grad()
    def forward(self, x):
        """Complete the downstream task with end-to-end manner all the way from the input"""
        # test
        return self.model([x])

    @property
    def cfg(self):
        return self._cfg


[docs]@register_vision_model("faster_rcnn_X_101_32x8d_FPN_3x") class faster_rcnn_X_101_32x8d_FPN_3x(Rcnn_R_50_X_101_FPN): def __init__(self, device: str, **kwargs): super().__init__(device, **kwargs)
[docs]@register_vision_model("mask_rcnn_X_101_32x8d_FPN_3x") class mask_rcnn_X_101_32x8d_FPN_3x(Rcnn_R_50_X_101_FPN): def __init__(self, device: str, **kwargs): super().__init__(device, **kwargs)
[docs]@register_vision_model("faster_rcnn_R_50_FPN_3x") class faster_rcnn_R_50_FPN_3x(Rcnn_R_50_X_101_FPN): def __init__(self, device: str, **kwargs): super().__init__(device, **kwargs)
[docs]@register_vision_model("mask_rcnn_R_50_FPN_3x") class mask_rcnn_R_50_FPN_3x(Rcnn_R_50_X_101_FPN): def __init__(self, device: str, **kwargs): super().__init__(device, **kwargs)