CollaborativeCoding.dataloaders

Submodules

Classes

Downloader

Class used to verify availability and potentially download implemented datasets.

MNISTDataset0_3

A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.

MNISTDataset4_9

MNIST dataset of numbers 4-9.

SVHNDataset

An abstract class representing a Dataset.

USPSDataset0_6

Dataset class for USPS dataset with labels 0-6.

USPSH5_Digit_7_9_Dataset

This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file.

Package Contents

class CollaborativeCoding.dataloaders.Downloader

Class used to verify availability and potentially download implemented datasets.

Methods

mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray]

Checks the availability of mnist dataset. If not present downloads it into MNIST folder in data_dir.

svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray]

Download the SVHN dataset and save it as an HDF5 file to data_dir.

usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray]

Download the USPS dataset and save it as an HDF5 file to data_dir.

Raises

NotImplementedError

If the download method is not implemented for the dataset.

Examples

>>> from pathlib import Path
>>> from CollaborativeCoding import Downloader
>>> dir = Path('tmp')
>>> dir.mkdir(exist_ok=True)
>>> train, test = Downloader().usps(dir)
mnist(data_dir: pathlib.Path) tuple[numpy.ndarray, numpy.ndarray]

Check the availability of mnist dataset. If not present downloads it into MNIST folder in data_dir.

svhn(data_dir: pathlib.Path) tuple[numpy.ndarray, numpy.ndarray]
usps(data_dir: pathlib.Path) tuple[numpy.ndarray, numpy.ndarray]

Download the USPS dataset and save it as an HDF5 file to data_dir/usps.h5.

__extract_usps(src: pathlib.Path, dest: pathlib.Path, mode: str)
static __reporthook(blocknum, blocksize, totalsize)

Use this function to report download progress for the urllib.request.urlretrieve function.

static __check_integrity(filepath, checksum)

Check the integrity of the USPS dataset file.

Args

filepathpathlib.Path

Path to the USPS dataset file.

checksumstr

MD5 checksum of the dataset file.

Returns

bool

True if the checksum of the file matches the expected checksum, False otherwise

class CollaborativeCoding.dataloaders.MNISTDataset0_3(data_path: pathlib.Path, sample_ids: list, train: bool = False, transform=None, nr_channels: int = 1)

Bases: torch.utils.data.Dataset

A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.

Args

data_pathPath

The root directory where the MNIST folder with data is stored.

sample_idslist

A list of indices specifying which samples to load.

trainbool, optional

If True, load training data, otherwise load test data. Default is False.

transformcallable, optional

A function/transform to apply to the images. Default is None.

Attributes

mnist_pathPath

The directory where the MNIST dataset is located within the root directory.

idxlist

A list of indices specifying which samples to load.

trainbool

Indicates whether to load training data or test data.

transformcallable

A function/transform to apply to the images.

num_classesint

The number of classes in the dataset (0 to 3).

images_pathPath

The path to the image file (train or test) based on the train flag.

labels_pathPath

The path to the label file (train or test) based on the train flag.

lengthint

The number of samples in the dataset.

Methods

__len__()

Returns the number of samples in the dataset.

__getitem__(index)

Retrieves the image and label at the specified index.

mnist_path
idx
train = False
transform = None
num_classes = 4
images_path
labels_path
length
__len__()
__getitem__(index)
class CollaborativeCoding.dataloaders.MNISTDataset4_9(data_path: pathlib.Path, sample_ids: numpy.ndarray, train: bool = False, transform=None, nr_channels: int = 1)

Bases: torch.utils.data.Dataset

MNIST dataset of numbers 4-9.

Parameters

data_pathPath

Root directory where MNIST dataset is stored

sample_idsnp.ndarray

Array of indices spcifying which samples to load. This determines the samples used by the dataloader.

trainbool, optional

Whether to train the model or not, by default False

transormcallable, optional

Transform to apply to the images, by default None

nr_channelsint, optional

Number of channels in the images, by default 1

data_path
mnist_path
samples
train = False
transform = None
num_classes = 6
images_path
labels_path
label_shift
label_restore
__len__()
__getitem__(idx)
class CollaborativeCoding.dataloaders.SVHNDataset(data_path: pathlib.Path, sample_ids: list, train: bool, transform=None, nr_channels=3)

Bases: torch.utils.data.Dataset

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

Note

DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

data_path
indexes
split = 'train'
nr_channels = 3
transforms = None
num_classes
_create_h5py(path: str)

Downloads the SVHN dataset to the specified directory. Args:

path (str): The directory where the dataset will be downloaded.

__len__()

Returns the number of samples in the dataset. Returns:

int: The number of samples.

__getitem__(index)

Retrieves the image and label at the specified index. Args:

index (int): The index of the sample to retrieve.

Returns:

tuple: A tuple containing the image and its corresponding label.

class CollaborativeCoding.dataloaders.USPSDataset0_6(data_path: pathlib.Path, sample_ids: list, train: bool = False, transform=None, nr_channels=1)

Bases: torch.utils.data.Dataset

Dataset class for USPS dataset with labels 0-6.

Args

data_pathpathlib.Path

Path to the data directory.

trainbool, optional

Mode of the dataset.

transformcallable, optional

A function/transform that takes in a sample and returns a transformed version.

downloadbool, optional

Whether to download the Dataset.

Attributes

filepathpathlib.Path

Path to the USPS dataset file.

modestr

Mode of the dataset, either train or test.

transformcallable

A function/transform that takes in a sample and returns a transformed version.

idxnumpy.ndarray

Indices of samples with labels 0-6.

num_classesint

Number of classes in the dataset

Methods

_index()

Get indices of samples with labels 0-6.

_load_data(idx)

Load data and target label from the dataset.

__len__()

Get the number of samples in the dataset.

__getitem__(idx)

Get a sample from the dataset.

Examples

>>> from torchvision import transforms
>>> from src.datahandlers import USPSDataset0_6
>>> transform = transforms.Compose([
...     transforms.Resize((16, 16)),
...     transforms.ToTensor()
... ])
>>> dataset = USPSDataset0_6(
...     data_path="data",
...     transform=transform
...     download=True,
...     train=True,
... )
>>> len(dataset)
5460
>>> data, target = dataset[0]
>>> data.shape
(1, 16, 16)
>>> target
tensor([1., 0., 0., 0., 0., 0., 0.])
filename = 'usps.h5'
num_classes = 7
filepath
transform = None
mode = 'test'
sample_ids
nr_channels = 1
__len__()
__getitem__(id)
class CollaborativeCoding.dataloaders.USPSH5_Digit_7_9_Dataset(data_path, sample_ids, train=False, transform=None, nr_channels=1)

Bases: torch.utils.data.Dataset

This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file. It allows for applying transformations to the images and provides methods to retrieve images and their corresponding labels.

Parameters

data_pathstr or Path

Path to the directory containing the USPS .h5 file. This file should contain the data in the “train” or “test” group.

sample_idslist of int

A list of sample indices to be used from the dataset. This allows for filtering or selecting a subset of the full dataset.

trainbool, optional, default=False

If True, the dataset is loaded in training mode (using the “train” group). If False, the dataset is loaded in test mode (using the “test” group).

transformcallable, optional, default=None

A transformation function to apply to each image. If None, no transformation is applied. Typically used for data augmentation or normalization.

nr_channelsint, optional, default=1

The number of channels in the image. USPS images are typically grayscale, so this should generally be set to 1. This parameter allows for potential future flexibility.

Attributes

imagesnumpy.ndarray

Array of images corresponding to digits 7, 8, and 9 from the USPS dataset. The images are loaded from the HDF5 file and filtered based on the labels.

labelsnumpy.ndarray

Array of labels corresponding to the images. Only labels of digits 7, 8, and 9 are retained, and they are mapped to 0, 1, and 2 for classification tasks.

transformcallable, optional

A transformation function to apply to the images. This is passed as an argument during initialization.

label_shiftfunction

A function to shift the labels for classification purposes. It maps the original labels (7, 8, 9) to (0, 1, 2).

label_restorefunction

A function to restore the original labels (7, 8, 9) from the shifted labels (0, 1, 2).

num_classesint

The number of unique labels in the dataset, which is 3 (for digits 7, 8, and 9).

filename = 'usps.h5'
filepath
transform = None
mode = 'test'
h5_path
sample_ids
nr_channels = 1
num_classes = 3
images
labels
label_shift
label_restore
__len__()

Returns the total number of samples in the dataset.

This method is required for PyTorch’s Dataset class, as it allows PyTorch to determine the size of the dataset.

Returns

int

The number of images in the dataset (after filtering for digits 7, 8, and 9).

__getitem__(id)

Returns a sample from the dataset given an index.

This method is required for PyTorch’s Dataset class, as it allows indexing into the dataset to retrieve specific samples.

Parameters

idxint

The index of the sample to retrieve from the dataset.

Returns

tuple

A tuple containing: - image (PIL Image): The image at the specified index. - label (int): The label corresponding to the image, shifted to be in the range [0, 2] for classification.