CollaborativeCoding.metrics.precision

Classes

Precision

Metric module for precision. Can calculate both the micro- and macro-averaged precision.

Module Contents

class CollaborativeCoding.metrics.precision.Precision(num_classes: int, macro_averaging: bool = False)

Bases: torch.nn.Module

Metric module for precision. Can calculate both the micro- and macro-averaged precision.

Parameters

num_classesint

Number of classes in the dataset.

macro_averagingbool

Performs macro-averaging if True, otherwise micro-averaging.

num_classes
macro_averaging = False
y_true = []
y_pred = []
forward(y_true: torch.tensor, logits: torch.tensor) torch.tensor

Add true and predicted values to the class-global lists.

Parameters

y_truetorch.tensor

True labels

logitstorch.tensor

Predicted labels

_micro_avg_precision(y_true: torch.tensor, y_pred: torch.tensor) torch.tensor

Compute micro-average precision by first calculating true/false positive across all classes and then find the precision.

Parameters

y_truetorch.tensor

True labels

y_predtorch.tensor

Predicted labels

Returns

torch.tensor

Micro-averaged precision

_macro_avg_precision(y_true: torch.tensor, y_pred: torch.tensor) torch.tensor

Compute macro-average precision by finding true/false positives of each class separately then averaging across all classes.

Parameters

y_truetorch.tensor

True labels

y_predtorch.tensor

Predicted labels

Returns

torch.tensor

Macro-averaged precision

__returnmetric__()

Return the micro- or macro-averaged precision.

Returns

torch.tensor

Micro- or macro-averaged precision

__reset__()

Resets the class-global lists of true and predicted values to empty lists.

Returns

None

Returns None