mpcompress.losses
MPC12Loss
MPC12Loss(rlmbda=1.0)
MPC12 loss function combining rate-distortion optimization.
This loss function combines:
- Bits per pixel (BPP) loss from likelihoods
- VQGAN feature reconstruction loss
- PATCH tokens reconstruction loss
- CLS token reconstruction loss
The total loss is: rlmbda * bpp_loss + cls_token_loss + patch_tokens_loss + h_vqgan_loss
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rlmbda
|
|
Rate-distortion trade-off parameter. Higher values emphasize compression rate. Defaults to 1.0. |
1.0
|
forward
forward(output, x, x_shape=None, global_step=None)
Forward pass to compute MPC12 loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
|
Model output dictionary containing:
|
required |
x
|
|
Input tensor [B, C, H, W]. Used to infer shape if x_shape is None. |
required |
x_shape
|
|
Shape of input tensor (N, C, H, W). Required if x is None. |
None
|
global_step
|
|
Current training step for lambda scheduling. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
loss |
|
The total loss tensor |
monitor |
|
Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present |
Raises:
| Type | Description |
|---|---|
|
If both x and x_shape are None. |
get_rlmbda
get_rlmbda(global_step=None)
Get the rate-distortion trade-off parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
global_step
|
|
Current training step (currently unused). Can be used for scheduled lambda values in the future. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
rlmbda |
|
The rate-distortion trade-off parameter. |
MPC2Loss
MPC2Loss(rlmbda=1.0)
MPC2 loss function combining rate-distortion optimization.
This loss function combines:
- Bits per pixel (BPP) loss from likelihoods
- PATCH tokens reconstruction loss
- CLS token reconstruction loss
The total loss is: rlmbda * bpp_loss + cls_token_loss + patch_tokens_loss
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rlmbda
|
|
Rate-distortion trade-off parameter. Higher values emphasize compression rate. Defaults to 1.0. |
1.0
|
forward
forward(output, x, x_shape=None, global_step=None)
Forward pass to compute MPC2 loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
|
Model output dictionary containing:
|
required |
x
|
|
Input tensor [B, C, H, W]. Used to infer shape if x_shape is None. |
required |
x_shape
|
|
Shape of input tensor (N, C, H, W). Required if x is None. |
None
|
global_step
|
|
Current training step for lambda scheduling. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
loss |
|
The total loss tensor |
monitor |
|
Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present |
Raises:
| Type | Description |
|---|---|
|
If both x and x_shape are None. |
get_rlmbda
get_rlmbda(global_step=None)
Get the rate-distortion trade-off parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
global_step
|
|
Current training step (currently unused). Can be used for scheduled lambda values in the future. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
rlmbda |
|
The rate-distortion trade-off parameter. |
SimpleLoss
SimpleLoss(**kwargs)
Simple loss function that extracts loss directly from model output.
This loss function is a wrapper that simply retrieves the loss value from the model output dictionary and returns it along with monitoring metrics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
|
Additional keyword arguments (currently unused). |
{}
|
forward
forward(output, x)
Forward pass to compute loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
|
Model output dictionary containing:
|
required |
x
|
|
Input tensor (not used, kept for interface compatibility) |
required |
Returns:
| Name | Type | Description |
|---|---|---|
loss |
|
The loss tensor |
monitor |
|
Dictionary of monitoring metrics including loss value and additional metrics from output["monitor"] if present |