CollaborativeCoding.dataloaders.usps_0_6
Dataset class for USPS dataset with labels 0-6.
This module contains the Dataset class for the USPS dataset with labels 0-6.
Classes
Dataset class for USPS dataset with labels 0-6. |
Module Contents
- class CollaborativeCoding.dataloaders.usps_0_6.USPSDataset0_6(data_path: pathlib.Path, sample_ids: list, train: bool = False, transform=None, nr_channels=1)
Bases:
torch.utils.data.Dataset
Dataset class for USPS dataset with labels 0-6.
Args
- data_pathpathlib.Path
Path to the data directory.
- trainbool, optional
Mode of the dataset.
- transformcallable, optional
A function/transform that takes in a sample and returns a transformed version.
- downloadbool, optional
Whether to download the Dataset.
Attributes
- filepathpathlib.Path
Path to the USPS dataset file.
- modestr
Mode of the dataset, either train or test.
- transformcallable
A function/transform that takes in a sample and returns a transformed version.
- idxnumpy.ndarray
Indices of samples with labels 0-6.
- num_classesint
Number of classes in the dataset
Methods
- _index()
Get indices of samples with labels 0-6.
- _load_data(idx)
Load data and target label from the dataset.
- __len__()
Get the number of samples in the dataset.
- __getitem__(idx)
Get a sample from the dataset.
Examples
>>> from torchvision import transforms >>> from src.datahandlers import USPSDataset0_6 >>> transform = transforms.Compose([ ... transforms.Resize((16, 16)), ... transforms.ToTensor() ... ]) >>> dataset = USPSDataset0_6( ... data_path="data", ... transform=transform ... download=True, ... train=True, ... ) >>> len(dataset) 5460 >>> data, target = dataset[0] >>> data.shape (1, 16, 16) >>> target tensor([1., 0., 0., 0., 0., 0., 0.])
- filename = 'usps.h5'
- num_classes = 7
- filepath
- transform = None
- mode = 'test'
- sample_ids
- nr_channels = 1
- __len__()
- __getitem__(id)