Skip to content

mpcompress.losses

BalancedBinaryCrossEntropyLoss

BalancedBinaryCrossEntropyLoss(pos_weight=None, ignore_index=255)

Balanced binary cross entropy loss with ignore regions.

Used for edge detection and binary segmentation tasks.

Parameters:

Name Type Description Default
pos_weight

Weight for positive samples (default: computed from data)

None
ignore_index

Label index to ignore

255

CrossEntropyLoss

CrossEntropyLoss(ignore_index=255, class_weight=None, balanced=False)

Cross entropy loss with ignore regions.

Parameters:

Name Type Description Default
ignore_index

Label index to ignore (default: 255)

255
class_weight

Optional class weights for balancing

None
balanced

Whether to use balanced weighting

False

L1Loss

L1Loss(normalize=False, ignore_index=0, ignore_invalid_area=True)

L1 loss with ignore regions.

Used for surface normal estimation and depth estimation.

Parameters:

Name Type Description Default
normalize

Whether to normalize predictions (for surface normals)

False
ignore_index

Value to ignore in labels

0
ignore_invalid_area

Whether to ignore invalid regions

True

MLoRECodingLoss

MLoRECodingLoss(p, tasks, loss_weights=None)

Combined loss for MLoRE coding training.

Combines multi-task losses with compression losses (bpp and mse).

Parameters:

Name Type Description Default
p

Configuration dict

required
tasks

List of task names

required
loss_weights

Dict of loss weights including: - Task weights (e.g., 'semseg': 1.0) - 'bpp_loss': Weight for bits-per-pixel loss - 'mse_loss': Weight for reconstruction MSE loss - 'load_balancing': Weight for routing CV loss

None

forward

forward(pred, gt, tasks=None)

Compute combined coding loss.

Parameters:

Name Type Description Default
pred

Dict with task predictions and 'bpp_loss', 'mse_loss'

required
gt

Dict of ground truth per task

required
tasks

List of tasks to compute loss for

None

Returns:

Type Description

Dict with 'total' loss and all component losses

MPC12Loss

MPC12Loss(rlmbda=1.0)

MPC12 loss function combining rate-distortion optimization.

This loss function combines:

  • Bits per pixel (BPP) loss from likelihoods
  • VQGAN feature reconstruction loss
  • PATCH tokens reconstruction loss
  • CLS token reconstruction loss

The total loss is: rlmbda * bpp_loss + cls_token_loss + patch_tokens_loss + h_vqgan_loss

Parameters:

Name Type Description Default
rlmbda float

Rate-distortion trade-off parameter. Higher values emphasize compression rate. Defaults to 1.0.

1.0

forward

forward(output, x, x_shape=None, global_step=None)

Forward pass to compute MPC12 loss.

Parameters:

Name Type Description Default
output dict

Model output dictionary containing:

  • "likelihoods" (dict): Dictionary of likelihood tensors for BPP calculation
  • "h_vqgan_hat" (torch.Tensor): Reconstructed VQGAN features
  • "h_vqgan" (torch.Tensor): Target VQGAN features
  • "h_dino_hat" (torch.Tensor): Reconstructed DINO features [B, N+1, D]
  • "h_dino" (torch.Tensor): Target DINO features [B, N+1, D]
  • "monitor" (dict, optional): Additional monitoring metrics
required
x Tensor

Input tensor [B, C, H, W]. Used to infer shape if x_shape is None.

required
x_shape tuple

Shape of input tensor (N, C, H, W). Required if x is None.

None
global_step int

Current training step for lambda scheduling.

None

Returns:

Name Type Description
loss Tensor

The total loss tensor

monitor dict

Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present

Raises:

Type Description
ValueError

If both x and x_shape are None.

get_rlmbda

get_rlmbda(global_step=None)

Get the rate-distortion trade-off parameter.

Parameters:

Name Type Description Default
global_step int

Current training step (currently unused). Can be used for scheduled lambda values in the future.

None

Returns:

Name Type Description
rlmbda float

The rate-distortion trade-off parameter.

MPC2Loss

MPC2Loss(rlmbda=1.0)

MPC2 loss function combining rate-distortion optimization.

This loss function combines:

  • Bits per pixel (BPP) loss from likelihoods
  • PATCH tokens reconstruction loss
  • CLS token reconstruction loss

The total loss is: rlmbda * bpp_loss + cls_token_loss + patch_tokens_loss

Parameters:

Name Type Description Default
rlmbda float

Rate-distortion trade-off parameter. Higher values emphasize compression rate. Defaults to 1.0.

1.0

forward

forward(output, x, x_shape=None, global_step=None)

Forward pass to compute MPC2 loss.

Parameters:

Name Type Description Default
output dict

Model output dictionary containing:

  • "likelihoods" (dict): Dictionary of likelihood tensors for BPP calculation
  • "h_dino_hat" (torch.Tensor): Reconstructed DINO features [B, N+1, D]
  • "h_dino" (torch.Tensor): Target DINO features [B, N+1, D]
  • "monitor" (dict, optional): Additional monitoring metrics
required
x Tensor

Input tensor [B, C, H, W]. Used to infer shape if x_shape is None.

required
x_shape tuple

Shape of input tensor (N, C, H, W). Required if x is None.

None
global_step int

Current training step for lambda scheduling.

None

Returns:

Name Type Description
loss Tensor

The total loss tensor

monitor dict

Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present

Raises:

Type Description
ValueError

If both x and x_shape are None.

get_rlmbda

get_rlmbda(global_step=None)

Get the rate-distortion trade-off parameter.

Parameters:

Name Type Description Default
global_step int

Current training step (currently unused). Can be used for scheduled lambda values in the future.

None

Returns:

Name Type Description
rlmbda float

The rate-distortion trade-off parameter.

MultiTaskLoss

MultiTaskLoss(p, tasks, loss_ft, loss_weights)

Multi-task loss aggregation.

Combines losses from multiple tasks with configurable weights, plus optional load balancing loss for routing mechanisms.

Parameters:

Name Type Description Default
p

Configuration dict

required
tasks

List of task names

required
loss_ft

ModuleDict mapping task names to loss functions

required
loss_weights

Dict mapping task names to loss weights

required

forward

forward(pred, gt, tasks=None)

Compute multi-task loss.

Parameters:

Name Type Description Default
pred

Dict of predictions per task

required
gt

Dict of ground truth per task

required
tasks

List of tasks to compute loss for (default: all)

None

Returns:

Type Description

Dict with 'total' loss and per-task losses

SiLogLoss

SiLogLoss(lambd=0.5)

Scale-invariant logarithmic loss for depth estimation.

Parameters:

Name Type Description Default
lambd

Balance factor (default: 0.5)

0.5

SimpleLoss

SimpleLoss(**kwargs)

Simple loss function that extracts loss directly from model output.

This loss function is a wrapper that simply retrieves the loss value from the model output dictionary and returns it along with monitoring metrics.

Parameters:

Name Type Description Default
**kwargs dict

Additional keyword arguments (currently unused).

{}

forward

forward(output, x)

Forward pass to compute loss.

Parameters:

Name Type Description Default
output dict

Model output dictionary containing:

  • "loss" (torch.Tensor): Pre-computed loss tensor
  • "monitor" (dict, optional): Additional monitoring metrics
required
x Tensor

Input tensor (not used, kept for interface compatibility)

required

Returns:

Name Type Description
loss Tensor

The loss tensor

monitor dict

Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present

get_task_loss

get_task_loss(task, p=None, **kwargs)

Factory function to get loss function for a specific task.

Parameters:

Name Type Description Default
task

Task name

required
p

Configuration dict (optional)

None
**kwargs

Additional arguments for loss construction

{}

Returns:

Type Description

Loss module