Source code for compressai_vision.evaluators.base_evaluator
# Copyright (c) 2022-2024, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import logging
from pathlib import Path
import torch.nn as nn
[docs]class BaseEvaluator(nn.Module):
def __init__(
self,
datacatalog_name,
dataset_name,
dataset,
output_dir="./vision_output/",
criteria=None,
):
self._logger = logging.getLogger(self.__class__.__name__)
self.datacatalog_name = datacatalog_name
self.dataset_name = dataset_name
self.output_dir = output_dir
self.criteria = criteria
self.output_file_name = (
f"{self.__class__.__name__}_on_{datacatalog_name}_{dataset_name}"
)
path = Path(self.output_dir)
if (not path.is_dir()) and not path.exists():
self._logger.info(f"creating output folder: {path}")
path.mkdir(parents=True, exist_ok=True)
[docs] def set_annotation_info(self, dataset):
self.annotation_path = dataset.annotation_path
self.seqinfo_path = dataset.seqinfo_path
self.thing_classes = dataset.thing_classes
self.thing_id_mapping = dataset.thing_dataset_id_to_contiguous_id
[docs] @staticmethod
def get_jde_eval_info_name(name):
return f"{name}_info_to_eval.h5"
[docs] @staticmethod
def get_coco_eval_info_name(name):
return "coco_instances_results.json"
[docs] def reset(self):
raise NotImplementedError
[docs] def digest(self, gt, pred):
raise NotImplementedError
[docs] def results(self, save_path: str = None):
raise NotImplementedError
[docs] def write_results(self, out, path: str = None):
if path is None:
path = f"{self.output_dir}"
path = Path(path)
if not path.is_dir():
self._logger.info(f"creating output folder: {path}")
path.mkdir(parents=True, exist_ok=True)
with open(f"{path}/{self.output_file_name}.json", "w", encoding="utf-8") as f:
json.dump(out, f, ensure_ascii=False, indent=4)