# Copyright (c) 2021-2023, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
import math
import string
from collections import defaultdict
from typing import Any, TypeVar
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
K = TypeVar("K")
V = TypeVar("V")
[docs]def np_img_to_tensor(x: np.ndarray) -> torch.Tensor:
return torch.from_numpy(x).moveaxis(-1, -3).to(torch.float32) / 255
[docs]def tensor_to_np_img(x: torch.Tensor) -> np.ndarray:
return (x * 255).clip(0, 255).to(torch.uint8).moveaxis(-3, -1).cpu().numpy()
[docs]def compute_padding(in_h: int, in_w: int, *, out_h=None, out_w=None, min_div=1):
"""Returns tuples for padding and unpadding.
NOTE: This is also available in ``compressai.ops`` as of v1.2.4.
Args:
in_h: Input height.
in_w: Input width.
out_h: Output height.
out_w: Output width.
min_div: Length that output dimensions should be divisible by.
"""
if out_h is None:
out_h = (in_h + min_div - 1) // min_div * min_div
if out_w is None:
out_w = (in_w + min_div - 1) // min_div * min_div
if out_h % min_div != 0 or out_w % min_div != 0:
raise ValueError(
f"Padded output height and width are not divisible by min_div={min_div}."
)
left = (out_w - in_w) // 2
right = out_w - in_w - left
top = (out_h - in_h) // 2
bottom = out_h - in_h - top
pad = (left, right, top, bottom)
unpad = (-left, -right, -top, -bottom)
return pad, unpad
[docs]def num_parameters(net: nn.Module, predicate=lambda x: x.requires_grad) -> int:
unique = {x.data_ptr(): x for x in net.parameters() if predicate(x)}.values()
return sum(x.numel() for x in unique)
def _is_nan(x):
return x is None or math.isnan(x)
def _coerce_list(x):
if isinstance(x, list):
return x
return [x]
[docs]def flatten_values(x, value_type=object):
if isinstance(x, list) or isinstance(x, tuple) or isinstance(x, set):
for v in x:
yield from flatten_values(v)
elif isinstance(x, dict):
for v in x.values():
yield from flatten_values(v)
elif isinstance(x, value_type):
yield x
else:
raise ValueError(f"Unexpected type {type(x)}")
[docs]def dl_to_ld(dl: dict[K, list[V]]) -> list[dict[K, V]]:
"""Converts a dict of lists into a list of dicts."""
ld = []
for k, vs in dl.items():
ld += [{} for _ in range(len(vs) - len(ld))]
for i, v in enumerate(vs):
ld[i][k] = v
return ld
[docs]def ld_to_dl(ld: list[dict[K, V]]) -> dict[K, list[V]]:
"""Converts a list of dicts into a dict of lists."""
dl = defaultdict(list)
for d in ld:
for k, v in d.items():
dl[k].append(v)
return dl