# Copyright (c) 2021-2024, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from torch.hub import load_state_dict_from_url
from compressai.models import (
Cheng2020Anchor,
Cheng2020Attention,
FactorizedPrior,
FactorizedPriorReLU,
JointAutoregressiveHierarchicalPriors,
MeanScaleHyperprior,
ScaleHyperprior,
)
from .pretrained import load_pretrained
__all__ = [
"bmshj2018_factorized",
"bmshj2018_factorized_relu",
"bmshj2018_hyperprior",
"mbt2018",
"mbt2018_mean",
"cheng2020_anchor",
"cheng2020_attn",
]
model_architectures = {
"bmshj2018-factorized": FactorizedPrior,
"bmshj2018_factorized_relu": FactorizedPriorReLU,
"bmshj2018-hyperprior": ScaleHyperprior,
"mbt2018-mean": MeanScaleHyperprior,
"mbt2018": JointAutoregressiveHierarchicalPriors,
"cheng2020-anchor": Cheng2020Anchor,
"cheng2020-attn": Cheng2020Attention,
}
root_url = "https://compressai.s3.amazonaws.com/models/v1"
model_urls = {
"bmshj2018-factorized": {
"mse": {
1: f"{root_url}/bmshj2018-factorized-prior-1-446d5c7f.pth.tar",
2: f"{root_url}/bmshj2018-factorized-prior-2-87279a02.pth.tar",
3: f"{root_url}/bmshj2018-factorized-prior-3-5c6f152b.pth.tar",
4: f"{root_url}/bmshj2018-factorized-prior-4-1ed4405a.pth.tar",
5: f"{root_url}/bmshj2018-factorized-prior-5-866ba797.pth.tar",
6: f"{root_url}/bmshj2018-factorized-prior-6-9b02ea3a.pth.tar",
7: f"{root_url}/bmshj2018-factorized-prior-7-6dfd6734.pth.tar",
8: f"{root_url}/bmshj2018-factorized-prior-8-5232faa3.pth.tar",
},
"ms-ssim": {
1: f"{root_url}/bmshj2018-factorized-ms-ssim-1-9781d705.pth.tar",
2: f"{root_url}/bmshj2018-factorized-ms-ssim-2-4a584386.pth.tar",
3: f"{root_url}/bmshj2018-factorized-ms-ssim-3-5352f123.pth.tar",
4: f"{root_url}/bmshj2018-factorized-ms-ssim-4-4f91b847.pth.tar",
5: f"{root_url}/bmshj2018-factorized-ms-ssim-5-b3a88897.pth.tar",
6: f"{root_url}/bmshj2018-factorized-ms-ssim-6-ee028763.pth.tar",
7: f"{root_url}/bmshj2018-factorized-ms-ssim-7-8c265a29.pth.tar",
8: f"{root_url}/bmshj2018-factorized-ms-ssim-8-8811bd14.pth.tar",
},
},
"bmshj2018-hyperprior": {
"mse": {
1: f"{root_url}/bmshj2018-hyperprior-1-7eb97409.pth.tar",
2: f"{root_url}/bmshj2018-hyperprior-2-93677231.pth.tar",
3: f"{root_url}/bmshj2018-hyperprior-3-6d87be32.pth.tar",
4: f"{root_url}/bmshj2018-hyperprior-4-de1b779c.pth.tar",
5: f"{root_url}/bmshj2018-hyperprior-5-f8b614e1.pth.tar",
6: f"{root_url}/bmshj2018-hyperprior-6-1ab9c41e.pth.tar",
7: f"{root_url}/bmshj2018-hyperprior-7-3804dcbd.pth.tar",
8: f"{root_url}/bmshj2018-hyperprior-8-a583f0cf.pth.tar",
},
"ms-ssim": {
1: f"{root_url}/bmshj2018-hyperprior-ms-ssim-1-5cf249be.pth.tar",
2: f"{root_url}/bmshj2018-hyperprior-ms-ssim-2-1ff60d1f.pth.tar",
3: f"{root_url}/bmshj2018-hyperprior-ms-ssim-3-92dd7878.pth.tar",
4: f"{root_url}/bmshj2018-hyperprior-ms-ssim-4-4377354e.pth.tar",
5: f"{root_url}/bmshj2018-hyperprior-ms-ssim-5-c34afc8d.pth.tar",
6: f"{root_url}/bmshj2018-hyperprior-ms-ssim-6-3a6d8229.pth.tar",
7: f"{root_url}/bmshj2018-hyperprior-ms-ssim-7-8747d3bc.pth.tar",
8: f"{root_url}/bmshj2018-hyperprior-ms-ssim-8-cc15b5f3.pth.tar",
},
},
"mbt2018-mean": {
"mse": {
1: f"{root_url}/mbt2018-mean-1-e522738d.pth.tar",
2: f"{root_url}/mbt2018-mean-2-e54a039d.pth.tar",
3: f"{root_url}/mbt2018-mean-3-723404a8.pth.tar",
4: f"{root_url}/mbt2018-mean-4-6dba02a3.pth.tar",
5: f"{root_url}/mbt2018-mean-5-d504e8eb.pth.tar",
6: f"{root_url}/mbt2018-mean-6-a19628ab.pth.tar",
7: f"{root_url}/mbt2018-mean-7-d5d441d1.pth.tar",
8: f"{root_url}/mbt2018-mean-8-8089ae3e.pth.tar",
},
"ms-ssim": {
1: f"{root_url}/mbt2018-mean-ms-ssim-1-5bf9c0b6.pth.tar",
2: f"{root_url}/mbt2018-mean-ms-ssim-2-e2a1bf3f.pth.tar",
3: f"{root_url}/mbt2018-mean-ms-ssim-3-640ce819.pth.tar",
4: f"{root_url}/mbt2018-mean-ms-ssim-4-12626c13.pth.tar",
5: f"{root_url}/mbt2018-mean-ms-ssim-5-1be7f059.pth.tar",
6: f"{root_url}/mbt2018-mean-ms-ssim-6-b83bf379.pth.tar",
7: f"{root_url}/mbt2018-mean-ms-ssim-7-ddf9644c.pth.tar",
8: f"{root_url}/mbt2018-mean-ms-ssim-8-0cc7b94f.pth.tar",
},
},
"mbt2018": {
"mse": {
1: f"{root_url}/mbt2018-1-3f36cd77.pth.tar",
2: f"{root_url}/mbt2018-2-43b70cdd.pth.tar",
3: f"{root_url}/mbt2018-3-22901978.pth.tar",
4: f"{root_url}/mbt2018-4-456e2af9.pth.tar",
5: f"{root_url}/mbt2018-5-b4a046dd.pth.tar",
6: f"{root_url}/mbt2018-6-7052e5ea.pth.tar",
7: f"{root_url}/mbt2018-7-8ba2bf82.pth.tar",
8: f"{root_url}/mbt2018-8-dd0097aa.pth.tar",
},
"ms-ssim": {
1: f"{root_url}/mbt2018-ms-ssim-1-2878436b.pth.tar",
2: f"{root_url}/mbt2018-ms-ssim-2-c41cb208.pth.tar",
3: f"{root_url}/mbt2018-ms-ssim-3-d0dd64e8.pth.tar",
4: f"{root_url}/mbt2018-ms-ssim-4-a120e037.pth.tar",
5: f"{root_url}/mbt2018-ms-ssim-5-9b30e3b7.pth.tar",
6: f"{root_url}/mbt2018-ms-ssim-6-f8b3626f.pth.tar",
7: f"{root_url}/mbt2018-ms-ssim-7-16e6ff50.pth.tar",
8: f"{root_url}/mbt2018-ms-ssim-8-0cb49d43.pth.tar",
},
},
"cheng2020-anchor": {
"mse": {
1: f"{root_url}/cheng2020-anchor-1-dad2ebff.pth.tar",
2: f"{root_url}/cheng2020-anchor-2-a29008eb.pth.tar",
3: f"{root_url}/cheng2020-anchor-3-e49be189.pth.tar",
4: f"{root_url}/cheng2020-anchor-4-98b0b468.pth.tar",
5: f"{root_url}/cheng2020-anchor-5-23852949.pth.tar",
6: f"{root_url}/cheng2020-anchor-6-4c052b1a.pth.tar",
},
"ms-ssim": {
1: f"{root_url}/cheng2020_anchor-ms-ssim-1-20f521db.pth.tar",
2: f"{root_url}/cheng2020_anchor-ms-ssim-2-c7ff5812.pth.tar",
3: f"{root_url}/cheng2020_anchor-ms-ssim-3-c23e22d5.pth.tar",
4: f"{root_url}/cheng2020_anchor-ms-ssim-4-0e658304.pth.tar",
5: f"{root_url}/cheng2020_anchor-ms-ssim-5-c0a95e77.pth.tar",
6: f"{root_url}/cheng2020_anchor-ms-ssim-6-f2dc1913.pth.tar",
},
},
"cheng2020-attn": {
"mse": {
1: f"{root_url}/cheng2020_attn-mse-1-465f2b64.pth.tar",
2: f"{root_url}/cheng2020_attn-mse-2-e0805385.pth.tar",
3: f"{root_url}/cheng2020_attn-mse-3-2d07bbdf.pth.tar",
4: f"{root_url}/cheng2020_attn-mse-4-f7b0ccf2.pth.tar",
5: f"{root_url}/cheng2020_attn-mse-5-26c8920e.pth.tar",
6: f"{root_url}/cheng2020_attn-mse-6-730501f2.pth.tar",
},
"ms-ssim": {
1: f"{root_url}/cheng2020_attn-ms-ssim-1-c5381d91.pth.tar",
2: f"{root_url}/cheng2020_attn-ms-ssim-2-5dad201d.pth.tar",
3: f"{root_url}/cheng2020_attn-ms-ssim-3-5c9be841.pth.tar",
4: f"{root_url}/cheng2020_attn-ms-ssim-4-8b2f647e.pth.tar",
5: f"{root_url}/cheng2020_attn-ms-ssim-5-5ca1f34c.pth.tar",
6: f"{root_url}/cheng2020_attn-ms-ssim-6-216423ec.pth.tar",
},
},
}
cfgs = {
"bmshj2018-factorized": {
1: (128, 192),
2: (128, 192),
3: (128, 192),
4: (128, 192),
5: (128, 192),
6: (192, 320),
7: (192, 320),
8: (192, 320),
},
"bmshj2018-factorized-relu": {
1: (128, 192),
2: (128, 192),
3: (128, 192),
4: (128, 192),
5: (128, 192),
6: (192, 320),
7: (192, 320),
8: (192, 320),
},
"bmshj2018-hyperprior": {
1: (128, 192),
2: (128, 192),
3: (128, 192),
4: (128, 192),
5: (128, 192),
6: (192, 320),
7: (192, 320),
8: (192, 320),
},
"mbt2018-mean": {
1: (128, 192),
2: (128, 192),
3: (128, 192),
4: (128, 192),
5: (192, 320),
6: (192, 320),
7: (192, 320),
8: (192, 320),
},
"mbt2018": {
1: (192, 192),
2: (192, 192),
3: (192, 192),
4: (192, 192),
5: (192, 320),
6: (192, 320),
7: (192, 320),
8: (192, 320),
},
"cheng2020-anchor": {
1: (128,),
2: (128,),
3: (128,),
4: (192,),
5: (192,),
6: (192,),
},
"cheng2020-attn": {
1: (128,),
2: (128,),
3: (128,),
4: (192,),
5: (192,),
6: (192,),
},
}
def _load_model(
architecture, metric, quality, pretrained=False, progress=True, **kwargs
):
if architecture not in model_architectures:
raise ValueError(f'Invalid architecture name "{architecture}"')
if quality not in cfgs[architecture]:
raise ValueError(f'Invalid quality value "{quality}"')
if pretrained:
if (
architecture not in model_urls
or metric not in model_urls[architecture]
or quality not in model_urls[architecture][metric]
):
raise RuntimeError("Pre-trained model not yet available")
url = model_urls[architecture][metric][quality]
state_dict = load_state_dict_from_url(url, progress=progress)
state_dict = load_pretrained(state_dict)
model = model_architectures[architecture].from_state_dict(state_dict)
return model
model = model_architectures[architecture](*cfgs[architecture][quality], **kwargs)
return model
[docs]
def bmshj2018_factorized(
quality, metric="mse", pretrained=False, progress=True, **kwargs
):
r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
(ICLR), 2018.
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse', 'ms-ssim')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model(
"bmshj2018-factorized", metric, quality, pretrained, progress, **kwargs
)
def bmshj2018_factorized_relu(
quality, metric="mse", pretrained=False, progress=True, **kwargs
):
r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
(ICLR), 2018.
GDN activations are replaced by ReLU
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse', 'ms-ssim')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model(
"bmshj2018-factorized", metric, quality, pretrained, progress, **kwargs
)
[docs]
def bmshj2018_hyperprior(
quality, metric="mse", pretrained=False, progress=True, **kwargs
):
r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_ Int. Conf. on Learning Representations
(ICLR), 2018.
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse', 'ms-ssim')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model(
"bmshj2018-hyperprior", metric, quality, pretrained, progress, **kwargs
)
[docs]
def mbt2018_mean(quality, metric="mse", pretrained=False, progress=True, **kwargs):
r"""Scale Hyperprior with non zero-mean Gaussian conditionals from D.
Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical
Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_,
Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse', 'ms-ssim')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model("mbt2018-mean", metric, quality, pretrained, progress, **kwargs)
[docs]
def mbt2018(quality, metric="mse", pretrained=False, progress=True, **kwargs):
r"""Joint Autoregressive Hierarchical Priors model from D.
Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical
Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_,
Adv. in Neural Information Processing Systems 31 (NeurIPS 2018).
Args:
quality (int): Quality levels (1: lowest, highest: 8)
metric (str): Optimized metric, choose from ('mse', 'ms-ssim')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 8:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')
return _load_model("mbt2018", metric, quality, pretrained, progress, **kwargs)
[docs]
def cheng2020_anchor(quality, metric="mse", pretrained=False, progress=True, **kwargs):
r"""Anchor model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Args:
quality (int): Quality levels (1: lowest, highest: 6)
metric (str): Optimized metric, choose from ('mse', 'ms-ssim')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 6:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 6)')
return _load_model(
"cheng2020-anchor", metric, quality, pretrained, progress, **kwargs
)
[docs]
def cheng2020_attn(quality, metric="mse", pretrained=False, progress=True, **kwargs):
r"""Self-attention model variant from `"Learned Image Compression with
Discretized Gaussian Mixture Likelihoods and Attention Modules"
<https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
Takeuchi, Jiro Katto.
Args:
quality (int): Quality levels (1: lowest, highest: 6)
metric (str): Optimized metric, choose from ('mse', 'ms-ssim')
pretrained (bool): If True, returns a pre-trained model
progress (bool): If True, displays a progress bar of the download to stderr
"""
if metric not in ("mse", "ms-ssim"):
raise ValueError(f'Invalid metric "{metric}"')
if quality < 1 or quality > 6:
raise ValueError(f'Invalid quality "{quality}", should be between (1, 6)')
return _load_model(
"cheng2020-attn", metric, quality, pretrained, progress, **kwargs
)