CollaborativeCoding.metrics.precision
Classes
Metric module for precision. Can calculate both the micro- and macro-averaged precision. |
Module Contents
- class CollaborativeCoding.metrics.precision.Precision(num_classes: int, macro_averaging: bool = False)
Bases:
torch.nn.Module
Metric module for precision. Can calculate both the micro- and macro-averaged precision.
Parameters
- num_classesint
Number of classes in the dataset.
- macro_averagingbool
Performs macro-averaging if True, otherwise micro-averaging.
- num_classes
- macro_averaging = False
- y_true = []
- y_pred = []
- forward(y_true: torch.tensor, logits: torch.tensor) torch.tensor
Add true and predicted values to the class-global lists.
Parameters
- y_truetorch.tensor
True labels
- logitstorch.tensor
Predicted labels
- _micro_avg_precision(y_true: torch.tensor, y_pred: torch.tensor) torch.tensor
Compute micro-average precision by first calculating true/false positive across all classes and then find the precision.
Parameters
- y_truetorch.tensor
True labels
- y_predtorch.tensor
Predicted labels
Returns
- torch.tensor
Micro-averaged precision
- _macro_avg_precision(y_true: torch.tensor, y_pred: torch.tensor) torch.tensor
Compute macro-average precision by finding true/false positives of each class separately then averaging across all classes.
Parameters
- y_truetorch.tensor
True labels
- y_predtorch.tensor
Predicted labels
Returns
- torch.tensor
Macro-averaged precision