mpcompress.datasets
ClassificationDataset
ClassificationDataset(root, transform=None, split='', file_list=None, labels_file=None, **kwargs)
Unified image folder dataset for classification tasks.
This dataset loads images from a directory structure and associates them with classification labels from a labels file. It supports loading images from a file list or by scanning the directory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
root
|
|
Root directory of the dataset. |
required |
transform
|
|
Data preprocessing transform function. Defaults to None. |
None
|
split
|
|
Subset name (e.g., 'train', 'val'). If empty, uses root directory directly. Defaults to "". |
''
|
file_list
|
|
Path to file list relative to root. Each line should contain a relative path to an image file. If None, scans the directory for image files. Defaults to None. |
None
|
labels_file
|
(
|
Path to labels file relative to root. Each line should contain "image_name label" where label is an integer. Defaults to None. |
None
|
**kwargs
|
|
Additional keyword arguments (unused). |
{}
|
__getitem__
__getitem__(index)
Get an image sample and its metadata from the dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
|
Index of the sample to retrieve. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
(
|
A tuple containing: - img (PIL.Image.Image or torch.Tensor): The image. If transform is provided, returns the transformed version (typically a torch.Tensor). Otherwise, returns a PIL Image in RGB format. - img_meta (dict): Metadata dictionary with keys: - "img_path" (str): Full path to the image file. - "img_name" (str): Image filename without extension. - "ori_size" (tuple): Original image size (width, height) or (height, width) for tensors. - "cls_label" (int or None): Classification label for the image. |
__len__
__len__()
Return the number of samples in the dataset.
Returns:
| Name | Type | Description |
|---|---|---|
length |
|
Number of image files in the dataset. |
FeatureDictPerKeyFolder
FeatureDictPerKeyFolder(root, split='', keys=None, transform=None)
Dataset for loading feature dictionaries organized by keys.
Each key in the feature dictionary is stored in a separate folder. This organization allows loading only the specified keys, reducing disk I/O usage. All keys must have the same number of samples.
Directory structure:
root/
split/
key1/
sample0.pt
sample1.pt
...
key2/
sample0.pt
sample1.pt
...
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
root
|
|
Root directory of the dataset. |
required |
split
|
|
Subdirectory name within root (e.g., 'train', 'val'). Defaults to "" (empty string, meaning root itself). |
''
|
keys
|
|
List of keys to load. If None, all keys found in the directory are loaded. Defaults to None. |
None
|
transform
|
|
A function or transform to apply to each sample dictionary. Defaults to None. |
None
|
__getitem__
__getitem__(idx)
Get a feature dictionary sample from the dataset.
Loads the idx-th sample for each specified key and combines them into a single dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
|
Index of the sample to retrieve. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
data |
|
Dictionary containing all specified keys with their corresponding values. If transform is provided, the transformed version is returned. |
__len__
__len__()
Return the number of samples in the dataset.
Returns:
| Name | Type | Description |
|---|---|---|
length |
|
Number of samples (same for all keys). |
FeatureDictPerSampleFolder
FeatureDictPerSampleFolder(root, transform=None, split='train')
Dataset for loading feature dictionaries stored as separate .pt files.
Each sample is stored as a separate PyTorch .pt file containing a dictionary. This is useful when each sample has different keys or when features are preprocessed and saved individually.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
root
|
|
Root directory of the dataset. |
required |
transform
|
|
A function or transform to apply to each sample. Defaults to None. |
None
|
split
|
|
Subdirectory name within root (e.g., 'train', 'val'). Defaults to "train". |
'train'
|
__getitem__
__getitem__(idx)
Get a feature dictionary sample from the dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
|
Index of the sample to retrieve. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
data |
|
Feature dictionary loaded from the .pt file. If transform is provided, the transformed version is returned. |
__len__
__len__()
Return the number of samples in the dataset.
Returns:
| Name | Type | Description |
|---|---|---|
length |
|
Number of .pt files in the dataset. |
FeatureFolder
FeatureFolder(root, transform=None, split='train', model_type='sd3', task='tti', trun_flag=False, trun_low=-20, trun_high=20, quant_type='uniform', qsamples=0, bit_depth=1, patch_size=(512, 512))
Load a feature folder database.
This dataset loads feature files (numpy .npy files) from a directory and applies preprocessing including truncation, quantization, packing, and random cropping. The features are processed according to the specified model type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
root
|
|
Root directory containing feature files. |
required |
transform
|
|
A function or transform that takes in a feature and returns a transformed version. Defaults to None. |
None
|
split
|
|
Split mode ('train' or 'val'). Currently unused, features are loaded directly from root directory. Defaults to "train". |
'train'
|
model_type
|
|
Model type for feature packing. Supported types: "llama3", "dinov2", "sd3". Defaults to "sd3". |
'sd3'
|
task
|
|
Task type. Defaults to "tti". |
'tti'
|
trun_flag
|
|
Whether to apply truncation to features. Defaults to False. |
False
|
trun_low
|
|
Lower bound(s) for truncation. If list, each channel has its own bound. Defaults to -20. |
-20
|
trun_high
|
|
Upper bound(s) for truncation. If list, each channel has its own bound. Defaults to 20. |
20
|
quant_type
|
|
Quantization type. Currently only "uniform" is supported. Defaults to "uniform". |
'uniform'
|
qsamples
|
|
Number of quantization samples. Defaults to 0. |
0
|
bit_depth
|
|
Bit depth for uniform quantization. Defaults to 1. |
1
|
patch_size
|
|
Patch size for random cropping in format (height, width). Must be a multiple of 64. Defaults to (512, 512). |
(512, 512)
|
__getitem__
__getitem__(index)
Get a feature sample from the dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
|
Index of the sample to retrieve. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
feat |
|
Preprocessed feature array of shape (1, H, W) after truncation, quantization, packing, and random cropping. |
__len__
__len__()
Return the number of samples in the dataset.
Returns:
| Name | Type | Description |
|---|---|---|
length |
|
Number of feature files in the dataset. |
packing
staticmethod
packing(feat, model_type)
Pack features according to model type.
Reshapes features from (N, C, H, W) format to a 2D array format specific to the model type. This is used for compatibility with different model architectures.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat
|
|
Input feature array of shape (N, C, H, W). |
required |
model_type
|
|
Model type. Supported types: - "llama3": Extracts single channel, returns (H, W) - "dinov2": Reshapes to (NH, CW) - "sd3": Reshapes to (C/4H, C/4W) |
required |
Returns:
| Name | Type | Description |
|---|---|---|
feat |
|
Packed feature array with shape depending on model_type. |
random_crop
staticmethod
random_crop(feat, crop_shape)
Randomly crop a feature array to specified shape.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat
|
|
Input feature array of shape (H, W). |
required |
crop_shape
|
|
Desired crop size in format (height, width). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
feat |
|
Cropped feature array of shape crop_shape. |
Raises:
| Type | Description |
|---|---|
|
If crop_shape exceeds the feature dimensions. |
truncation
staticmethod
truncation(feat, trun_low, trun_high)
Truncate feature values to specified range.
Clips feature values to be within [trun_low, trun_high]. Supports per-channel truncation when trun_low and trun_high are lists.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat
|
|
Input feature array of shape (N, C, H, W). |
required |
trun_low
|
|
Lower bound(s) for truncation. |
required |
trun_high
|
|
Upper bound(s) for truncation. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
trun_feat |
|
Truncated feature array of the same shape as input. |
uniform_dequantization
staticmethod
uniform_dequantization(feat, min_v, max_v, bit_depth)
Apply uniform dequantization to features.
Converts quantized integer features back to continuous values using uniform dequantization. Supports per-channel dequantization when min_v and max_v are lists.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat
|
|
Quantized feature array of shape (N, C, H, W). |
required |
min_v
|
|
Minimum value(s) for dequantization range. |
required |
max_v
|
|
Maximum value(s) for dequantization range. |
required |
bit_depth
|
|
Number of bits used for quantization. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
dequant_feat |
|
Dequantized feature array of the same shape as input. |
uniform_quantization
staticmethod
uniform_quantization(feat, min_v, max_v, bit_depth)
Apply uniform quantization to features.
Quantizes features to integer values in the range [0, 2^bit_depth - 1] using uniform quantization. Supports per-channel quantization when min_v and max_v are lists.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat
|
|
Input feature array of shape (N, C, H, W). |
required |
min_v
|
|
Minimum value(s) for quantization range. |
required |
max_v
|
|
Maximum value(s) for quantization range. |
required |
bit_depth
|
|
Number of bits for quantization (determines quantization levels). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
quant_feat |
|
Quantized feature array of the same shape as input. |
unpacking
staticmethod
unpacking(feat, shape, model_type)
Unpack features according to model type.
Reshapes packed features back to (N, C, H, W) format. This is the inverse operation of packing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat
|
|
Packed feature array. |
required |
shape
|
|
Target shape (N, C, H, W). |
required |
model_type
|
|
Model type. Supported types: - "llama3": Expands to (1, 1, H, W) - "dinov2": Reshapes from (NH, CW) to (N, C, H, W) - "sd3": Reshapes from (C/4H, C/4W) to (N, C, H, W) |
required |
Returns:
| Name | Type | Description |
|---|---|---|
feat |
|
Unpacked feature array of shape (N, C, H, W). |
ImageFolder
ImageFolder(root, transform=None, split='train')
Load an image folder database.
Training and testing image samples are respectively stored in separate directories:
rootdir/
train/
img000.png
img001.png
test/
img000.png
img001.png
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
root
|
|
Root directory of the dataset. |
required |
transform
|
|
A function or transform that takes in a PIL image and returns a transformed version. Defaults to None. |
None
|
split
|
|
Split mode ('train' or 'val'). Defaults to "train". |
'train'
|
__getitem__
__getitem__(index)
Get an image sample from the dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
|
Index of the sample to retrieve. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
img |
|
The image. If transform is provided, returns the transformed version (typically a torch.Tensor). Otherwise, returns a PIL Image in RGB format. |
__len__
__len__()
Return the number of samples in the dataset.
Returns:
| Name | Type | Description |
|---|---|---|
length |
|
Number of image files in the dataset. |
PngSequenceVideoReader
PngSequenceVideoReader(src_path, width, height, start_num=1)
Video reader for PNG image sequences.
This reader loads frames from a sequence of PNG images with naming conventions like "im1.png", "im2.png" or "im00001.png", "im00002.png". The padding width is automatically detected from the first image found.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
src_path
|
|
Path to the directory containing PNG images. |
required |
width
|
|
Expected width of each frame in pixels. |
required |
height
|
|
Expected height of each frame in pixels. |
required |
start_num
|
|
Starting frame number. Defaults to 1. |
1
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
src_path
|
|
Path to the directory containing PNG images. |
required |
width
|
|
Expected width of each frame in pixels. |
required |
height
|
|
Expected height of each frame in pixels. |
required |
start_num
|
|
Starting frame number. Defaults to 1. |
1
|
Raises:
| Type | Description |
|---|---|
|
If the image naming convention cannot be determined. |
close
close()
Close the reader and reset frame index.
read_one_frame
read_one_frame()
Read the next frame from the PNG sequence.
Returns:
| Name | Type | Description |
|---|---|---|
frame |
|
Frame object containing RGB data of shape (3, H, W) as uint8 numpy array. Returns None if end of file is reached or frame cannot be read. |
restart
restart()
Reset the reader to the starting frame.
PngSequenceVideoWriter
PngSequenceVideoWriter(dst_path, width, height)
Video writer for PNG image sequences.
This writer saves frames as a sequence of PNG images with naming convention "im00001.png", "im00002.png", etc. The output directory is created automatically if it doesn't exist.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dst_path
|
|
Path to the output directory for PNG images. |
required |
width
|
|
Width of each frame in pixels. |
required |
height
|
|
Height of each frame in pixels. |
required |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dst_path
|
|
Path to the output directory for PNG images. |
required |
width
|
|
Width of each frame in pixels. |
required |
height
|
|
Height of each frame in pixels. |
required |
close
close()
Close the writer and reset frame index.
write_one_frame
write_one_frame(rgb)
Write a single frame as a PNG image.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rgb
|
|
RGB image data of shape (3, H, W) as uint8 numpy array. |
required |
write_one_frame_from_tensor
write_one_frame_from_tensor(x, format='yuv444')
Write a single frame from a PyTorch tensor.
Converts the tensor from the specified format to RGB and writes it as a PNG image.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
|
Input tensor in the specified format. |
required |
format
|
|
Input format. Currently only "yuv444" is supported. Defaults to "yuv444". |
'yuv444'
|
Raises:
| Type | Description |
|---|---|
|
If the specified format is not supported. |
SegmentationDataset
SegmentationDataset(root, transform=None, img_path='JPEGImages', seg_map_path='SegmentationClass', file_list=None, reduce_zero_label=False, **kwargs)
Dataset for image segmentation tasks.
This dataset loads images and their corresponding segmentation masks from separate directories. It supports loading images from a file list or by scanning the directory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
root
|
|
Root directory of the dataset. |
required |
transform
|
|
Data preprocessing transform function. Defaults to None. |
None
|
img_path
|
|
Name of the image subdirectory within root. Defaults to "JPEGImages". |
'JPEGImages'
|
seg_map_path
|
|
Name of the segmentation mask subdirectory within root. Defaults to "SegmentationClass". |
'SegmentationClass'
|
file_list
|
|
Path to file list relative to root. Each line should contain an image name (without extension). If None, scans the directory for .jpg files. Defaults to None. |
None
|
reduce_zero_label
|
|
Whether to reduce zero label. If True, subtracts 1 from all labels (2->1, 1->0, 0->255 for uint8). Defaults to False. |
False
|
**kwargs
|
|
Additional keyword arguments (unused). |
{}
|
__getitem__
__getitem__(index)
Get an image sample and its segmentation mask from the dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
|
Index of the sample to retrieve. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
(
|
A tuple containing: - img (PIL.Image.Image or torch.Tensor): The image. If transform is provided, returns the transformed version (typically a torch.Tensor). Otherwise, returns a PIL Image in RGB format. - img_meta (dict): Metadata dictionary with keys: - "img_path" (str): Full path to the image file. - "img_name" (str): Image filename without extension. - "ori_size" (tuple): Original image size (width, height) or (height, width) for tensors. - "seg_label_path" (str): Full path to the segmentation mask file. - "seg_label" (numpy.ndarray): Segmentation mask as int64 array. |
__len__
__len__()
Return the number of samples in the dataset.
Returns:
| Name | Type | Description |
|---|---|---|
length |
|
Number of image files in the dataset. |
VideoFolder
VideoFolder(root, transform=None, split='train', src_type='yuv420', sequences=[])
Load an image folder database. Training and testing image samples are respectively stored in separate directories:
- rootdir/
- train/
- video000.yuv
- video001.yuv
- test/
- video002.yuv
- video003.yuv
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
root
|
|
root directory of the dataset |
required |
transform
|
|
a function or transform that takes in a PIL image and returns a transformed version |
None
|
split
|
|
split mode ('train' or 'val') |
'train'
|
__getitem__
__getitem__(index)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
|
Index |
required |
Returns:
| Name | Type | Description |
|---|---|---|
reader |
|
Video reader object. |
vid_meta |
|
Video metadata. |
YUV420VideoReader
YUV420VideoReader(src_path, width, height, skip_frame=0)
Video reader for YUV420 format video files.
This reader loads frames from raw YUV420 format video files. YUV420 is a chroma-subsampled format where the Y (luminance) channel is full resolution and the U and V (chrominance) channels are subsampled by a factor of 2 in both dimensions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
src_path
|
|
Path to the YUV file (with or without .yuv extension). |
required |
width
|
|
Width of each frame in pixels. |
required |
height
|
|
Height of each frame in pixels. |
required |
skip_frame
|
|
Number of frames to skip at the beginning. Defaults to 0. |
0
|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
src_path
|
|
Path to the YUV file (with or without .yuv extension). |
required |
width
|
|
Width of each frame in pixels. |
required |
height
|
|
Height of each frame in pixels. |
required |
skip_frame
|
|
Number of frames to skip at the beginning. Defaults to 0. |
0
|
close
close()
Close the file handle.
read_one_frame
read_one_frame()
Read the next frame from the YUV420 file.
Returns:
| Name | Type | Description |
|---|---|---|
frame |
|
Frame object containing Y, U, V channels and YUV444 converted data. Returns None if end of file is reached or frame cannot be read. Frame attributes:
|
restart
restart()
Reset the reader to the beginning of the file.
YUV420VideoWriter
YUV420VideoWriter(dst_path, width, height)
Video writer for YUV420 format video files.
This writer saves frames to a raw YUV420 format video file. YUV420 is a chroma-subsampled format where the Y (luminance) channel is full resolution and the U and V (chrominance) channels are subsampled by a factor of 2 in both dimensions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dst_path
|
|
Path to the output YUV file (with or without .yuv extension). If a directory path is provided, it will create "out.yuv" in that directory. |
required |
width
|
|
Width of each frame in pixels. |
required |
height
|
|
Height of each frame in pixels. |
required |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dst_path
|
|
Path to the output YUV file (with or without .yuv extension). If a directory path is provided, it will create "out.yuv" in that directory. |
required |
width
|
|
Width of each frame in pixels. |
required |
height
|
|
Height of each frame in pixels. |
required |
close
close()
Close the file handle.
write_one_frame
write_one_frame(y, uv)
Write a single frame in YUV420 format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y
|
|
Luminance channel of shape (H, W) as uint8 numpy array. |
required |
uv
|
|
Chrominance channels of shape (2, H//2, W//2) as uint8 numpy array, where uv[0] is U and uv[1] is V. |
required |
write_one_frame_from_tensor
write_one_frame_from_tensor(x, format='yuv444')
Write a single frame from a PyTorch tensor.
Converts the tensor from the specified format to YUV420 and writes it to the file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
|
Input tensor in the specified format. |
required |
format
|
|
Input format. Currently only "yuv444" is supported. Defaults to "yuv444". |
'yuv444'
|
Raises:
| Type | Description |
|---|---|
|
If the specified format is not supported. |
feature_dict_collate_fn
feature_dict_collate_fn(batch)
Custom collate function for feature dictionaries.
Merges multiple feature dictionaries into a single batch. For each key: - If the value is a torch.Tensor, stacks all values along dimension 0. - If the value is a torch.Size, creates a new Size with batch dimension. - Otherwise, takes the first value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
|
List of feature dictionaries to collate. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
collated |
|
Collated dictionary with the same keys as input, where tensor values are stacked and other values are taken from the first sample. |