Source code for compressai_trainer.run.plot_rd

# Copyright (c) 2021-2023, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
#   contributors may be used to endorse or promote products derived from this
#   software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""
RD curves plotter.

See :ref:`tools-plot_rd` for more information.
"""

from __future__ import annotations

import argparse
import sys
import textwrap
from typing import Iterable

import aim
import pandas as pd

from compressai_trainer.plot import plot_rd
from compressai_trainer.utils.aim.query import get_runs_dataframe, run_hashes_by_query
from compressai_trainer.utils.compressai.results import compressai_results_dataframe
from compressai_trainer.utils.optimal import optimal_dataframe
from compressai_trainer.utils.utils import format_dataframe

TITLE = "Performance evaluation on Kodak - PSNR (RGB)"

COMPRESSAI_CODECS = [
    "image/kodak/compressai-bmshj2018-factorized_mse_cuda.json",
    "image/kodak/compressai-bmshj2018-hyperprior_mse_cuda.json",
    "image/kodak/compressai-mbt2018-mean_mse_cuda.json",
    "image/kodak/compressai-mbt2018_mse_cuda.json",
    "image/kodak/compressai-cheng2020-anchor_mse_cuda.json",
    "image/kodak/vtm.json",
]

HOVER_HPARAMS = [
    "criterion.lmbda",
]

HOVER_METRICS = [
    "loss",
]

HOVER_DATA = [
    "run_hash",
    "name",
    "model.name",
    "experiment",
    "epoch",
]

HOVER_DATA += HOVER_HPARAMS + HOVER_METRICS


def _reference_dataframes():
    return [compressai_results_dataframe(filename) for filename in COMPRESSAI_CODECS]


[docs]def create_dataframe(repo, args): reference_dfs = _reference_dataframes() dfs = [ _create_dataframe(repo, args.x, args.y, query, curves, optimal) for query, curves, optimal in zip(args.query, args.curves, args.optimal) ] df = pd.concat([*reference_dfs, *dfs]) df = _reorder_dataframe_columns(df) return df
def _create_dataframe(repo, x, y, query, curves, optimal): run_hashes = run_hashes_by_query(repo, query) metrics = sorted( {x, y, *HOVER_METRICS} | set(_needed_metrics(curves, "x")) | set(_needed_metrics(curves, "y")) ) hparams = HOVER_HPARAMS df = get_runs_dataframe( run_hashes=run_hashes, repo=repo, metrics=metrics, hparams=hparams, epoch="best", ) df = format_dataframe(df, x, y, curves, skip_nan=True) df = optimal_dataframe(df, x=x, y=y, method=optimal, groupby="name") df.sort_values(["name", x, y], inplace=True) df.reset_index(drop=True, inplace=True) return df def _needed_metrics(xs, key) -> Iterable[str]: for x in xs: xk = x[key] if isinstance(xk, str): yield xk continue yield from xk def _reorder_dataframe_columns(df: pd.DataFrame) -> pd.DataFrame: head = HOVER_DATA head_set = set(head) columns = head + [x for x in df.columns if x not in head_set] return df[columns]
[docs]def plot_dataframe(df: pd.DataFrame, args): scatter_kwargs = dict( x=args.x, y=args.y, hover_data=HOVER_DATA, ) print(df) if args.out_csv: df.to_csv(args.out_csv, index=False) fig = plot_rd(df, scatter_kwargs=scatter_kwargs, layout_kwargs=dict(title=TITLE)) if args.out_html: from plotly.offline import plot plot(fig, auto_open=False, filename=args.out_html) if args.show: fig.show()
[docs]def wrap(s): return "\n".join("\n".join(textwrap.wrap(x)) for x in s.splitlines())
[docs]def build_args(argv): help = { "description": wrap( "Plot.\n" "\n" "Queries experiment tracker (Aim) repository for relevant metrics and plots. " "Users may specify what to plot using groups of --query, --curves, and --optimal. " "If desired, one may plot multiple query groups within the same plot.\n" ), "show": "Show figure in browser.", "query": ( "Query selector for relevant runs to aggregate plotting data from.\n" "\n" "Default: '' (i.e. uses all runs).\n" "\n" "Examples:\n" " - 'run.hash == \"e4e6d4d5e5c59c69f3bd7be2\"'\n" " - 'run.model.name == \"bmshj2018-factorized\"'\n" " - 'run.experiment.startswith(\"some-prefix-\")'\n" " - 'run.created_at >= datetime(1970, 1, 1)'\n" " - 'run.criterion.lmbda < 0.02 and run.hp.M == 3 * 2**6'\n" ), "curves": ( wrap( "For the current query, specify a grouping and format for the curves. " "One may specify multiple such groupings for a given query within a list. " 'Each unique "name" produces a unique curve. ' 'If a key (e.g. "name", "x", "y") is not specified, its default value is used.\n' "\n" 'For "name", one may specify a hparam as by key via "{hparam}".\n' ) + ( "\n" "\n" 'Default: [{"name": "{experiment}", "x": args.x, "y": args.y}].\n' "\n" "Examples:\n" " - Show both model name and experiment name:\n" ' [{"name": "{model.name} {experiment}"}]\n' " - Group by hp.M:\n" ' [{"name": "{experiment} (M={hp.M})"}]\n' " - Multiple metrics as separate curves:\n" " [\n" ' {"name": "{experiment} (RGB-PSNR)", "y": "psnr_rgb"},\n' ' {"name": "{experiment} (YUV-PSNR)", "y": "psnr_yuv"},\n' " ]\n" " - Multi-rate models (e.g. G-VAE):\n" " [{\n" ' "name": "{experiment} {run.hash}",\n' ' "x": ["bpp_0", "bpp_1", "bpp_2", "bpp_3"],\n' ' "y": ["psnr_0", "psnr_1", "psnr_2", "psnr_3"],\n' " }]\n" ) ), "optimal": ( "Show only optimal points on curve for respective query.\n" "Choices: none, pareto, convex.\n" "Default: none.\n" ), } parser = argparse.ArgumentParser( description=help["description"], formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument("--aim_repo", type=str, required=True) parser.add_argument("--out_html", type=str, default="plot_result.html") parser.add_argument("--out_csv", type=str, default="plot_result.csv") parser.add_argument("--show", action="store_true", help=help["show"]) parser.add_argument("--x", "-x", type=str, default="bpp") parser.add_argument("--y", "-y", type=str, default="psnr") parser.add_argument( "--query", "-q", action="append", default=[], help=help["query"] ) parser.add_argument( "--curves", "-c", action="append", default=[], help=help["curves"], ) parser.add_argument("--optimal", action="append", default=[], help=help["optimal"]) args = parser.parse_args(argv) if len(args.query) == 0: args.query = [""] num_queries = len(args.query) curves_default = {"name": "{experiment}", "x": args.x, "y": args.y} args.curves = [eval(x) for x in args.curves] # WARNING: unsafe! args.curves = [[{**curves_default, **x} for x in xs] for xs in args.curves] args.curves += [[curves_default]] * (num_queries - len(args.curves)) args.optimal += ["none"] * (num_queries - len(args.optimal)) return args
def _main(argv): args = build_args(argv) repo = aim.Repo(args.aim_repo) df = create_dataframe(repo, args) plot_dataframe(df, args)
[docs]def main(): _main(sys.argv[1:])
if __name__ == "__main__": main()