CollaborativeCoding.dataloaders.mnist_0_3

Classes

MNISTDataset0_3

A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.

Module Contents

class CollaborativeCoding.dataloaders.mnist_0_3.MNISTDataset0_3(data_path: pathlib.Path, sample_ids: list, train: bool = False, transform=None, nr_channels: int = 1)

Bases: torch.utils.data.Dataset

A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.

Args

data_pathPath

The root directory where the MNIST folder with data is stored.

sample_idslist

A list of indices specifying which samples to load.

trainbool, optional

If True, load training data, otherwise load test data. Default is False.

transformcallable, optional

A function/transform to apply to the images. Default is None.

Attributes

mnist_pathPath

The directory where the MNIST dataset is located within the root directory.

idxlist

A list of indices specifying which samples to load.

trainbool

Indicates whether to load training data or test data.

transformcallable

A function/transform to apply to the images.

num_classesint

The number of classes in the dataset (0 to 3).

images_pathPath

The path to the image file (train or test) based on the train flag.

labels_pathPath

The path to the label file (train or test) based on the train flag.

lengthint

The number of samples in the dataset.

Methods

__len__()

Returns the number of samples in the dataset.

__getitem__(index)

Retrieves the image and label at the specified index.

mnist_path
idx
train = False
transform = None
num_classes = 4
images_path
labels_path
length
__len__()
__getitem__(index)