CollaborativeCoding.dataloaders.uspsh5_7_9 ========================================== .. py:module:: CollaborativeCoding.dataloaders.uspsh5_7_9 Classes ------- .. autoapisummary:: CollaborativeCoding.dataloaders.uspsh5_7_9.USPSH5_Digit_7_9_Dataset Module Contents --------------- .. py:class:: USPSH5_Digit_7_9_Dataset(data_path, sample_ids, train=False, transform=None, nr_channels=1) Bases: :py:obj:`torch.utils.data.Dataset` This class loads a subset of the USPS dataset, specifically images of digits 7, 8, and 9, from an HDF5 file. It allows for applying transformations to the images and provides methods to retrieve images and their corresponding labels. Parameters ---------- data_path : str or Path Path to the directory containing the USPS `.h5` file. This file should contain the data in the "train" or "test" group. sample_ids : list of int A list of sample indices to be used from the dataset. This allows for filtering or selecting a subset of the full dataset. train : bool, optional, default=False If `True`, the dataset is loaded in training mode (using the "train" group). If `False`, the dataset is loaded in test mode (using the "test" group). transform : callable, optional, default=None A transformation function to apply to each image. If `None`, no transformation is applied. Typically used for data augmentation or normalization. nr_channels : int, optional, default=1 The number of channels in the image. USPS images are typically grayscale, so this should generally be set to 1. This parameter allows for potential future flexibility. Attributes ---------- images : numpy.ndarray Array of images corresponding to digits 7, 8, and 9 from the USPS dataset. The images are loaded from the HDF5 file and filtered based on the labels. labels : numpy.ndarray Array of labels corresponding to the images. Only labels of digits 7, 8, and 9 are retained, and they are mapped to 0, 1, and 2 for classification tasks. transform : callable, optional A transformation function to apply to the images. This is passed as an argument during initialization. label_shift : function A function to shift the labels for classification purposes. It maps the original labels (7, 8, 9) to (0, 1, 2). label_restore : function A function to restore the original labels (7, 8, 9) from the shifted labels (0, 1, 2). num_classes : int The number of unique labels in the dataset, which is 3 (for digits 7, 8, and 9). .. py:attribute:: filename :value: 'usps.h5' .. py:attribute:: filepath .. py:attribute:: transform :value: None .. py:attribute:: mode :value: 'test' .. py:attribute:: h5_path .. py:attribute:: sample_ids .. py:attribute:: nr_channels :value: 1 .. py:attribute:: num_classes :value: 3 .. py:attribute:: images .. py:attribute:: labels .. py:attribute:: label_shift .. py:attribute:: label_restore .. py:method:: __len__() Returns the total number of samples in the dataset. This method is required for PyTorch's Dataset class, as it allows PyTorch to determine the size of the dataset. Returns ------- int The number of images in the dataset (after filtering for digits 7, 8, and 9). .. py:method:: __getitem__(id) Returns a sample from the dataset given an index. This method is required for PyTorch's Dataset class, as it allows indexing into the dataset to retrieve specific samples. Parameters ---------- idx : int The index of the sample to retrieve from the dataset. Returns ------- tuple A tuple containing: - image (PIL Image): The image at the specified index. - label (int): The label corresponding to the image, shifted to be in the range [0, 2] for classification.