CollaborativeCoding.dataloaders.usps_0_6

Dataset class for USPS dataset with labels 0-6.

This module contains the Dataset class for the USPS dataset with labels 0-6.

Classes

USPSDataset0_6

Dataset class for USPS dataset with labels 0-6.

Module Contents

class CollaborativeCoding.dataloaders.usps_0_6.USPSDataset0_6(data_path: pathlib.Path, sample_ids: list, train: bool = False, transform=None, nr_channels=1)

Bases: torch.utils.data.Dataset

Dataset class for USPS dataset with labels 0-6.

Args

data_pathpathlib.Path

Path to the data directory.

trainbool, optional

Mode of the dataset.

transformcallable, optional

A function/transform that takes in a sample and returns a transformed version.

downloadbool, optional

Whether to download the Dataset.

Attributes

filepathpathlib.Path

Path to the USPS dataset file.

modestr

Mode of the dataset, either train or test.

transformcallable

A function/transform that takes in a sample and returns a transformed version.

idxnumpy.ndarray

Indices of samples with labels 0-6.

num_classesint

Number of classes in the dataset

Methods

_index()

Get indices of samples with labels 0-6.

_load_data(idx)

Load data and target label from the dataset.

__len__()

Get the number of samples in the dataset.

__getitem__(idx)

Get a sample from the dataset.

Examples

>>> from torchvision import transforms
>>> from src.datahandlers import USPSDataset0_6
>>> transform = transforms.Compose([
...     transforms.Resize((16, 16)),
...     transforms.ToTensor()
... ])
>>> dataset = USPSDataset0_6(
...     data_path="data",
...     transform=transform
...     download=True,
...     train=True,
... )
>>> len(dataset)
5460
>>> data, target = dataset[0]
>>> data.shape
(1, 16, 16)
>>> target
tensor([1., 0., 0., 0., 0., 0., 0.])
filename = 'usps.h5'
num_classes = 7
filepath
transform = None
mode = 'test'
sample_ids
nr_channels = 1
__len__()
__getitem__(id)