CollaborativeCoding.dataloaders.uspsh5_7_9

Classes

USPSH5_Digit_7_9_Dataset

This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file.

Module Contents

class CollaborativeCoding.dataloaders.uspsh5_7_9.USPSH5_Digit_7_9_Dataset(data_path, sample_ids, train=False, transform=None, nr_channels=1)

Bases: torch.utils.data.Dataset

This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file. It allows for applying transformations to the images and provides methods to retrieve images and their corresponding labels.

Parameters

data_pathstr or Path

Path to the directory containing the USPS .h5 file. This file should contain the data in the “train” or “test” group.

sample_idslist of int

A list of sample indices to be used from the dataset. This allows for filtering or selecting a subset of the full dataset.

trainbool, optional, default=False

If True, the dataset is loaded in training mode (using the “train” group). If False, the dataset is loaded in test mode (using the “test” group).

transformcallable, optional, default=None

A transformation function to apply to each image. If None, no transformation is applied. Typically used for data augmentation or normalization.

nr_channelsint, optional, default=1

The number of channels in the image. USPS images are typically grayscale, so this should generally be set to 1. This parameter allows for potential future flexibility.

Attributes

imagesnumpy.ndarray

Array of images corresponding to digits 7, 8, and 9 from the USPS dataset. The images are loaded from the HDF5 file and filtered based on the labels.

labelsnumpy.ndarray

Array of labels corresponding to the images. Only labels of digits 7, 8, and 9 are retained, and they are mapped to 0, 1, and 2 for classification tasks.

transformcallable, optional

A transformation function to apply to the images. This is passed as an argument during initialization.

label_shiftfunction

A function to shift the labels for classification purposes. It maps the original labels (7, 8, 9) to (0, 1, 2).

label_restorefunction

A function to restore the original labels (7, 8, 9) from the shifted labels (0, 1, 2).

num_classesint

The number of unique labels in the dataset, which is 3 (for digits 7, 8, and 9).

filename = 'usps.h5'
filepath
transform = None
mode = 'test'
h5_path
sample_ids
nr_channels = 1
num_classes = 3
images
labels
label_shift
label_restore
__len__()

Returns the total number of samples in the dataset.

This method is required for PyTorch’s Dataset class, as it allows PyTorch to determine the size of the dataset.

Returns

int

The number of images in the dataset (after filtering for digits 7, 8, and 9).

__getitem__(id)

Returns a sample from the dataset given an index.

This method is required for PyTorch’s Dataset class, as it allows indexing into the dataset to retrieve specific samples.

Parameters

idxint

The index of the sample to retrieve from the dataset.

Returns

tuple

A tuple containing: - image (PIL Image): The image at the specified index. - label (int): The label corresponding to the image, shifted to be in the range [0, 2] for classification.