mpcompress.losses
BalancedBinaryCrossEntropyLoss
BalancedBinaryCrossEntropyLoss(pos_weight=None, ignore_index=255)
Balanced binary cross entropy loss with ignore regions.
Used for edge detection and binary segmentation tasks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pos_weight
|
Weight for positive samples (default: computed from data) |
None
|
|
ignore_index
|
Label index to ignore |
255
|
CrossEntropyLoss
CrossEntropyLoss(ignore_index=255, class_weight=None, balanced=False)
Cross entropy loss with ignore regions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ignore_index
|
Label index to ignore (default: 255) |
255
|
|
class_weight
|
Optional class weights for balancing |
None
|
|
balanced
|
Whether to use balanced weighting |
False
|
L1Loss
L1Loss(normalize=False, ignore_index=0, ignore_invalid_area=True)
L1 loss with ignore regions.
Used for surface normal estimation and depth estimation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
normalize
|
Whether to normalize predictions (for surface normals) |
False
|
|
ignore_index
|
Value to ignore in labels |
0
|
|
ignore_invalid_area
|
Whether to ignore invalid regions |
True
|
MLoRECodingLoss
MLoRECodingLoss(p, tasks, loss_weights=None)
Combined loss for MLoRE coding training.
Combines multi-task losses with compression losses (bpp and mse).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p
|
Configuration dict |
required | |
tasks
|
List of task names |
required | |
loss_weights
|
Dict of loss weights including: - Task weights (e.g., 'semseg': 1.0) - 'bpp_loss': Weight for bits-per-pixel loss - 'mse_loss': Weight for reconstruction MSE loss - 'load_balancing': Weight for routing CV loss |
None
|
forward
forward(pred, gt, tasks=None)
Compute combined coding loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Dict with task predictions and 'bpp_loss', 'mse_loss' |
required | |
gt
|
Dict of ground truth per task |
required | |
tasks
|
List of tasks to compute loss for |
None
|
Returns:
| Type | Description |
|---|---|
|
Dict with 'total' loss and all component losses |
MPC12Loss
MPC12Loss(rlmbda=1.0)
MPC12 loss function combining rate-distortion optimization.
This loss function combines:
- Bits per pixel (BPP) loss from likelihoods
- VQGAN feature reconstruction loss
- PATCH tokens reconstruction loss
- CLS token reconstruction loss
The total loss is: rlmbda * bpp_loss + cls_token_loss + patch_tokens_loss + h_vqgan_loss
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rlmbda
|
float
|
Rate-distortion trade-off parameter. Higher values emphasize compression rate. Defaults to 1.0. |
1.0
|
forward
forward(output, x, x_shape=None, global_step=None)
Forward pass to compute MPC12 loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
dict
|
Model output dictionary containing:
|
required |
x
|
Tensor
|
Input tensor [B, C, H, W]. Used to infer shape if x_shape is None. |
required |
x_shape
|
tuple
|
Shape of input tensor (N, C, H, W). Required if x is None. |
None
|
global_step
|
int
|
Current training step for lambda scheduling. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
loss |
Tensor
|
The total loss tensor |
monitor |
dict
|
Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present |
Raises:
| Type | Description |
|---|---|
ValueError
|
If both x and x_shape are None. |
get_rlmbda
get_rlmbda(global_step=None)
MPC2Loss
MPC2Loss(rlmbda=1.0)
MPC2 loss function combining rate-distortion optimization.
This loss function combines:
- Bits per pixel (BPP) loss from likelihoods
- PATCH tokens reconstruction loss
- CLS token reconstruction loss
The total loss is: rlmbda * bpp_loss + cls_token_loss + patch_tokens_loss
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rlmbda
|
float
|
Rate-distortion trade-off parameter. Higher values emphasize compression rate. Defaults to 1.0. |
1.0
|
forward
forward(output, x, x_shape=None, global_step=None)
Forward pass to compute MPC2 loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
dict
|
Model output dictionary containing:
|
required |
x
|
Tensor
|
Input tensor [B, C, H, W]. Used to infer shape if x_shape is None. |
required |
x_shape
|
tuple
|
Shape of input tensor (N, C, H, W). Required if x is None. |
None
|
global_step
|
int
|
Current training step for lambda scheduling. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
loss |
Tensor
|
The total loss tensor |
monitor |
dict
|
Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present |
Raises:
| Type | Description |
|---|---|
ValueError
|
If both x and x_shape are None. |
get_rlmbda
get_rlmbda(global_step=None)
MultiTaskLoss
MultiTaskLoss(p, tasks, loss_ft, loss_weights)
Multi-task loss aggregation.
Combines losses from multiple tasks with configurable weights, plus optional load balancing loss for routing mechanisms.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p
|
Configuration dict |
required | |
tasks
|
List of task names |
required | |
loss_ft
|
ModuleDict mapping task names to loss functions |
required | |
loss_weights
|
Dict mapping task names to loss weights |
required |
forward
forward(pred, gt, tasks=None)
Compute multi-task loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Dict of predictions per task |
required | |
gt
|
Dict of ground truth per task |
required | |
tasks
|
List of tasks to compute loss for (default: all) |
None
|
Returns:
| Type | Description |
|---|---|
|
Dict with 'total' loss and per-task losses |
SiLogLoss
SiLogLoss(lambd=0.5)
Scale-invariant logarithmic loss for depth estimation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lambd
|
Balance factor (default: 0.5) |
0.5
|
SimpleLoss
SimpleLoss(**kwargs)
Simple loss function that extracts loss directly from model output.
This loss function is a wrapper that simply retrieves the loss value from the model output dictionary and returns it along with monitoring metrics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
dict
|
Additional keyword arguments (currently unused). |
{}
|
forward
forward(output, x)
Forward pass to compute loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
dict
|
Model output dictionary containing:
|
required |
x
|
Tensor
|
Input tensor (not used, kept for interface compatibility) |
required |
Returns:
| Name | Type | Description |
|---|---|---|
loss |
Tensor
|
The loss tensor |
monitor |
dict
|
Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present |
get_task_loss
get_task_loss(task, p=None, **kwargs)
Factory function to get loss function for a specific task.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
task
|
Task name |
required | |
p
|
Configuration dict (optional) |
None
|
|
**kwargs
|
Additional arguments for loss construction |
{}
|
Returns:
| Type | Description |
|---|---|
|
Loss module |