CollaborativeCoding

Submodules

Classes

MetricWrapper

A wrapper class for evaluating multiple metrics on the same dataset.

Functions

get_args()

createfolders(→ None)

Creates folders for storing data, results, model weights.

load_data(→ tuple)

Load the dataset based on the dataset name.

load_model(→ torch.nn.Module)

Load the model based on the model name.

Package Contents

CollaborativeCoding.get_args()
CollaborativeCoding.createfolders(*dirs: pathlib.Path) None

Creates folders for storing data, results, model weights.

Parameters

args

ArgParse object containing string paths to be created

CollaborativeCoding.load_data(dataset: str, *args, **kwargs) tuple

Load the dataset based on the dataset name.

Args

datasetstr

Name of the dataset to load.

*argslist

Additional arguments for the dataset class.

**kwargsdict

Additional keyword arguments for the dataset class.

Returns

tuple

Tuple of train, validation and test datasets.

Raises

NotImplementedError

If the dataset is not implemented.

Examples

>>> from CollaborativeCoding import setup_data
>>> train, val, test = setup_data("usps_0-6", data_path="data", train=True, download=True)
>>> len(train), len(val), len(test)
(4914, 546, 1782)
class CollaborativeCoding.MetricWrapper(*metrics, num_classes, macro_averaging=False, **kwargs)

Bases: torch.nn.Module

A wrapper class for evaluating multiple metrics on the same dataset. This class allows you to compute several metrics simultaneously on given true and predicted labels. It supports a variety of common metrics and provides methods to accumulate results and reset the state.

Args

num_classesint

The number of classes in the classification task.

metricslist[str]

A list of metric names to be evaluated.

macro_averagingbool

Whether to compute macro-averaged metrics for multi-class classification.

Attributes

metricsdict

A dictionary mapping metric names to their corresponding functions.

num_classesint

The number of classes for the classification task.

Methods

__call__(y_true, y_pred)

Passes the true and predicted logits to the metric functions.

getmetrics(str_prefix: str = None)

Retrieves the dictionary of computed metrics, optionally all keys can be prefixed with a string.

resetmetric()

Resets the state of all metric computations.

Examples

>>> from CollaborativeCoding import MetricWrapperProposed
>>> metrics = MetricWrapperProposed(2, "entropy", "f1", "precision")
>>> y_true = [0, 1, 0, 1]
>>> y_pred = [[0.8, -1.9],
             [0.1,   9.0],
             [-1.9, -0.1],
             [1.9,   1.8]]
>>> metrics(y_true, y_pred)
>>> metrics.getmetrics()
{'entropy': 0.3292665, 'f1': 0.5, 'precision': 0.5}
>>> metrics.resetmetric()
>>> metrics.getmetrics()
{'entropy': [], 'f1': [], 'precision': []}
metrics
params
_get_metric(key)

Retrieves the metric function based on the provided key. Args —-

key (str): The name of the metric.

Returns

metric (callable): The function that computes the metric.

__call__(y_true, y_pred)
getmetrics(str_prefix: str = None)
resetmetric()
CollaborativeCoding.load_model(modelname: str, *args, **kwargs) torch.nn.Module

Load the model based on the model name.

Args

modelnamestr

Name of the model to load.

*argslist

Additional arguments for the model class.

**kwargsdict

Additional keyword arguments for the model class.

Returns

modeltorch.nn.Module

Model object.

Raises

NotImplementedError

If the model is not implemented.

Examples

>>> from CollaborativeCoding import load_model
>>> model = load_model("magnusmodel", num_classes=10)
>>> model
MagnusModel(
  (fc1): Linear(in_features=784, out_features=100, bias=True)
  (fc2): Linear(in_features=100, out_features=10, bias=True