CollaborativeCoding.load_metric

Attributes

metrics

Classes

MetricWrapper

A wrapper class for evaluating multiple metrics on the same dataset.

Module Contents

class CollaborativeCoding.load_metric.MetricWrapper(*metrics, num_classes, macro_averaging=False, **kwargs)

Bases: torch.nn.Module

A wrapper class for evaluating multiple metrics on the same dataset. This class allows you to compute several metrics simultaneously on given true and predicted labels. It supports a variety of common metrics and provides methods to accumulate results and reset the state.

Args

num_classesint

The number of classes in the classification task.

metricslist[str]

A list of metric names to be evaluated.

macro_averagingbool

Whether to compute macro-averaged metrics for multi-class classification.

Attributes

metricsdict

A dictionary mapping metric names to their corresponding functions.

num_classesint

The number of classes for the classification task.

Methods

__call__(y_true, y_pred)

Passes the true and predicted logits to the metric functions.

getmetrics(str_prefix: str = None)

Retrieves the dictionary of computed metrics, optionally all keys can be prefixed with a string.

resetmetric()

Resets the state of all metric computations.

Examples

>>> from CollaborativeCoding import MetricWrapperProposed
>>> metrics = MetricWrapperProposed(2, "entropy", "f1", "precision")
>>> y_true = [0, 1, 0, 1]
>>> y_pred = [[0.8, -1.9],
             [0.1,   9.0],
             [-1.9, -0.1],
             [1.9,   1.8]]
>>> metrics(y_true, y_pred)
>>> metrics.getmetrics()
{'entropy': 0.3292665, 'f1': 0.5, 'precision': 0.5}
>>> metrics.resetmetric()
>>> metrics.getmetrics()
{'entropy': [], 'f1': [], 'precision': []}
metrics
params
_get_metric(key)

Retrieves the metric function based on the provided key. Args —-

key (str): The name of the metric.

Returns

metric (callable): The function that computes the metric.

__call__(y_true, y_pred)
getmetrics(str_prefix: str = None)
resetmetric()
CollaborativeCoding.load_metric.metrics = ['entropy', 'f1', 'recall', 'precision', 'accuracy']