CollaborativeCoding.dataloaders
Submodules
Classes
Class used to verify availability and potentially download implemented datasets. |
|
A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3. |
|
MNIST dataset of numbers 4-9. |
|
An abstract class representing a |
|
Dataset class for USPS dataset with labels 0-6. |
|
This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file. |
Package Contents
- class CollaborativeCoding.dataloaders.Downloader
Class used to verify availability and potentially download implemented datasets.
Methods
- mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
Checks the availability of mnist dataset. If not present downloads it into MNIST folder in data_dir.
- svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
Download the SVHN dataset and save it as an HDF5 file to data_dir.
- usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
Download the USPS dataset and save it as an HDF5 file to data_dir.
Raises
- NotImplementedError
If the download method is not implemented for the dataset.
Examples
>>> from pathlib import Path >>> from CollaborativeCoding import Downloader >>> dir = Path('tmp') >>> dir.mkdir(exist_ok=True) >>> train, test = Downloader().usps(dir)
- mnist(data_dir: pathlib.Path) tuple[numpy.ndarray, numpy.ndarray]
Check the availability of mnist dataset. If not present downloads it into MNIST folder in data_dir.
- svhn(data_dir: pathlib.Path) tuple[numpy.ndarray, numpy.ndarray]
- usps(data_dir: pathlib.Path) tuple[numpy.ndarray, numpy.ndarray]
Download the USPS dataset and save it as an HDF5 file to data_dir/usps.h5.
- __extract_usps(src: pathlib.Path, dest: pathlib.Path, mode: str)
- static __reporthook(blocknum, blocksize, totalsize)
Use this function to report download progress for the urllib.request.urlretrieve function.
- class CollaborativeCoding.dataloaders.MNISTDataset0_3(data_path: pathlib.Path, sample_ids: list, train: bool = False, transform=None, nr_channels: int = 1)
Bases:
torch.utils.data.Dataset
A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.
Args
- data_pathPath
The root directory where the MNIST folder with data is stored.
- sample_idslist
A list of indices specifying which samples to load.
- trainbool, optional
If True, load training data, otherwise load test data. Default is False.
- transformcallable, optional
A function/transform to apply to the images. Default is None.
Attributes
- mnist_pathPath
The directory where the MNIST dataset is located within the root directory.
- idxlist
A list of indices specifying which samples to load.
- trainbool
Indicates whether to load training data or test data.
- transformcallable
A function/transform to apply to the images.
- num_classesint
The number of classes in the dataset (0 to 3).
- images_pathPath
The path to the image file (train or test) based on the train flag.
- labels_pathPath
The path to the label file (train or test) based on the train flag.
- lengthint
The number of samples in the dataset.
Methods
- __len__()
Returns the number of samples in the dataset.
- __getitem__(index)
Retrieves the image and label at the specified index.
- mnist_path
- idx
- train = False
- transform = None
- num_classes = 4
- images_path
- labels_path
- length
- __len__()
- __getitem__(index)
- class CollaborativeCoding.dataloaders.MNISTDataset4_9(data_path: pathlib.Path, sample_ids: numpy.ndarray, train: bool = False, transform=None, nr_channels: int = 1)
Bases:
torch.utils.data.Dataset
MNIST dataset of numbers 4-9.
Parameters
- data_pathPath
Root directory where MNIST dataset is stored
- sample_idsnp.ndarray
Array of indices spcifying which samples to load. This determines the samples used by the dataloader.
- trainbool, optional
Whether to train the model or not, by default False
- transormcallable, optional
Transform to apply to the images, by default None
- nr_channelsint, optional
Number of channels in the images, by default 1
- data_path
- mnist_path
- samples
- train = False
- transform = None
- num_classes = 6
- images_path
- labels_path
- label_shift
- label_restore
- __len__()
- __getitem__(idx)
- class CollaborativeCoding.dataloaders.SVHNDataset(data_path: pathlib.Path, sample_ids: list, train: bool, transform=None, nr_channels=3)
Bases:
torch.utils.data.Dataset
An abstract class representing a
Dataset
.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite
__getitem__()
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__()
, which is expected to return the size of the dataset by manySampler
implementations and the default options ofDataLoader
. Subclasses could also optionally implement__getitems__()
, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.Note
DataLoader
by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.- data_path
- indexes
- split = 'train'
- nr_channels = 3
- transforms = None
- num_classes
- _create_h5py(path: str)
Downloads the SVHN dataset to the specified directory. Args:
path (str): The directory where the dataset will be downloaded.
- __len__()
Returns the number of samples in the dataset. Returns:
int: The number of samples.
- __getitem__(index)
Retrieves the image and label at the specified index. Args:
index (int): The index of the sample to retrieve.
- Returns:
tuple: A tuple containing the image and its corresponding label.
- class CollaborativeCoding.dataloaders.USPSDataset0_6(data_path: pathlib.Path, sample_ids: list, train: bool = False, transform=None, nr_channels=1)
Bases:
torch.utils.data.Dataset
Dataset class for USPS dataset with labels 0-6.
Args
- data_pathpathlib.Path
Path to the data directory.
- trainbool, optional
Mode of the dataset.
- transformcallable, optional
A function/transform that takes in a sample and returns a transformed version.
- downloadbool, optional
Whether to download the Dataset.
Attributes
- filepathpathlib.Path
Path to the USPS dataset file.
- modestr
Mode of the dataset, either train or test.
- transformcallable
A function/transform that takes in a sample and returns a transformed version.
- idxnumpy.ndarray
Indices of samples with labels 0-6.
- num_classesint
Number of classes in the dataset
Methods
- _index()
Get indices of samples with labels 0-6.
- _load_data(idx)
Load data and target label from the dataset.
- __len__()
Get the number of samples in the dataset.
- __getitem__(idx)
Get a sample from the dataset.
Examples
>>> from torchvision import transforms >>> from src.datahandlers import USPSDataset0_6 >>> transform = transforms.Compose([ ... transforms.Resize((16, 16)), ... transforms.ToTensor() ... ]) >>> dataset = USPSDataset0_6( ... data_path="data", ... transform=transform ... download=True, ... train=True, ... ) >>> len(dataset) 5460 >>> data, target = dataset[0] >>> data.shape (1, 16, 16) >>> target tensor([1., 0., 0., 0., 0., 0., 0.])
- filename = 'usps.h5'
- num_classes = 7
- filepath
- transform = None
- mode = 'test'
- sample_ids
- nr_channels = 1
- __len__()
- __getitem__(id)
- class CollaborativeCoding.dataloaders.USPSH5_Digit_7_9_Dataset(data_path, sample_ids, train=False, transform=None, nr_channels=1)
Bases:
torch.utils.data.Dataset
This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file. It allows for applying transformations to the images and provides methods to retrieve images and their corresponding labels.
Parameters
- data_pathstr or Path
Path to the directory containing the USPS .h5 file. This file should contain the data in the “train” or “test” group.
- sample_idslist of int
A list of sample indices to be used from the dataset. This allows for filtering or selecting a subset of the full dataset.
- trainbool, optional, default=False
If True, the dataset is loaded in training mode (using the “train” group). If False, the dataset is loaded in test mode (using the “test” group).
- transformcallable, optional, default=None
A transformation function to apply to each image. If None, no transformation is applied. Typically used for data augmentation or normalization.
- nr_channelsint, optional, default=1
The number of channels in the image. USPS images are typically grayscale, so this should generally be set to 1. This parameter allows for potential future flexibility.
Attributes
- imagesnumpy.ndarray
Array of images corresponding to digits 7, 8, and 9 from the USPS dataset. The images are loaded from the HDF5 file and filtered based on the labels.
- labelsnumpy.ndarray
Array of labels corresponding to the images. Only labels of digits 7, 8, and 9 are retained, and they are mapped to 0, 1, and 2 for classification tasks.
- transformcallable, optional
A transformation function to apply to the images. This is passed as an argument during initialization.
- label_shiftfunction
A function to shift the labels for classification purposes. It maps the original labels (7, 8, 9) to (0, 1, 2).
- label_restorefunction
A function to restore the original labels (7, 8, 9) from the shifted labels (0, 1, 2).
- num_classesint
The number of unique labels in the dataset, which is 3 (for digits 7, 8, and 9).
- filename = 'usps.h5'
- filepath
- transform = None
- mode = 'test'
- h5_path
- sample_ids
- nr_channels = 1
- num_classes = 3
- images
- labels
- label_shift
- label_restore
- __len__()
Returns the total number of samples in the dataset.
This method is required for PyTorch’s Dataset class, as it allows PyTorch to determine the size of the dataset.
Returns
- int
The number of images in the dataset (after filtering for digits 7, 8, and 9).
- __getitem__(id)
Returns a sample from the dataset given an index.
This method is required for PyTorch’s Dataset class, as it allows indexing into the dataset to retrieve specific samples.
Parameters
- idxint
The index of the sample to retrieve from the dataset.
Returns
- tuple
A tuple containing: - image (PIL Image): The image at the specified index. - label (int): The label corresponding to the image, shifted to be in the range [0, 2] for classification.