CollaborativeCoding.dataloaders =============================== .. py:module:: CollaborativeCoding.dataloaders Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/CollaborativeCoding/dataloaders/datasources/index /autoapi/CollaborativeCoding/dataloaders/download/index /autoapi/CollaborativeCoding/dataloaders/mnist_0_3/index /autoapi/CollaborativeCoding/dataloaders/mnist_4_9/index /autoapi/CollaborativeCoding/dataloaders/svhn/index /autoapi/CollaborativeCoding/dataloaders/usps_0_6/index /autoapi/CollaborativeCoding/dataloaders/uspsh5_7_9/index Classes ------- .. autoapisummary:: CollaborativeCoding.dataloaders.Downloader CollaborativeCoding.dataloaders.MNISTDataset0_3 CollaborativeCoding.dataloaders.MNISTDataset4_9 CollaborativeCoding.dataloaders.SVHNDataset CollaborativeCoding.dataloaders.USPSDataset0_6 CollaborativeCoding.dataloaders.USPSH5_Digit_7_9_Dataset Package Contents ---------------- .. py:class:: Downloader Class used to verify availability and potentially download implemented datasets. Methods ------- mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray] Checks the availability of mnist dataset. If not present downloads it into MNIST folder in `data_dir`. svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray] Download the SVHN dataset and save it as an HDF5 file to `data_dir`. usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray] Download the USPS dataset and save it as an HDF5 file to `data_dir`. Raises ------ NotImplementedError If the download method is not implemented for the dataset. Examples -------- >>> from pathlib import Path >>> from CollaborativeCoding import Downloader >>> dir = Path('tmp') >>> dir.mkdir(exist_ok=True) >>> train, test = Downloader().usps(dir) .. py:method:: mnist(data_dir: pathlib.Path) -> tuple[numpy.ndarray, numpy.ndarray] Check the availability of mnist dataset. If not present downloads it into MNIST folder in `data_dir`. .. py:method:: svhn(data_dir: pathlib.Path) -> tuple[numpy.ndarray, numpy.ndarray] .. py:method:: usps(data_dir: pathlib.Path) -> tuple[numpy.ndarray, numpy.ndarray] Download the USPS dataset and save it as an HDF5 file to `data_dir/usps.h5`. .. py:method:: __extract_usps(src: pathlib.Path, dest: pathlib.Path, mode: str) .. py:method:: __reporthook(blocknum, blocksize, totalsize) :staticmethod: Use this function to report download progress for the urllib.request.urlretrieve function. .. py:method:: __check_integrity(filepath, checksum) :staticmethod: Check the integrity of the USPS dataset file. Args ---- filepath : pathlib.Path Path to the USPS dataset file. checksum : str MD5 checksum of the dataset file. Returns ------- bool True if the checksum of the file matches the expected checksum, False otherwise .. py:class:: MNISTDataset0_3(data_path: pathlib.Path, sample_ids: list, train: bool = False, transform=None, nr_channels: int = 1) Bases: :py:obj:`torch.utils.data.Dataset` A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3. Args ---------- data_path : Path The root directory where the MNIST folder with data is stored. sample_ids : list A list of indices specifying which samples to load. train : bool, optional If True, load training data, otherwise load test data. Default is False. transform : callable, optional A function/transform to apply to the images. Default is None. Attributes ---------- mnist_path : Path The directory where the MNIST dataset is located within the root directory. idx : list A list of indices specifying which samples to load. train : bool Indicates whether to load training data or test data. transform : callable A function/transform to apply to the images. num_classes : int The number of classes in the dataset (0 to 3). images_path : Path The path to the image file (train or test) based on the `train` flag. labels_path : Path The path to the label file (train or test) based on the `train` flag. length : int The number of samples in the dataset. Methods ------- __len__() Returns the number of samples in the dataset. __getitem__(index) Retrieves the image and label at the specified index. .. py:attribute:: mnist_path .. py:attribute:: idx .. py:attribute:: train :value: False .. py:attribute:: transform :value: None .. py:attribute:: num_classes :value: 4 .. py:attribute:: images_path .. py:attribute:: labels_path .. py:attribute:: length .. py:method:: __len__() .. py:method:: __getitem__(index) .. py:class:: MNISTDataset4_9(data_path: pathlib.Path, sample_ids: numpy.ndarray, train: bool = False, transform=None, nr_channels: int = 1) Bases: :py:obj:`torch.utils.data.Dataset` MNIST dataset of numbers 4-9. Parameters ---------- data_path : Path Root directory where MNIST dataset is stored sample_ids : np.ndarray Array of indices spcifying which samples to load. This determines the samples used by the dataloader. train : bool, optional Whether to train the model or not, by default False transorm : callable, optional Transform to apply to the images, by default None nr_channels : int, optional Number of channels in the images, by default 1 .. py:attribute:: data_path .. py:attribute:: mnist_path .. py:attribute:: samples .. py:attribute:: train :value: False .. py:attribute:: transform :value: None .. py:attribute:: num_classes :value: 6 .. py:attribute:: images_path .. py:attribute:: labels_path .. py:attribute:: label_shift .. py:attribute:: label_restore .. py:method:: __len__() .. py:method:: __getitem__(idx) .. py:class:: SVHNDataset(data_path: pathlib.Path, sample_ids: list, train: bool, transform=None, nr_channels=3) Bases: :py:obj:`torch.utils.data.Dataset` An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. .. py:attribute:: data_path .. py:attribute:: indexes .. py:attribute:: split :value: 'train' .. py:attribute:: nr_channels :value: 3 .. py:attribute:: transforms :value: None .. py:attribute:: num_classes .. py:method:: _create_h5py(path: str) Downloads the SVHN dataset to the specified directory. Args: path (str): The directory where the dataset will be downloaded. .. py:method:: __len__() Returns the number of samples in the dataset. Returns: int: The number of samples. .. py:method:: __getitem__(index) Retrieves the image and label at the specified index. Args: index (int): The index of the sample to retrieve. Returns: tuple: A tuple containing the image and its corresponding label. .. py:class:: USPSDataset0_6(data_path: pathlib.Path, sample_ids: list, train: bool = False, transform=None, nr_channels=1) Bases: :py:obj:`torch.utils.data.Dataset` Dataset class for USPS dataset with labels 0-6. Args ---- data_path : pathlib.Path Path to the data directory. train : bool, optional Mode of the dataset. transform : callable, optional A function/transform that takes in a sample and returns a transformed version. download : bool, optional Whether to download the Dataset. Attributes ---------- filepath : pathlib.Path Path to the USPS dataset file. mode : str Mode of the dataset, either train or test. transform : callable A function/transform that takes in a sample and returns a transformed version. idx : numpy.ndarray Indices of samples with labels 0-6. num_classes : int Number of classes in the dataset Methods ------- _index() Get indices of samples with labels 0-6. _load_data(idx) Load data and target label from the dataset. __len__() Get the number of samples in the dataset. __getitem__(idx) Get a sample from the dataset. Examples -------- >>> from torchvision import transforms >>> from src.datahandlers import USPSDataset0_6 >>> transform = transforms.Compose([ ... transforms.Resize((16, 16)), ... transforms.ToTensor() ... ]) >>> dataset = USPSDataset0_6( ... data_path="data", ... transform=transform ... download=True, ... train=True, ... ) >>> len(dataset) 5460 >>> data, target = dataset[0] >>> data.shape (1, 16, 16) >>> target tensor([1., 0., 0., 0., 0., 0., 0.]) .. py:attribute:: filename :value: 'usps.h5' .. py:attribute:: num_classes :value: 7 .. py:attribute:: filepath .. py:attribute:: transform :value: None .. py:attribute:: mode :value: 'test' .. py:attribute:: sample_ids .. py:attribute:: nr_channels :value: 1 .. py:method:: __len__() .. py:method:: __getitem__(id) .. py:class:: USPSH5_Digit_7_9_Dataset(data_path, sample_ids, train=False, transform=None, nr_channels=1) Bases: :py:obj:`torch.utils.data.Dataset` This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file. It allows for applying transformations to the images and provides methods to retrieve images and their corresponding labels. Parameters ---------- data_path : str or Path Path to the directory containing the USPS `.h5` file. This file should contain the data in the "train" or "test" group. sample_ids : list of int A list of sample indices to be used from the dataset. This allows for filtering or selecting a subset of the full dataset. train : bool, optional, default=False If `True`, the dataset is loaded in training mode (using the "train" group). If `False`, the dataset is loaded in test mode (using the "test" group). transform : callable, optional, default=None A transformation function to apply to each image. If `None`, no transformation is applied. Typically used for data augmentation or normalization. nr_channels : int, optional, default=1 The number of channels in the image. USPS images are typically grayscale, so this should generally be set to 1. This parameter allows for potential future flexibility. Attributes ---------- images : numpy.ndarray Array of images corresponding to digits 7, 8, and 9 from the USPS dataset. The images are loaded from the HDF5 file and filtered based on the labels. labels : numpy.ndarray Array of labels corresponding to the images. Only labels of digits 7, 8, and 9 are retained, and they are mapped to 0, 1, and 2 for classification tasks. transform : callable, optional A transformation function to apply to the images. This is passed as an argument during initialization. label_shift : function A function to shift the labels for classification purposes. It maps the original labels (7, 8, 9) to (0, 1, 2). label_restore : function A function to restore the original labels (7, 8, 9) from the shifted labels (0, 1, 2). num_classes : int The number of unique labels in the dataset, which is 3 (for digits 7, 8, and 9). .. py:attribute:: filename :value: 'usps.h5' .. py:attribute:: filepath .. py:attribute:: transform :value: None .. py:attribute:: mode :value: 'test' .. py:attribute:: h5_path .. py:attribute:: sample_ids .. py:attribute:: nr_channels :value: 1 .. py:attribute:: num_classes :value: 3 .. py:attribute:: images .. py:attribute:: labels .. py:attribute:: label_shift .. py:attribute:: label_restore .. py:method:: __len__() Returns the total number of samples in the dataset. This method is required for PyTorch's Dataset class, as it allows PyTorch to determine the size of the dataset. Returns ------- int The number of images in the dataset (after filtering for digits 7, 8, and 9). .. py:method:: __getitem__(id) Returns a sample from the dataset given an index. This method is required for PyTorch's Dataset class, as it allows indexing into the dataset to retrieve specific samples. Parameters ---------- idx : int The index of the sample to retrieve from the dataset. Returns ------- tuple A tuple containing: - image (PIL Image): The image at the specified index. - label (int): The label corresponding to the image, shifted to be in the range [0, 2] for classification.