Source code for compressai.models.vbr

import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F

from compressai.ans import BufferedRansEncoder, RansDecoder
from compressai.entropy_models import EntropyBottleneck, EntropyBottleneckVbr
from compressai.models.google import (
    JointAutoregressiveHierarchicalPriors,
    MeanScaleHyperprior,
    ScaleHyperprior,
)
from compressai.ops import LowerBound, quantize_ste
from compressai.registry import register_model

from .base import get_scale_table

# from .utils import update_registered_buffers

eps = 1e-9


[docs] @register_model("bmshj2018-hyperprior-vbr") class ScaleHyperpriorVbr(ScaleHyperprior): r"""Variable bitrate (vbr) version of bmshj2018-hyperprior (see compressai/models/google.py) with variable bitrate components detailed in: Fatih Kamisli, Fabien Racape and Hyomin Choi `"Variable-Rate Learned Image Compression with Multi-Objective Optimization and Quantization-Reconstruction Offsets`" <https://arxiv.org/abs/2402.18930>`_, Data Compression Conference (DCC), 2024. """ def __init__(self, N, M, vr_entbttlnck=False, **kwargs): super().__init__(N=N, M=M, **kwargs) # lambdas to use during training self.lmbda = [0.0018, 0.0035, 0.0067, 0.0130, 0.025, 0.0483, 0.0932, 0.18] self.levels = len(self.lmbda) # gain: inverse of quantization step size self.Gain = torch.nn.Parameter( torch.tensor( [0.10000, 0.13944, 0.19293, 0.26874, 0.37268, 0.51801, 0.71957, 1.00000] ), requires_grad=True, ) # 3-layer NN to get quant offset from Gain and stdev i.e. scales_hat Nds = 12 self.QuantABCD = nn.Sequential( nn.Linear(1 + 1, Nds), nn.ReLU(), nn.Linear(Nds, Nds), nn.ReLU(), nn.Linear(Nds, 1), ) # flag to indicate whether to use or not to use quantization offsets self.no_quantoffset = False # use also variable rate hyper prior z (entropy_bottleneck) self.vr_entbttlnck = vr_entbttlnck if self.vr_entbttlnck: self.entropy_bottleneck = EntropyBottleneckVbr(N) # 3-layer NN to get quant step size for hyper prior from Gain Ndsz = 10 self.gayn2zqstep = nn.Sequential( nn.Linear(1, Ndsz), nn.ReLU(), nn.Linear(Ndsz, Ndsz), nn.ReLU(), nn.Linear(Ndsz, 1), nn.Softplus(), ) self.lower_bound_zqstep = LowerBound(0.5) def _raise_stage_error(self, stage): raise ValueError(f"Invalid stage (stage={stage}) parameter for this model.") def _get_scale(self, stage, s, inputscale): s = max(0, min(s, len(self.Gain) - 1)) # clips to correct range if self.training: if stage > 1: scale = self.Gain[s].detach() # scale = torch.max(self.Gain[s], torch.tensor(1e-4)) + eps # train scale else: scale = self.Gain[s].detach() else: if inputscale == 0: scale = self.Gain[s].detach() else: scale = inputscale return scale def forward(self, x, stage: int = 2, s: int = 1, inputscale=0): r"""stage: 1 -> non-vbr (old) code path; vbr modules not used; operates like corresponding google model. use for initial training. 2 -> vbr code path; vbr modules now used. use for post training with e.g. MOO. """ scale = self._get_scale(stage, s, inputscale) rescale = 1.0 / scale.clone().detach() if stage == 1: y = self.g_a(x) z = self.h_a(torch.abs(y)) z_hat, z_likelihoods = self.entropy_bottleneck(z) scales_hat = self.h_s(z_hat) y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat) x_hat = self.g_s(y_hat) elif stage == 2: # vbr code path with STE based quantization proxy y = self.g_a(x) z = self.h_a(torch.abs(y)) if not self.vr_entbttlnck: z_hat, z_likelihoods = self.entropy_bottleneck(z) z_offset = self.entropy_bottleneck._get_medians() z_tmp = z - z_offset z_hat = quantize_ste(z_tmp) + z_offset else: z_qstep = self.gayn2zqstep(1.0 / scale.clone().view(1)) z_qstep = self.lower_bound_zqstep(z_qstep) z_hat, z_likelihoods = self.entropy_bottleneck( z, qs=z_qstep[0], training=None, ste=False ) # ste=True) scales_hat = self.h_s(z_hat) if self.no_quantoffset: y_hat = quantize_ste(y * scale, "ste") * rescale _, y_likelihoods = self.gaussian_conditional( y * scale, scales_hat * scale ) else: y_ch_means = 0 y_zm = y - y_ch_means y_zm_sc = y_zm * scale signs = torch.sign(y_zm_sc).detach() q_abs = quantize_ste(torch.abs(y_zm_sc)) q_stdev = self.gaussian_conditional.lower_bound_scale( scales_hat * scale ) stdev_and_gain = torch.cat( ( q_stdev.unsqueeze(dim=4), scale.detach().expand(q_stdev.unsqueeze(dim=4).shape), ), dim=4, ) q_offsets = (-1) * (self.QuantABCD.forward(stdev_and_gain)).squeeze( dim=4 ) q_offsets[q_abs < 0.0001] = ( 0 # must use zero offset for locations quantized to zero ! ) y_hat = signs * (q_abs + q_offsets) y_hat = y_hat * rescale + y_ch_means _, y_likelihoods = self.gaussian_conditional( y * scale, scales_hat * scale ) x_hat = self.g_s(y_hat) else: self._raise_stage_error(self, stage) return { "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } # def load_state_dict(self, state_dict): # update_registered_buffers( # self.gaussian_conditional, # "gaussian_conditional", # ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], # state_dict, # ) # super().load_state_dict(state_dict) @classmethod def from_state_dict(cls, state_dict, vr_entbttlnck=False): """Return a new model instance from `state_dict`.""" N = state_dict["g_a.0.weight"].size(0) M = state_dict["g_a.6.weight"].size(0) net = cls(N, M, vr_entbttlnck) if "QuantOffset" in state_dict.keys(): del state_dict["QuantOffset"] net.load_state_dict(state_dict) return net def update(self, scale_table=None, force=False, scale=None): if scale_table is None: scale_table = get_scale_table() updated = self.gaussian_conditional.update_scale_table(scale_table, force=force) # update vr EntropyBottleneck with given scale, i.e. quantization step size if isinstance(self.entropy_bottleneck, EntropyBottleneckVbr): sc = scale if sc is None: rv = self.entropy_bottleneck.update(force=force) else: z_qstep = self.gayn2zqstep(1.0 / sc.view(1)) z_qstep = self.lower_bound_zqstep(z_qstep) rv = self.entropy_bottleneck.update_variable(force=force, qs=z_qstep) elif isinstance(self.entropy_bottleneck, EntropyBottleneck): rv = self.entropy_bottleneck.update(force=force) updated |= rv return updated def compress(self, x, stage: int = 2, s: int = 1, inputscale=0): if inputscale != 0: scale = inputscale else: assert s in range( 0, self.levels ), f"s should in range(0, {self.levels}), but get s:{s}" scale = torch.abs(self.Gain[s]) y = self.g_a(x) z = self.h_a(torch.abs(y)) if stage == 1 or (stage == 2 and not self.vr_entbttlnck): z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) elif stage == 2: # support vr EntropyBottleneck z_qstep = self.gayn2zqstep(1.0 / scale.view(1)) z_qstep = self.lower_bound_zqstep(z_qstep) z_strings = self.entropy_bottleneck.compress(z, qs=z_qstep[0]) z_hat = self.entropy_bottleneck.decompress( z_strings, z.size()[-2:], qs=z_qstep[0] ) else: self._raise_stage_error(self, stage) scales_hat = self.h_s(z_hat) if stage == 1: indexes = self.gaussian_conditional.build_indexes(scales_hat) y_strings = self.gaussian_conditional.compress(y, indexes) elif stage == 2: indexes = self.gaussian_conditional.build_indexes(scales_hat * scale) y_strings = self.gaussian_conditional.compress(y * scale, indexes) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def decompress(self, strings, shape, stage: int = 2, s: int = 1, inputscale=0): assert isinstance(strings, list) and len(strings) == 2 if inputscale != 0: scale = inputscale else: assert s in range( 0, self.levels ), f"s should in range(0, {self.levels}), but get s:{s}" scale = torch.abs(self.Gain[s]) rescale = torch.tensor(1.0) / scale if stage == 1 or (stage == 2 and not self.vr_entbttlnck): z_hat = self.entropy_bottleneck.decompress(strings[1], shape) elif stage == 2: # support vr EntropyBottleneck z_qstep = self.gayn2zqstep(1.0 / scale.view(1)) z_qstep = self.lower_bound_zqstep(z_qstep) z_hat = self.entropy_bottleneck.decompress(strings[1], shape, qs=z_qstep[0]) else: self._raise_stage_error(self, stage) scales_hat = self.h_s(z_hat) if stage == 1: indexes = self.gaussian_conditional.build_indexes(scales_hat) y_hat = self.gaussian_conditional.decompress( strings[0], indexes, z_hat.dtype ) elif stage == 2: indexes = self.gaussian_conditional.build_indexes(scales_hat * scale) if self.no_quantoffset: y_hat = ( self.gaussian_conditional.decompress(strings[0], indexes) * rescale ) else: q_val = self.gaussian_conditional.decompress(strings[0], indexes) q_abs, signs = q_val.abs(), torch.sign(q_val) q_stdev = self.gaussian_conditional.lower_bound_scale( scales_hat * scale ) stdev_and_gain = torch.cat( ( q_stdev.unsqueeze(dim=4), scale.detach().expand(q_stdev.unsqueeze(dim=4).shape), ), dim=4, ) q_offsets = (-1) * (self.QuantABCD.forward(stdev_and_gain)).squeeze( dim=4 ) q_offsets[q_abs < 0.0001] = ( 0 # must use zero offset for locations quantized to zero ) y_hat = signs * (q_abs + q_offsets) y_ch_means = 0 y_hat = y_hat * rescale + y_ch_means x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat}
[docs] @register_model("mbt2018-mean-vbr") # class MeanScaleHyperpriorVbr(ScaleHyperpriorVbr): class MeanScaleHyperpriorVbr(ScaleHyperpriorVbr, MeanScaleHyperprior): r"""Variable bitrate (vbr) version of mbt2018-mean (see compressai/models/google.py) with variable bitrate components detailed in: Fatih Kamisli, Fabien Racape and Hyomin Choi `"Variable-Rate Learned Image Compression with Multi-Objective Optimization and Quantization-Reconstruction Offsets`" <https://arxiv.org/abs/2402.18930>`_, Data Compression Conference (DCC), 2024. """ def __init__(self, N=192, M=320, vr_entbttlnck=False, **kwargs): super().__init__(N, M, vr_entbttlnck=vr_entbttlnck, **kwargs) def forward(self, x, stage: int = 2, s: int = 1, inputscale=0): r"""stage: 1 -> non-vbr (old) code path; vbr modules not used; operates like corresponding google model. use for initial training. 2 -> vbr code path; vbr modules now used. use for post training with e.g. MOO. """ scale = self._get_scale(stage, s, inputscale) rescale = 1.0 / scale.clone().detach() if ( stage == 1 ): # NOTE: modifications to support q_offsets in noise case are not well tested y = self.g_a(x) z = self.h_a(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) gaussian_params = self.h_s(z_hat) scales_hat, means_hat = gaussian_params.chunk(2, 1) y_hat, y_likelihoods = self.gaussian_conditional( y, scales_hat, means=means_hat ) x_hat = self.g_s(y_hat) elif stage == 2: # vbr code path with STE based quantization proxy y = self.g_a(x) z = self.h_a(y) if not self.vr_entbttlnck: z_hat, z_likelihoods = self.entropy_bottleneck(z) z_offset = self.entropy_bottleneck._get_medians() z_tmp = z - z_offset z_hat = quantize_ste(z_tmp) + z_offset else: z_qstep = self.gayn2zqstep(1.0 / scale.clone().view(1)) z_qstep = self.lower_bound_zqstep(z_qstep) z_hat, z_likelihoods = self.entropy_bottleneck( z, qs=z_qstep[0], training=None, ste=False ) gaussian_params = self.h_s(z_hat) scales_hat, means_hat = gaussian_params.chunk(2, 1) if self.no_quantoffset: y_hat = ( self.quantizer.quantize((y - means_hat) * scale, "ste") * rescale + means_hat ) _, y_likelihoods = self.gaussian_conditional( y * scale, scales_hat * scale, means=means_hat * scale ) else: y_zm = y - means_hat y_zm_sc = y_zm * scale signs = torch.sign(y_zm_sc).detach() q_abs = quantize_ste(torch.abs(y_zm_sc)) q_stdev = self.gaussian_conditional.lower_bound_scale( scales_hat * scale ) stdev_and_gain = torch.cat( ( q_stdev.unsqueeze(dim=4), scale.detach().expand(q_stdev.unsqueeze(dim=4).shape), ), dim=4, ) q_offsets = (-1) * (self.QuantABCD.forward(stdev_and_gain)).squeeze( dim=4 ) q_offsets[q_abs < 0.0001] = ( 0 # must use zero offset for locations quantized to zero ! ) y_hat = signs * (q_abs + q_offsets) y_hat = y_hat * rescale + means_hat _, y_likelihoods = self.gaussian_conditional( y * scale, scales_hat * scale, means=means_hat * scale ) x_hat = self.g_s(y_hat) else: self._raise_stage_error(self, stage) return { "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } def compress(self, x, stage: int = 2, s: int = 1, inputscale=0): if inputscale != 0: scale = inputscale else: assert s in range( 0, self.levels ), f"s should in range(0, {self.levels}), but get s:{s}" scale = torch.abs(self.Gain[s]) y = self.g_a(x) z = self.h_a(y) if stage == 1 or (stage == 2 and not self.vr_entbttlnck): z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) elif stage == 2: # support vr EntropyBottleneck z_qstep = self.gayn2zqstep(1.0 / scale.view(1)) z_qstep = self.lower_bound_zqstep(z_qstep) z_strings = self.entropy_bottleneck.compress(z, qs=z_qstep[0]) z_hat = self.entropy_bottleneck.decompress( z_strings, z.size()[-2:], qs=z_qstep[0] ) else: self._raise_stage_error(self, stage) gaussian_params = self.h_s(z_hat) scales_hat, means_hat = gaussian_params.chunk(2, 1) if stage == 1: indexes = self.gaussian_conditional.build_indexes(scales_hat) y_strings = self.gaussian_conditional.compress(y, indexes, means=means_hat) elif stage == 2: indexes = self.gaussian_conditional.build_indexes(scales_hat * scale) y_strings = self.gaussian_conditional.compress( y * scale, indexes, means=means_hat * scale ) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def decompress(self, strings, shape, stage: int = 2, s: int = 1, inputscale=0): assert isinstance(strings, list) and len(strings) == 2 if inputscale != 0: scale = inputscale else: assert s in range( 0, self.levels ), f"s should in range(0, {self.levels}), but get s:{s}" scale = torch.abs(self.Gain[s]) rescale = torch.tensor(1.0) / scale if stage == 1 or (stage == 2 and not self.vr_entbttlnck): z_hat = self.entropy_bottleneck.decompress(strings[1], shape) elif stage == 2: # support vr EntropyBottleneck z_qstep = self.gayn2zqstep(1.0 / scale.view(1)) z_qstep = self.lower_bound_zqstep(z_qstep) z_hat = self.entropy_bottleneck.decompress(strings[1], shape, qs=z_qstep[0]) else: self._raise_stage_error(self, stage) gaussian_params = self.h_s(z_hat) scales_hat, means_hat = gaussian_params.chunk(2, 1) if stage == 1: indexes = self.gaussian_conditional.build_indexes(scales_hat) y_hat = self.gaussian_conditional.decompress( strings[0], indexes, means=means_hat ) elif stage == 2: indexes = self.gaussian_conditional.build_indexes(scales_hat * scale) if self.no_quantoffset: y_hat = ( self.gaussian_conditional.decompress( strings[0], indexes, means=means_hat * scale ) * rescale ) else: q_val = self.gaussian_conditional.decompress(strings[0], indexes) q_abs, signs = q_val.abs(), torch.sign(q_val) q_stdev = self.gaussian_conditional.lower_bound_scale( scales_hat * scale ) stdev_and_gain = torch.cat( ( q_stdev.unsqueeze(dim=4), scale.detach().expand(q_stdev.unsqueeze(dim=4).shape), ), dim=4, ) q_offsets = (-1) * (self.QuantABCD.forward(stdev_and_gain)).squeeze( dim=4 ) q_offsets[q_abs < 0.0001] = ( 0 # must use zero offset for locations quantized to zero ) y_hat = signs * (q_abs + q_offsets) y_hat = y_hat * rescale + means_hat x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat}
[docs] @register_model("mbt2018-vbr") class JointAutoregressiveHierarchicalPriorsVbr( ScaleHyperpriorVbr, JointAutoregressiveHierarchicalPriors ): r"""Variable bitrate (vbr) version of mbt2018 (see compressai/models/google.py) with variable bitrate components detailed in: Fatih Kamisli, Fabien Racape and Hyomin Choi `"Variable-Rate Learned Image Compression with Multi-Objective Optimization and Quantization-Reconstruction Offsets`" <https://arxiv.org/abs/2402.18930>`_, Data Compression Conference (DCC), 2024. """ def __init__(self, N=192, M=320, **kwargs): super().__init__(N, M, vr_entbttlnck=False, **kwargs) self.ste_recursive = True # flag to indicate whether this will be used or not self.scl2ctx = True # this generates from scale (i.e. quant step) a tensor that will be added # to context_prediction output so that it changes depending on scale in order to # have better entropy parameters for any quantization stepsize self.scale_to_context = nn.Linear(1, 2 * M) def forward(self, x, stage: int = 2, s: int = 1, inputscale=0): r"""stage: 1 -> non-vbr (old) code path; vbr modules not used; operates like corresponding google model. use for initial training. 2 -> vbr code path; vbr modules now used. use for post training with e.g. MOO. """ scale = self._get_scale(stage, s, inputscale) rescale = 1.0 / scale.clone().detach() if stage == 1: y = self.g_a(x) z = self.h_a(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) params = self.h_s(z_hat) y_hat = self.gaussian_conditional.quantize( y, "noise" if self.training else "dequantize" ) ctx_params = self.context_prediction(y_hat) gaussian_params = self.entropy_parameters( torch.cat((params, ctx_params), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) x_hat = self.g_s(y_hat) elif stage == 2: # vbr code path with STE based quantization proxy y = self.g_a(x) z = self.h_a(y) _, z_likelihoods = self.entropy_bottleneck(z) z_offset = self.entropy_bottleneck._get_medians() z_tmp = z - z_offset z_hat = quantize_ste(z_tmp) + z_offset params = self.h_s(z_hat) if self.ste_recursive: kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_hat = F.pad(y, (padding, padding, padding, padding)) # all modifications to perform quantization with quant offset are inside the _stequantization() y_hat, y_likelihoods = self._stequantization( y_hat, params, y.size(2), y.size(3), kernel_size, padding, scale, rescale, ) else: raise ValueError("ste_recurseive=False is not supported.") x_hat = self.g_s(y_hat) else: self._raise_stage_error(self, stage) return { "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } def _stequantization( self, y_hat, params, height, width, kernel_size, padding, scale, rescale ): y_likelihoods = torch.zeros([y_hat.size(0), y_hat.size(1), height, width]).to( scale.device ) # get ctx_scl to condition entropy model also on scale if self.scl2ctx: ctx_scl = self.scale_to_context.forward(scale.view(1, 1)).view(1, -1, 1, 1) # TODO: profile the calls to the bindings... masked_weight = self.context_prediction.weight * self.context_prediction.mask for h in range(height): for w in range(width): y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size].clone() ctx_p = F.conv2d( y_crop, masked_weight, bias=self.context_prediction.bias, ) # adjust ctx_p to be conditioned on also scale so that entropy params are also conditioned on scale ... if self.scl2ctx: p = params[:, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p + ctx_scl), dim=1) ) else: # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[:, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p), dim=1) ) gaussian_params = gaussian_params.squeeze(3).squeeze(2) scales_hat, means_hat = gaussian_params.chunk(2, 1) y_crop = y_crop[:, :, padding, padding] _, y_likelihoods[:, :, h : h + 1, w : w + 1] = ( self.gaussian_conditional( ((y_crop - means_hat) * scale).unsqueeze(2).unsqueeze(3), (scales_hat * scale).unsqueeze(2).unsqueeze(3), ) ) if self.no_quantoffset: y_q = ( self.quantizer.quantize( (y_crop - means_hat.detach()) * scale, "ste" ) * rescale + means_hat.detach() ) else: y_zm = y_crop - means_hat y_zm_sc = y_zm * scale signs = torch.sign(y_zm_sc).detach() q_abs = quantize_ste(torch.abs(y_zm_sc)) q_stdev = self.gaussian_conditional.lower_bound_scale( scales_hat * scale ) stdev_and_gain = torch.cat( ( q_stdev.unsqueeze(dim=2), scale.detach().expand(q_stdev.unsqueeze(dim=2).shape), ), dim=2, ) q_offsets = (-1) * (self.QuantABCD.forward(stdev_and_gain)).squeeze( dim=2 ) q_offsets[q_abs < 0.0001] = ( 0 # must use zero offset for locations quantized to zero ) y_q = signs * (q_abs + q_offsets) y_q = y_q * rescale + means_hat y_hat[:, :, h + padding, w + padding] = y_q y_hat = F.pad(y_hat, (-padding, -padding, -padding, -padding)) return y_hat, y_likelihoods @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict["g_a.0.weight"].size(0) M = state_dict["g_a.6.weight"].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def compress(self, x, stage: int = 2, s: int = 1, inputscale=0): if next(self.parameters()).device != torch.device("cpu"): warnings.warn( "Inference on GPU is not recommended for the autoregressive " "models (the entropy coder is run sequentially on CPU).", stacklevel=2, ) if inputscale != 0: scale = inputscale else: assert s in range( 0, self.levels ), f"s should in range(0, {self.levels}), but get s:{s}" scale = torch.abs(self.Gain[s]) rescale = torch.tensor(1.0) / scale y = self.g_a(x) z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s y_hat = F.pad(y, (padding, padding, padding, padding)) y_strings = [] for i in range(y.size(0)): string = self._compress_ar( y_hat[i : i + 1], params[i : i + 1], y_height, y_width, kernel_size, padding, scale, rescale, stage, ) y_strings.append(string) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def _compress_ar( self, y_hat, params, height, width, kernel_size, padding, scale, rescale, stage ): cdf = self.gaussian_conditional.quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional.cdf_length.tolist() offsets = self.gaussian_conditional.offset.tolist() encoder = BufferedRansEncoder() symbols_list = [] indexes_list = [] # get ctx_scl to condition entropy model also on scale if self.scl2ctx and stage == 2: ctx_scl = self.scale_to_context.forward(scale.view(1, 1)).view(1, -1, 1, 1) # Warning, this is slow... # TODO: profile the calls to the bindings... masked_weight = self.context_prediction.weight * self.context_prediction.mask for h in range(height): for w in range(width): y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size] ctx_p = F.conv2d( y_crop, masked_weight, bias=self.context_prediction.bias, ) # adjust ctx_p to be conditioned on also scale so that entropy params are also conditioned on scale ... if self.scl2ctx and stage == 2: ctx_p = ctx_p + ctx_scl # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[:, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1)) gaussian_params = gaussian_params.squeeze(3).squeeze(2) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat * scale) y_crop = y_crop[:, :, padding, padding] if stage == 1: y_q = self.gaussian_conditional.quantize( y_crop, "symbols", means_hat ) y_hat[:, :, h + padding, w + padding] = y_q + means_hat elif stage == 2: if self.no_quantoffset or ( self.no_quantoffset is False and self.ste_recursive is False ): y_q = self.gaussian_conditional.quantize( (y_crop - means_hat) * scale, "symbols" ) y_hat[:, :, h + padding, w + padding] = ( y_q ) * rescale + means_hat else: y_zm = y_crop - means_hat.detach() y_zm_sc = y_zm * scale signs = torch.sign(y_zm_sc).detach() q_abs = quantize_ste(torch.abs(y_zm_sc)) q_stdev = self.gaussian_conditional.lower_bound_scale( scales_hat * scale ) stdev_and_gain = torch.cat( ( q_stdev.unsqueeze(dim=2), scale.detach().expand(q_stdev.unsqueeze(dim=2).shape), ), dim=2, ) q_offsets = (-1) * ( self.QuantABCD.forward(stdev_and_gain) ).squeeze(dim=2) q_offsets[q_abs < 0.0001] = ( 0 # must use zero offset for locations quantized to zero ) y_q = (signs * (q_abs + 0)).int() y_hat[:, :, h + padding, w + padding] = ( signs * (q_abs + q_offsets) ) * rescale + means_hat symbols_list.extend(y_q.squeeze().tolist()) indexes_list.extend(indexes.squeeze().tolist()) encoder.encode_with_indexes( symbols_list, indexes_list, cdf, cdf_lengths, offsets ) string = encoder.flush() return string def decompress(self, strings, shape, stage: int = 2, s: int = 1, inputscale=0): assert isinstance(strings, list) and len(strings) == 2 if next(self.parameters()).device != torch.device("cpu"): warnings.warn( "Inference on GPU is not recommended for the autoregressive " "models (the entropy coder is run sequentially on CPU).", stacklevel=2, ) if inputscale != 0: scale = inputscale else: assert s in range( 0, self.levels ), f"s should in range(0, {self.levels}), but get s:{s}" scale = torch.abs(self.Gain[s]) rescale = torch.tensor(1.0) / scale # FIXME: we don't respect the default entropy coder and directly call the # range ANS decoder z_hat = self.entropy_bottleneck.decompress(strings[1], shape) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s # initialize y_hat to zeros, and pad it so we can directly work with # sub-tensors of size (N, C, kernel size, kernel_size) y_hat = torch.zeros( (z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding), device=z_hat.device, ) for i, y_string in enumerate(strings[0]): self._decompress_ar( y_string, y_hat[i : i + 1], params[i : i + 1], y_height, y_width, kernel_size, padding, scale, rescale, stage, ) y_hat = F.pad(y_hat, (-padding, -padding, -padding, -padding)) x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat} def _decompress_ar( # noqa: C901 self, y_string, y_hat, params, height, width, kernel_size, padding, scale, rescale, stage, ): cdf = self.gaussian_conditional.quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional.cdf_length.tolist() offsets = self.gaussian_conditional.offset.tolist() decoder = RansDecoder() decoder.set_stream(y_string) if stage == 2: if self.no_quantoffset is False and self.ste_recursive is False: y_rec = torch.zeros_like(y_hat) # get ctx_scl to condition entropy model also on scale if self.scl2ctx and stage == 2: ctx_scl = self.scale_to_context.forward(scale.view(1, 1)).view(1, -1, 1, 1) # Warning: this is slow due to the auto-regressive nature of the # decoding... See more recent publication where they use an # auto-regressive module on chunks of channels for faster decoding... for h in range(height): for w in range(width): # only perform the 5x5 convolution on a cropped tensor # centered in (h, w) y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size] ctx_p = F.conv2d( y_crop, self.context_prediction.weight, bias=self.context_prediction.bias, ) # adjust ctx_p to be conditioned on also scale so that entropy params are also conditioned on scale ... if self.scl2ctx and stage == 2: ctx_p = ctx_p + ctx_scl # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[:, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1)) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat * scale) rv = decoder.decode_stream( indexes.squeeze().tolist(), cdf, cdf_lengths, offsets ) rv = ( torch.Tensor(rv).reshape(1, -1, 1, 1).to(scales_hat.device) ) # TODO: move rv to gpu ? if stage == 1: rv = self.gaussian_conditional.dequantize(rv, means_hat) hp = h + padding wp = w + padding y_hat[:, :, hp : hp + 1, wp : wp + 1] = rv elif stage == 2: if self.no_quantoffset: rv = ( self.gaussian_conditional.dequantize(rv) * rescale + means_hat ) hp = h + padding wp = w + padding y_hat[:, :, hp : hp + 1, wp : wp + 1] = rv else: q_val = self.gaussian_conditional.dequantize(rv) q_abs, signs = q_val.abs(), torch.sign(q_val) q_stdev = self.gaussian_conditional.lower_bound_scale( scales_hat * scale ) stdev_and_gain = torch.cat( ( q_stdev.unsqueeze(dim=4), scale.detach().expand(q_stdev.unsqueeze(dim=4).shape), ), dim=4, ) q_offsets = (-1) * ( self.QuantABCD.forward(stdev_and_gain) ).squeeze(dim=4) q_offsets[q_abs < 0.0001] = ( 0 # must use zero offset for locations quantized to zero ) rv_out = (signs * (q_abs + q_offsets)) * rescale + means_hat hp = h + padding wp = w + padding if self.ste_recursive is False: y_hat[:, :, hp : hp + 1, wp : wp + 1] = ( self.gaussian_conditional.dequantize(rv) * rescale + means_hat ) # find also reco without quantoffset for likelihood y_rec[:, :, hp : hp + 1, wp : wp + 1] = rv_out else: y_hat[:, :, hp : hp + 1, wp : wp + 1] = rv_out # NOTE: reconstruction with quantoffset will be used for only image reconstruction and reconstruction without quantoffset for likelihood if stage == 2: if self.no_quantoffset is False and self.ste_recursive is False: y_hat = y_rec