CollaborativeCoding.load_data

Functions

filter_labels(→ list)

load_data(→ tuple)

Load the dataset based on the dataset name.

Module Contents

CollaborativeCoding.load_data.filter_labels(samples: list, wanted_labels: list) list
CollaborativeCoding.load_data.load_data(dataset: str, *args, **kwargs) tuple

Load the dataset based on the dataset name.

Args

datasetstr

Name of the dataset to load.

*argslist

Additional arguments for the dataset class.

**kwargsdict

Additional keyword arguments for the dataset class.

Returns

tuple

Tuple of train, validation and test datasets.

Raises

NotImplementedError

If the dataset is not implemented.

Examples

>>> from CollaborativeCoding import setup_data
>>> train, val, test = setup_data("usps_0-6", data_path="data", train=True, download=True)
>>> len(train), len(val), len(test)
(4914, 546, 1782)