Skip to content

mpcompress.losses

MPC12Loss

MPC12Loss(rlmbda=1.0)

MPC12 loss function combining rate-distortion optimization.

This loss function combines:

  • Bits per pixel (BPP) loss from likelihoods
  • VQGAN feature reconstruction loss
  • PATCH tokens reconstruction loss
  • CLS token reconstruction loss

The total loss is: rlmbda * bpp_loss + cls_token_loss + patch_tokens_loss + h_vqgan_loss

Parameters:

Name Type Description Default
rlmbda float

Rate-distortion trade-off parameter. Higher values emphasize compression rate. Defaults to 1.0.

1.0

forward

forward(output, x, x_shape=None, global_step=None)

Forward pass to compute MPC12 loss.

Parameters:

Name Type Description Default
output dict

Model output dictionary containing:

  • "likelihoods" (dict): Dictionary of likelihood tensors for BPP calculation
  • "h_vqgan_hat" (torch.Tensor): Reconstructed VQGAN features
  • "h_vqgan" (torch.Tensor): Target VQGAN features
  • "h_dino_hat" (torch.Tensor): Reconstructed DINO features [B, N+1, D]
  • "h_dino" (torch.Tensor): Target DINO features [B, N+1, D]
  • "monitor" (dict, optional): Additional monitoring metrics
required
x Tensor

Input tensor [B, C, H, W]. Used to infer shape if x_shape is None.

required
x_shape tuple

Shape of input tensor (N, C, H, W). Required if x is None.

None
global_step int

Current training step for lambda scheduling.

None

Returns:

Name Type Description
loss Tensor

The total loss tensor

monitor dict

Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present

Raises:

Type Description
ValueError

If both x and x_shape are None.

get_rlmbda

get_rlmbda(global_step=None)

Get the rate-distortion trade-off parameter.

Parameters:

Name Type Description Default
global_step int

Current training step (currently unused). Can be used for scheduled lambda values in the future.

None

Returns:

Name Type Description
rlmbda float

The rate-distortion trade-off parameter.

MPC2Loss

MPC2Loss(rlmbda=1.0)

MPC2 loss function combining rate-distortion optimization.

This loss function combines:

  • Bits per pixel (BPP) loss from likelihoods
  • PATCH tokens reconstruction loss
  • CLS token reconstruction loss

The total loss is: rlmbda * bpp_loss + cls_token_loss + patch_tokens_loss

Parameters:

Name Type Description Default
rlmbda float

Rate-distortion trade-off parameter. Higher values emphasize compression rate. Defaults to 1.0.

1.0

forward

forward(output, x, x_shape=None, global_step=None)

Forward pass to compute MPC2 loss.

Parameters:

Name Type Description Default
output dict

Model output dictionary containing:

  • "likelihoods" (dict): Dictionary of likelihood tensors for BPP calculation
  • "h_dino_hat" (torch.Tensor): Reconstructed DINO features [B, N+1, D]
  • "h_dino" (torch.Tensor): Target DINO features [B, N+1, D]
  • "monitor" (dict, optional): Additional monitoring metrics
required
x Tensor

Input tensor [B, C, H, W]. Used to infer shape if x_shape is None.

required
x_shape tuple

Shape of input tensor (N, C, H, W). Required if x is None.

None
global_step int

Current training step for lambda scheduling.

None

Returns:

Name Type Description
loss Tensor

The total loss tensor

monitor dict

Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present

Raises:

Type Description
ValueError

If both x and x_shape are None.

get_rlmbda

get_rlmbda(global_step=None)

Get the rate-distortion trade-off parameter.

Parameters:

Name Type Description Default
global_step int

Current training step (currently unused). Can be used for scheduled lambda values in the future.

None

Returns:

Name Type Description
rlmbda float

The rate-distortion trade-off parameter.

SimpleLoss

SimpleLoss(**kwargs)

Simple loss function that extracts loss directly from model output.

This loss function is a wrapper that simply retrieves the loss value from the model output dictionary and returns it along with monitoring metrics.

Parameters:

Name Type Description Default
**kwargs dict

Additional keyword arguments (currently unused).

{}

forward

forward(output, x)

Forward pass to compute loss.

Parameters:

Name Type Description Default
output dict

Model output dictionary containing:

  • "loss" (torch.Tensor): Pre-computed loss tensor
  • "monitor" (dict, optional): Additional monitoring metrics
required
x Tensor

Input tensor (not used, kept for interface compatibility)

required

Returns:

Name Type Description
loss Tensor

The loss tensor

monitor dict

Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present