CollaborativeCoding.dataloaders.mnist_4_9

Classes

MNISTDataset4_9

MNIST dataset of numbers 4-9.

Module Contents

class CollaborativeCoding.dataloaders.mnist_4_9.MNISTDataset4_9(data_path: pathlib.Path, sample_ids: numpy.ndarray, train: bool = False, transform=None, nr_channels: int = 1)

Bases: torch.utils.data.Dataset

MNIST dataset of numbers 4-9.

Parameters

data_pathPath

Root directory where MNIST dataset is stored

sample_idsnp.ndarray

Array of indices spcifying which samples to load. This determines the samples used by the dataloader.

trainbool, optional

Whether to train the model or not, by default False

transormcallable, optional

Transform to apply to the images, by default None

nr_channelsint, optional

Number of channels in the images, by default 1

data_path
mnist_path
samples
train = False
transform = None
num_classes = 6
images_path
labels_path
label_shift
label_restore
__len__()
__getitem__(idx)