Skip to content

mpcompress.datasets

ClassificationDataset

ClassificationDataset(root, transform=None, split='', file_list=None, labels_file=None, **kwargs)

Unified image folder dataset for classification tasks.

This dataset loads images from a directory structure and associates them with classification labels from a labels file. It supports loading images from a file list or by scanning the directory.

Parameters:

Name Type Description Default
root str

Root directory of the dataset.

required
transform callable

Data preprocessing transform function. Defaults to None.

None
split str

Subset name (e.g., 'train', 'val'). If empty, uses root directory directly. Defaults to "".

''
file_list str

Path to file list relative to root. Each line should contain a relative path to an image file. If None, scans the directory for image files. Defaults to None.

None
labels_file (str, required)

Path to labels file relative to root. Each line should contain "image_name label" where label is an integer. Defaults to None.

None
**kwargs dict

Additional keyword arguments (unused).

{}

__getitem__

__getitem__(index)

Get an image sample and its metadata from the dataset.

Parameters:

Name Type Description Default
index int

Index of the sample to retrieve.

required

Returns:

Name Type Description
tuple (Image or Tensor, dict)

A tuple containing: - img (PIL.Image.Image or torch.Tensor): The image. If transform is provided, returns the transformed version (typically a torch.Tensor). Otherwise, returns a PIL Image in RGB format. - img_meta (dict): Metadata dictionary with keys: - "img_path" (str): Full path to the image file. - "img_name" (str): Image filename without extension. - "ori_size" (tuple): Original image size (width, height) or (height, width) for tensors. - "cls_label" (int or None): Classification label for the image.

__len__

__len__()

Return the number of samples in the dataset.

Returns:

Name Type Description
length int

Number of image files in the dataset.

FeatureDictPerKeyFolder

FeatureDictPerKeyFolder(root, split='', keys=None, transform=None)

Dataset for loading feature dictionaries organized by keys.

Each key in the feature dictionary is stored in a separate folder. This organization allows loading only the specified keys, reducing disk I/O usage. All keys must have the same number of samples.

Directory structure:

root/
    split/
        key1/
            sample0.pt
            sample1.pt
            ...
        key2/
            sample0.pt
            sample1.pt
            ...

Parameters:

Name Type Description Default
root str

Root directory of the dataset.

required
split str

Subdirectory name within root (e.g., 'train', 'val'). Defaults to "" (empty string, meaning root itself).

''
keys list[str]

List of keys to load. If None, all keys found in the directory are loaded. Defaults to None.

None
transform callable

A function or transform to apply to each sample dictionary. Defaults to None.

None

__getitem__

__getitem__(idx)

Get a feature dictionary sample from the dataset.

Loads the idx-th sample for each specified key and combines them into a single dictionary.

Parameters:

Name Type Description Default
idx int

Index of the sample to retrieve.

required

Returns:

Name Type Description
data dict

Dictionary containing all specified keys with their corresponding values. If transform is provided, the transformed version is returned.

__len__

__len__()

Return the number of samples in the dataset.

Returns:

Name Type Description
length int

Number of samples (same for all keys).

FeatureDictPerSampleFolder

FeatureDictPerSampleFolder(root, transform=None, split='train')

Dataset for loading feature dictionaries stored as separate .pt files.

Each sample is stored as a separate PyTorch .pt file containing a dictionary. This is useful when each sample has different keys or when features are preprocessed and saved individually.

Parameters:

Name Type Description Default
root str

Root directory of the dataset.

required
transform callable

A function or transform to apply to each sample. Defaults to None.

None
split str

Subdirectory name within root (e.g., 'train', 'val'). Defaults to "train".

'train'

__getitem__

__getitem__(idx)

Get a feature dictionary sample from the dataset.

Parameters:

Name Type Description Default
idx int

Index of the sample to retrieve.

required

Returns:

Name Type Description
data dict

Feature dictionary loaded from the .pt file. If transform is provided, the transformed version is returned.

__len__

__len__()

Return the number of samples in the dataset.

Returns:

Name Type Description
length int

Number of .pt files in the dataset.

FeatureFolder

FeatureFolder(root, transform=None, split='train', model_type='sd3', task='tti', trun_flag=False, trun_low=-20, trun_high=20, quant_type='uniform', qsamples=0, bit_depth=1, patch_size=(512, 512))

Load a feature folder database.

This dataset loads feature files (numpy .npy files) from a directory and applies preprocessing including truncation, quantization, packing, and random cropping. The features are processed according to the specified model type.

Parameters:

Name Type Description Default
root str

Root directory containing feature files.

required
transform callable

A function or transform that takes in a feature and returns a transformed version. Defaults to None.

None
split str

Split mode ('train' or 'val'). Currently unused, features are loaded directly from root directory. Defaults to "train".

'train'
model_type str

Model type for feature packing. Supported types: "llama3", "dinov2", "sd3". Defaults to "sd3".

'sd3'
task str

Task type. Defaults to "tti".

'tti'
trun_flag bool

Whether to apply truncation to features. Defaults to False.

False
trun_low float or list[float]

Lower bound(s) for truncation. If list, each channel has its own bound. Defaults to -20.

-20
trun_high float or list[float]

Upper bound(s) for truncation. If list, each channel has its own bound. Defaults to 20.

20
quant_type str

Quantization type. Currently only "uniform" is supported. Defaults to "uniform".

'uniform'
qsamples int

Number of quantization samples. Defaults to 0.

0
bit_depth int

Bit depth for uniform quantization. Defaults to 1.

1
patch_size tuple[int, int]

Patch size for random cropping in format (height, width). Must be a multiple of 64. Defaults to (512, 512).

(512, 512)

__getitem__

__getitem__(index)

Get a feature sample from the dataset.

Parameters:

Name Type Description Default
index int

Index of the sample to retrieve.

required

Returns:

Name Type Description
feat ndarray

Preprocessed feature array of shape (1, H, W) after truncation, quantization, packing, and random cropping.

__len__

__len__()

Return the number of samples in the dataset.

Returns:

Name Type Description
length int

Number of feature files in the dataset.

packing staticmethod

packing(feat, model_type)

Pack features according to model type.

Reshapes features from (N, C, H, W) format to a 2D array format specific to the model type. This is used for compatibility with different model architectures.

Parameters:

Name Type Description Default
feat ndarray

Input feature array of shape (N, C, H, W).

required
model_type str

Model type. Supported types: - "llama3": Extracts single channel, returns (H, W) - "dinov2": Reshapes to (NH, CW) - "sd3": Reshapes to (C/4H, C/4W)

required

Returns:

Name Type Description
feat ndarray

Packed feature array with shape depending on model_type.

random_crop staticmethod

random_crop(feat, crop_shape)

Randomly crop a feature array to specified shape.

Parameters:

Name Type Description Default
feat ndarray

Input feature array of shape (H, W).

required
crop_shape tuple[int, int]

Desired crop size in format (height, width).

required

Returns:

Name Type Description
feat ndarray

Cropped feature array of shape crop_shape.

Raises:

Type Description
ValueError

If crop_shape exceeds the feature dimensions.

truncation staticmethod

truncation(feat, trun_low, trun_high)

Truncate feature values to specified range.

Clips feature values to be within [trun_low, trun_high]. Supports per-channel truncation when trun_low and trun_high are lists.

Parameters:

Name Type Description Default
feat ndarray

Input feature array of shape (N, C, H, W).

required
trun_low float or list[float]

Lower bound(s) for truncation.

required
trun_high float or list[float]

Upper bound(s) for truncation.

required

Returns:

Name Type Description
trun_feat ndarray

Truncated feature array of the same shape as input.

uniform_dequantization staticmethod

uniform_dequantization(feat, min_v, max_v, bit_depth)

Apply uniform dequantization to features.

Converts quantized integer features back to continuous values using uniform dequantization. Supports per-channel dequantization when min_v and max_v are lists.

Parameters:

Name Type Description Default
feat ndarray

Quantized feature array of shape (N, C, H, W).

required
min_v float or list[float]

Minimum value(s) for dequantization range.

required
max_v float or list[float]

Maximum value(s) for dequantization range.

required
bit_depth int

Number of bits used for quantization.

required

Returns:

Name Type Description
dequant_feat ndarray

Dequantized feature array of the same shape as input.

uniform_quantization staticmethod

uniform_quantization(feat, min_v, max_v, bit_depth)

Apply uniform quantization to features.

Quantizes features to integer values in the range [0, 2^bit_depth - 1] using uniform quantization. Supports per-channel quantization when min_v and max_v are lists.

Parameters:

Name Type Description Default
feat ndarray

Input feature array of shape (N, C, H, W).

required
min_v float or list[float]

Minimum value(s) for quantization range.

required
max_v float or list[float]

Maximum value(s) for quantization range.

required
bit_depth int

Number of bits for quantization (determines quantization levels).

required

Returns:

Name Type Description
quant_feat ndarray

Quantized feature array of the same shape as input.

unpacking staticmethod

unpacking(feat, shape, model_type)

Unpack features according to model type.

Reshapes packed features back to (N, C, H, W) format. This is the inverse operation of packing.

Parameters:

Name Type Description Default
feat ndarray

Packed feature array.

required
shape tuple[int, int, int, int]

Target shape (N, C, H, W).

required
model_type str

Model type. Supported types: - "llama3": Expands to (1, 1, H, W) - "dinov2": Reshapes from (NH, CW) to (N, C, H, W) - "sd3": Reshapes from (C/4H, C/4W) to (N, C, H, W)

required

Returns:

Name Type Description
feat ndarray

Unpacked feature array of shape (N, C, H, W).

ImageFolder

ImageFolder(root, transform=None, split='train')

Load an image folder database.

Training and testing image samples are respectively stored in separate directories:

rootdir/
    train/
        img000.png
        img001.png
    test/
        img000.png
        img001.png

Parameters:

Name Type Description Default
root str

Root directory of the dataset.

required
transform callable

A function or transform that takes in a PIL image and returns a transformed version. Defaults to None.

None
split str

Split mode ('train' or 'val'). Defaults to "train".

'train'

__getitem__

__getitem__(index)

Get an image sample from the dataset.

Parameters:

Name Type Description Default
index int

Index of the sample to retrieve.

required

Returns:

Name Type Description
img Image or Tensor

The image. If transform is provided, returns the transformed version (typically a torch.Tensor). Otherwise, returns a PIL Image in RGB format.

__len__

__len__()

Return the number of samples in the dataset.

Returns:

Name Type Description
length int

Number of image files in the dataset.

PngSequenceVideoReader

PngSequenceVideoReader(src_path, width, height, start_num=1)

Video reader for PNG image sequences.

This reader loads frames from a sequence of PNG images with naming conventions like "im1.png", "im2.png" or "im00001.png", "im00002.png". The padding width is automatically detected from the first image found.

Parameters:

Name Type Description Default
src_path str

Path to the directory containing PNG images.

required
width int

Expected width of each frame in pixels.

required
height int

Expected height of each frame in pixels.

required
start_num int

Starting frame number. Defaults to 1.

1

Parameters:

Name Type Description Default
src_path str

Path to the directory containing PNG images.

required
width int

Expected width of each frame in pixels.

required
height int

Expected height of each frame in pixels.

required
start_num int

Starting frame number. Defaults to 1.

1

Raises:

Type Description
ValueError

If the image naming convention cannot be determined.

close

close()

Close the reader and reset frame index.

read_one_frame

read_one_frame()

Read the next frame from the PNG sequence.

Returns:

Name Type Description
frame Frame or None

Frame object containing RGB data of shape (3, H, W) as uint8 numpy array. Returns None if end of file is reached or frame cannot be read.

restart

restart()

Reset the reader to the starting frame.

PngSequenceVideoWriter

PngSequenceVideoWriter(dst_path, width, height)

Video writer for PNG image sequences.

This writer saves frames as a sequence of PNG images with naming convention "im00001.png", "im00002.png", etc. The output directory is created automatically if it doesn't exist.

Parameters:

Name Type Description Default
dst_path str

Path to the output directory for PNG images.

required
width int

Width of each frame in pixels.

required
height int

Height of each frame in pixels.

required

Parameters:

Name Type Description Default
dst_path str

Path to the output directory for PNG images.

required
width int

Width of each frame in pixels.

required
height int

Height of each frame in pixels.

required

close

close()

Close the writer and reset frame index.

write_one_frame

write_one_frame(rgb)

Write a single frame as a PNG image.

Parameters:

Name Type Description Default
rgb ndarray

RGB image data of shape (3, H, W) as uint8 numpy array.

required

write_one_frame_from_tensor

write_one_frame_from_tensor(x, format='yuv444')

Write a single frame from a PyTorch tensor.

Converts the tensor from the specified format to RGB and writes it as a PNG image.

Parameters:

Name Type Description Default
x Tensor

Input tensor in the specified format.

required
format str

Input format. Currently only "yuv444" is supported. Defaults to "yuv444".

'yuv444'

Raises:

Type Description
NotImplementedError

If the specified format is not supported.

SegmentationDataset

SegmentationDataset(root, transform=None, img_path='JPEGImages', seg_map_path='SegmentationClass', file_list=None, reduce_zero_label=False, **kwargs)

Dataset for image segmentation tasks.

This dataset loads images and their corresponding segmentation masks from separate directories. It supports loading images from a file list or by scanning the directory.

Parameters:

Name Type Description Default
root str

Root directory of the dataset.

required
transform callable

Data preprocessing transform function. Defaults to None.

None
img_path str

Name of the image subdirectory within root. Defaults to "JPEGImages".

'JPEGImages'
seg_map_path str

Name of the segmentation mask subdirectory within root. Defaults to "SegmentationClass".

'SegmentationClass'
file_list str

Path to file list relative to root. Each line should contain an image name (without extension). If None, scans the directory for .jpg files. Defaults to None.

None
reduce_zero_label bool

Whether to reduce zero label. If True, subtracts 1 from all labels (2->1, 1->0, 0->255 for uint8). Defaults to False.

False
**kwargs dict

Additional keyword arguments (unused).

{}

__getitem__

__getitem__(index)

Get an image sample and its segmentation mask from the dataset.

Parameters:

Name Type Description Default
index int

Index of the sample to retrieve.

required

Returns:

Name Type Description
tuple (Image or Tensor, dict)

A tuple containing: - img (PIL.Image.Image or torch.Tensor): The image. If transform is provided, returns the transformed version (typically a torch.Tensor). Otherwise, returns a PIL Image in RGB format. - img_meta (dict): Metadata dictionary with keys: - "img_path" (str): Full path to the image file. - "img_name" (str): Image filename without extension. - "ori_size" (tuple): Original image size (width, height) or (height, width) for tensors. - "seg_label_path" (str): Full path to the segmentation mask file. - "seg_label" (numpy.ndarray): Segmentation mask as int64 array.

__len__

__len__()

Return the number of samples in the dataset.

Returns:

Name Type Description
length int

Number of image files in the dataset.

VideoFolder

VideoFolder(root, transform=None, split='train', src_type='yuv420', sequences=[])

Load an image folder database. Training and testing image samples are respectively stored in separate directories:

- rootdir/
    - train/
        - video000.yuv
        - video001.yuv
    - test/
        - video002.yuv
        - video003.yuv

Parameters:

Name Type Description Default
root string

root directory of the dataset

required
transform callable

a function or transform that takes in a PIL image and returns a transformed version

None
split string

split mode ('train' or 'val')

'train'

__getitem__

__getitem__(index)

Parameters:

Name Type Description Default
index int

Index

required

Returns:

Name Type Description
reader VideoReader

Video reader object.

vid_meta dict

Video metadata.

YUV420VideoReader

YUV420VideoReader(src_path, width, height, skip_frame=0)

Video reader for YUV420 format video files.

This reader loads frames from raw YUV420 format video files. YUV420 is a chroma-subsampled format where the Y (luminance) channel is full resolution and the U and V (chrominance) channels are subsampled by a factor of 2 in both dimensions.

Parameters:

Name Type Description Default
src_path str

Path to the YUV file (with or without .yuv extension).

required
width int

Width of each frame in pixels.

required
height int

Height of each frame in pixels.

required
skip_frame int

Number of frames to skip at the beginning. Defaults to 0.

0

Parameters:

Name Type Description Default
src_path str

Path to the YUV file (with or without .yuv extension).

required
width int

Width of each frame in pixels.

required
height int

Height of each frame in pixels.

required
skip_frame int

Number of frames to skip at the beginning. Defaults to 0.

0

close

close()

Close the file handle.

read_one_frame

read_one_frame()

Read the next frame from the YUV420 file.

Returns:

Name Type Description
frame Frame or None

Frame object containing Y, U, V channels and YUV444 converted data. Returns None if end of file is reached or frame cannot be read.

Frame attributes:

- y: Luminance channel of shape (H, W) as uint8 numpy array.
- u: U chrominance channel of shape (H//2, W//2) as uint8 numpy array.
- v: V chrominance channel of shape (H//2, W//2) as uint8 numpy array.
- yuv444: YUV444 format data of shape (3, H, W) as uint8 numpy array.

restart

restart()

Reset the reader to the beginning of the file.

YUV420VideoWriter

YUV420VideoWriter(dst_path, width, height)

Video writer for YUV420 format video files.

This writer saves frames to a raw YUV420 format video file. YUV420 is a chroma-subsampled format where the Y (luminance) channel is full resolution and the U and V (chrominance) channels are subsampled by a factor of 2 in both dimensions.

Parameters:

Name Type Description Default
dst_path str

Path to the output YUV file (with or without .yuv extension). If a directory path is provided, it will create "out.yuv" in that directory.

required
width int

Width of each frame in pixels.

required
height int

Height of each frame in pixels.

required

Parameters:

Name Type Description Default
dst_path str

Path to the output YUV file (with or without .yuv extension). If a directory path is provided, it will create "out.yuv" in that directory.

required
width int

Width of each frame in pixels.

required
height int

Height of each frame in pixels.

required

close

close()

Close the file handle.

write_one_frame

write_one_frame(y, uv)

Write a single frame in YUV420 format.

Parameters:

Name Type Description Default
y ndarray

Luminance channel of shape (H, W) as uint8 numpy array.

required
uv ndarray

Chrominance channels of shape (2, H//2, W//2) as uint8 numpy array, where uv[0] is U and uv[1] is V.

required

write_one_frame_from_tensor

write_one_frame_from_tensor(x, format='yuv444')

Write a single frame from a PyTorch tensor.

Converts the tensor from the specified format to YUV420 and writes it to the file.

Parameters:

Name Type Description Default
x Tensor

Input tensor in the specified format.

required
format str

Input format. Currently only "yuv444" is supported. Defaults to "yuv444".

'yuv444'

Raises:

Type Description
NotImplementedError

If the specified format is not supported.

feature_dict_collate_fn

feature_dict_collate_fn(batch)

Custom collate function for feature dictionaries.

Merges multiple feature dictionaries into a single batch. For each key: - If the value is a torch.Tensor, stacks all values along dimension 0. - If the value is a torch.Size, creates a new Size with batch dimension. - Otherwise, takes the first value.

Parameters:

Name Type Description Default
batch list[dict]

List of feature dictionaries to collate.

required

Returns:

Name Type Description
collated dict

Collated dictionary with the same keys as input, where tensor values are stacked and other values are taken from the first sample.