Skip to content

mpcompress.latent_codecs

FeatureScaleHyperprior

FeatureScaleHyperprior(N, M, **kwargs)

Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: "Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>_ Int. Conf. on Learning Representations (ICLR), 2018.

          ┌───┐    y     ┌───┐  z  ┌───┐ z_hat      z_hat ┌───┐
    x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯⋯·───►───┤h_s├─┐
          └───┘    │     └───┘     └───┘        EB        └───┘ │
                   ▼                                            │
                 ┌─┴─┐                                          │
                 │ Q │                                          ▼
                 └─┬─┘                                          │
                   │                                            │
             y_hat ▼                                            │
                   │                                            │
                   ·                                            │
                GC : ◄─────────────────────◄────────────────────┘
                   ·                 scales_hat
                   │
             y_hat ▼
                   │
          ┌───┐    │
x_hat ──◄─┤g_s├────┘
          └───┘

EB = Entropy bottleneck
GC = Gaussian conditional

Parameters:

Name Type Description Default
N int

Number of channels

required
M int

Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder)

required

Parameters:

Name Type Description Default
N int

Number of channels in the main network.

required
M int

Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder).

required
**kwargs dict

Additional keyword arguments passed to parent class.

{}

downsampling_factor property

downsampling_factor: int

Compute the downsampling factor of the model.

Returns:

Name Type Description
factor int

Downsampling factor (64 for this architecture).

compress

compress(x)

Compress input tensor to bitstrings.

Parameters:

Name Type Description Default
x Tensor

Input tensor to compress.

required

Returns:

Name Type Description
output dict

Dictionary containing:

  • "strings" (list): List of compressed bitstrings [y_strings, z_strings].
  • "shape" (tuple): Spatial shape of the hyper latents (H, W).

decompress

decompress(strings, shape)

Decompress bitstrings to reconstructed tensor.

Parameters:

Name Type Description Default
strings list

List of compressed bitstrings [y_strings, z_strings]. Must contain exactly 2 elements.

required
shape tuple

Spatial shape of the hyper latents (H, W).

required

Returns:

Name Type Description
output dict

Dictionary containing:

  • "x_hat" (torch.Tensor): Reconstructed tensor, clamped to [0, 1].

forward

forward(x)

Forward pass through the Scale Hyperprior model.

Parameters:

Name Type Description Default
x Tensor

Input tensor to compress.

required

Returns:

Name Type Description
output dict

Dictionary containing:

  • "x_hat" (torch.Tensor): Reconstructed tensor.
  • "likelihoods" (dict): Dictionary with keys "y" and "z" containing likelihoods for main latents and hyper latents respectively.

from_state_dict classmethod

from_state_dict(state_dict)

Create a new model instance from state dictionary.

Parameters:

Name Type Description Default
state_dict dict

State dictionary containing model weights.

required

Returns:

Name Type Description
model FeatureScaleHyperprior

New model instance with loaded weights.

HyperLatentCodecWithCtx

HyperLatentCodecWithCtx(entropy_bottleneck: EntropyBottleneck, h_a: Module, h_s: Module, quantizer: str = 'noise', **kwargs)

Entropy bottleneck codec with surrounding h_a and h_s transforms.

"Hyper" side-information branch introduced in "Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>_, by J. Balle, D. Minnen, S. Singh, S.J. Hwang, and N. Johnston, International Conference on Learning Representations (ICLR), 2018.

HyperLatentCodec should be used inside HyperpriorLatentCodec to construct a full hyperprior.

       ┌───┐  z  ┌───┐ z_hat      z_hat ┌───┐
y ──►──┤h_a├──►──┤ Q ├───►───····───►───┤h_s├──►── params
       └───┘     └───┘        EB        └───┘

Parameters:

Name Type Description Default
entropy_bottleneck EntropyBottleneck

Entropy bottleneck module for compressing hyper latents.

required
h_a Module

Analysis transform that maps input to hyper latents.

required
h_s Module

Synthesis transform that maps hyper latents to parameters.

required
quantizer str

Quantization method. Options: "noise" (default) or "ste". Defaults to "noise".

'noise'
**kwargs dict

Additional keyword arguments passed to parent class.

{}

compress

compress(y: Tensor, ctx: Tensor) -> Dict[str, Any]

Compress main latents to bitstrings.

Parameters:

Name Type Description Default
y Tensor

Main latents to compress.

required
ctx Tensor

Context tensor for conditional processing.

required

Returns:

Name Type Description
output dict

Dictionary containing:

  • "strings" (list): List containing compressed bitstrings [z_strings].
  • "shape" (tuple): Spatial shape of hyper latents (H, W).
  • "params" (torch.Tensor): Parameters generated from hyper latents.

decompress

decompress(strings: List[List[bytes]], shape: Tuple[int, int], ctx: Tensor, **kwargs) -> Dict[str, Any]

Decompress bitstrings to parameters.

Parameters:

Name Type Description Default
strings list[list[bytes]]

List containing compressed bitstrings [z_strings].

required
shape tuple[int, int]

Spatial shape of hyper latents (H, W).

required
ctx Tensor

Context tensor for conditional processing.

required
**kwargs dict

Additional keyword arguments (unused).

{}

Returns:

Name Type Description
output dict

Dictionary containing:

  • "params" (torch.Tensor): Parameters generated from decompressed hyper latents.

forward

forward(y: Tensor, ctx: Tensor) -> Dict[str, Any]

Forward pass through the hyper latent codec.

Parameters:

Name Type Description Default
y Tensor

Main latents to process.

required
ctx Tensor

Context tensor for conditional processing.

required

Returns:

Name Type Description
output dict

Dictionary containing:

  • "likelihoods" (dict): Dictionary with key "z" containing likelihoods for hyper latents.
  • "params" (torch.Tensor): Parameters generated from hyper latents.

HyperpriorLatentCodecWithCtx

HyperpriorLatentCodecWithCtx(latent_codec: Mapping[str, LatentCodec], **kwargs)

Hyperprior codec constructed from latent codec for y that compresses y using params from hyper branch.

Hyperprior entropy modeling introduced in "Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>_, by J. Balle, D. Minnen, S. Singh, S.J. Hwang, and N. Johnston, International Conference on Learning Representations (ICLR), 2018.

         ┌──────────┐
    ┌─►──┤ lc_hyper ├──►─┐
    │    └──────────┘    │
    │                    ▼ params
    │                    │
    │                 ┌──┴───┐
y ──┴───────►─────────┤ lc_y ├───►── y_hat
                      └──────┘

By default, the following codec is constructed:

         ┌───┐  z  ┌───┐ z_hat      z_hat ┌───┐
    ┌─►──┤h_a├──►──┤ Q ├───►───····───►───┤h_s├──►─┐
    │    └───┘     └───┘        EB        └───┘    │
    │                                              │
    │                  ┌──────────────◄────────────┘
    │                  │            params
    │               ┌──┴──┐
    │               │  EP │
    │               └──┬──┘
    │                  │
    │   ┌───┐  y_hat   ▼
y ──┴─►─┤ Q ├────►────····────►── y_hat
        └───┘          GC

Common configurations of latent codecs include:

  • entropy bottleneck hyper (default) and gaussian conditional y (default)
  • entropy bottleneck hyper (default) and autoregressive y

Parameters:

Name Type Description Default
latent_codec Mapping[str, LatentCodec]

Dictionary of latent codecs containing at least "y" and "hyper" keys: - "y": Codec for main latents. - "hyper": Codec for hyper latents (side information).

required
**kwargs dict

Additional keyword arguments passed to parent class.

{}

__getitem__

__getitem__(key: str) -> LatentCodec

Get a latent codec by key.

Parameters:

Name Type Description Default
key str

Key to access latent codec (e.g., "y" or "hyper").

required

Returns:

Name Type Description
codec LatentCodec

Requested latent codec.

compress

compress(y: Tensor, ctx: Tensor) -> Dict[str, Any]

Compress main latents to bitstrings.

Parameters:

Name Type Description Default
y Tensor

Main latents to compress.

required
ctx Tensor

Context tensor for conditional processing.

required

Returns:

Name Type Description
output dict

Dictionary containing:

  • "strings" (list): List of compressed bitstrings, with y_strings followed by z_strings.
  • "shape" (dict): Dictionary with keys "y" and "hyper" containing spatial shapes for main and hyper latents respectively.
  • "y_hat" (torch.Tensor): Reconstructed main latents.

decompress

decompress(strings: List[List[bytes]], shape: Dict[str, Tuple[int, ...]], ctx: Tensor, **kwargs) -> Dict[str, Any]

Decompress bitstrings to reconstructed main latents.

Parameters:

Name Type Description Default
strings list[list[bytes]]

List of compressed bitstrings, with y_strings followed by z_strings. All y_strings must have the same length as z_strings.

required
shape dict[str, tuple[int, ...]]

Dictionary with keys "y" and "hyper" containing spatial shapes for main and hyper latents respectively.

required
ctx Tensor

Context tensor for conditional processing.

required
**kwargs dict

Additional keyword arguments (unused).

{}

Returns:

Name Type Description
output dict

Dictionary containing:

  • "y_hat" (torch.Tensor): Reconstructed main latents.

forward

forward(y: Tensor, ctx: Tensor) -> Dict[str, Any]

Forward pass through the hyperprior codec.

Parameters:

Name Type Description Default
y Tensor

Main latents to process.

required
ctx Tensor

Context tensor for conditional processing.

required

Returns:

Name Type Description
output dict

Dictionary containing:

  • "likelihoods" (dict): Dictionary with keys "y" and "z" containing likelihoods for main latents and hyper latents respectively.
  • "y_hat" (torch.Tensor): Reconstructed main latents.

VitUnionLatentCodec

VitUnionLatentCodec(h_dim=384, y_dim=256, z_dim=192, groups=16, num_prefix_tokens=1, **kwargs)

Vit-based latent codec with joint modeling of cls token and patch tokens.

This codec takes ViT features as input and compresses the 2D patch tokens using a hyperprior + space-channel context model (SCCTX) as in [He2022]. It reconstructs the ViT feature map and re-injects learned register tokens before passing through transformer blocks.

Parameters:

Name Type Description Default
h_dim int

Channel dimension of ViT features.

384
y_dim int

Channel dimension of primary latent representation y.

256
z_dim int

Channel dimension of hyperprior latent representation z.

192
groups int or list[int]

Channel groups for channel-wise context modeling. If int, the channels are evenly split; if list, must sum to y_dim.

16
num_prefix_tokens int

Number of prefix/register tokens in the ViT feature.

1
**kwargs dict

Extra keyword arguments for compatibility (unused).

{}

compress

compress(h, token_res, **kwargs)

Compress ViT features into entropy-coded bitstreams.

Parameters:

Name Type Description Default
h Tensor

ViT output tensor of shape (B, L, C).

required
token_res tuple[int, int]

Spatial token resolution (H, W).

required
**kwargs dict

Unused keyword arguments for API compatibility.

{}

Returns:

Name Type Description
out dict

A dictionary with keys:

  • "strings" (dict): Entropy-coded bitstreams for "y" and "z".
  • "pstate" (dict): Side information needed for decoding, including shapes and token resolution.

decompress

decompress(strings, pstate, **kwargs)

Decompress entropy-coded bitstreams back to ViT features.

Parameters:

Name Type Description Default
strings dict

Bitstreams produced by :meth:compress, with keys "y" and "z".

required
pstate dict

Side information produced by :meth:compress, including shapes and token resolution.

required
**kwargs dict

Unused keyword arguments for API compatibility.

{}

Returns:

Name Type Description
out dict

A dictionary with key:

  • "h_hat" (torch.Tensor): Reconstructed ViT features.

forward

forward(h, token_res, **kwargs)

Forward pass for end-to-end rate–distortion training.

Parameters:

Name Type Description Default
h Tensor

ViT output tensor of shape (B, L, C) containing prefix tokens and patch tokens.

required
token_res tuple[int, int]

Spatial token resolution (H, W) such that L = num_prefix_tokens + H * W.

required
**kwargs dict

Unused keyword arguments for API compatibility.

{}

Returns:

Name Type Description
out dict

A dictionary with keys:

  • "h_hat" (torch.Tensor): Reconstructed ViT features of shape (B, L, C).
  • "likelihoods" (dict): Per-latent likelihoods with keys "y" and "z".

VitUnionLatentCodecWithCtx

VitUnionLatentCodecWithCtx(h_dim=384, y_dim=256, z_dim=192, ctx_dim=256, groups=16, **kwargs)

Vit union latent codec conditioned on an external context feature map.

This codec jointly compresses ViT patch tokens and an additional context feature map. The context is injected into both the analysis and synthesis transforms as well as the hyperprior pathway.

Parameters:

Name Type Description Default
h_dim int

Channel dimension of ViT features.

384
y_dim int

Channel dimension of primary latent representation y.

256
z_dim int

Channel dimension of hyperprior latent representation z.

192
ctx_dim int

Channel dimension of the external context feature map.

256
groups int or list[int]

Channel groups for channel-wise context modeling.

16
**kwargs dict

Extra keyword arguments for compatibility (unused).

{}

compress

compress(h, ctx, token_res)

Compress ViT features conditioned on a context feature map.

Parameters:

Name Type Description Default
h Tensor

ViT output tensor of shape (B, L, C).

required
ctx Tensor

Context feature map of shape (B, ctx_dim, H, W).

required
token_res tuple[int, int]

Spatial token resolution (H, W).

required

Returns:

Name Type Description
out dict

A dictionary with keys:

  • "strings" (dict): Bitstreams for "y" and "z".
  • "pstate" (dict): Side information with shapes and token resolution.

decompress

decompress(strings, pstate, ctx, **kwargs)

Decompress context-conditioned bitstreams back to ViT features.

Parameters:

Name Type Description Default
strings dict

Bitstreams with keys "y" and "z".

required
pstate dict

Side information produced by :meth:compress.

required
ctx Tensor

Context feature map used also at decoding time.

required
**kwargs dict

Unused keyword arguments for API compatibility.

{}

Returns:

Name Type Description
out dict

A dictionary with keys:

  • "h_hat" (torch.Tensor): Reconstructed ViT features.
  • "h_hat_share" (torch.Tensor): Shared feature map before context decoding.

forward

forward(h, ctx, token_res)

Forward pass with context-conditioned hyperprior.

Parameters:

Name Type Description Default
h Tensor

ViT output tensor of shape (B, L, C).

required
ctx Tensor

Context feature map of shape (B, ctx_dim, H, W).

required
token_res tuple[int, int]

Spatial token resolution (H, W).

required

Returns:

Name Type Description
out dict

A dictionary with keys:

  • "h_hat" (torch.Tensor): Reconstructed ViT features.
  • "h_hat_share" (torch.Tensor): Shared feature map before context decoding.
  • "likelihoods" (dict): Likelihoods for "y" and "z".

VtmCodec

VtmCodec(repo_dir)

VTM (VVC Test Model) codec wrapper for video encoding and decoding.

This class provides an interface to VTM encoder and decoder executables for compressing and decompressing video data.

Parameters:

Name Type Description Default
repo_dir str

Path to VTM repository directory containing bin/ and cfg/ folders.

required

compress

compress(raw_path, bin_path, width, height, qp: int, bitdepth: int = 8, chroma_format: str = '400')

Compress raw video file using VTM encoder.

Parameters:

Name Type Description Default
raw_path str

Path to input raw YUV file.

required
bin_path str

Path to output compressed bitstream file.

required
width int

Video width in pixels.

required
height int

Video height in pixels.

required
qp int

Quantization parameter (0-51, lower is higher quality).

required
bitdepth int

Bit depth (8 or 10). Defaults to 8.

8
chroma_format str

Chroma format. Defaults to "400" (grayscale).

'400'

decompress

decompress(bin_path, rec_path, bit_depth=8)

Decompress VTM bitstream to raw video file.

Parameters:

Name Type Description Default
bin_path str

Path to input compressed bitstream file.

required
rec_path str

Path to output reconstructed YUV file.

required
bit_depth int

Bit depth (8 or 10). Defaults to 8.

8

VtmFeatureCodec

VtmFeatureCodec(cfg)

VTM-based feature codec for compressing neural network features.

This codec applies truncation, quantization, packing, VTM encoding/decoding, and post-processing to compress features from various model types (llama3, dinov2, sd3).

Parameters:

Name Type Description Default
cfg dict

Configuration object containing:

  • vtm_path (str): Path to VTM repository directory.
  • trun_flag (bool): Whether to apply truncation.
  • trun_low (float or list[float]): Lower truncation bound(s).
  • trun_high (float or list[float]): Upper truncation bound(s).
  • bit_depth (int): Bit depth for quantization.
  • model_type (str): Model type ("llama3", "dinov2", or "sd3").
required

compress

compress(org_feat, qp: int)

Compress features using VTM codec.

Expected feature shape: (N_crop, N_layer, H*W+1, C)

Parameters:

Name Type Description Default
org_feat ndarray

Original features to compress.

required
qp int

Quantization parameter for VTM encoding.

required

Returns:

Name Type Description
output dict

Dictionary containing:

  • "strings" (dict): Compressed bitstring with key "vtm".
  • "pstate" (dict): State information including:
    • "bin_path" (str): Path to compressed bitstream.
    • "pack_shape" (tuple): Shape of packed features.
    • "feat_shape" (tuple): Original feature shape.
    • "bit_depth" (int): Bit depth used.

decompress

decompress(strings, pstate, **kwargs)

Decompress features from VTM bitstream.

Note: model_type, bit_depth, trun_low, trun_high are fixed in self.cfg.

Parameters:

Name Type Description Default
strings dict

Dictionary with key "vtm" containing compressed bitstring.

required
pstate dict

State dictionary containing:

  • "bin_path" (str): Path to compressed bitstream.
  • "pack_shape" (tuple): Shape of packed features.
  • "feat_shape" (tuple): Target feature shape.
  • "bit_depth" (int): Bit depth used.
required
**kwargs dict

Additional keyword arguments (unused).

{}

Returns:

Name Type Description
decoded dict

Dictionary containing: - "h_hat" (numpy.ndarray): Decoded features with original shape.

forward_test

forward_test(org_feat, qp: int)

Forward test method for debugging (includes timing measurements).

This method performs full encode-decode cycle and returns both compressed representation and decoded features with timing information.

Parameters:

Name Type Description Default
org_feat ndarray

Original features to compress.

required
qp int

Quantization parameter for VTM encoding.

required

Returns:

Name Type Description
coded_unit dict

Dictionary containing:

  • "strings" (dict): Compressed bitstring with key "vtm".
  • "pstate" (dict): State information including paths and shapes.
decoded dict

Dictionary containing:

  • "h_hat" (numpy.ndarray): Decoded features.