CollaborativeCoding.metrics.recall
Classes
Recall metric. |
Functions
|
One-hot encode the target tensor. |
Module Contents
- CollaborativeCoding.metrics.recall.one_hot_encode(vec, num_classes)
One-hot encode the target tensor.
Args
- vectorch.Tensor
Target tensor.
- num_classesint
Number of classes in the dataset.
Returns
- torch.Tensor
One-hot encoded tensor.
- class CollaborativeCoding.metrics.recall.Recall(num_classes, macro_averaging=False)
Bases:
torch.nn.Module
Recall metric.
Args
- num_classesint
Number of classes in the dataset.
- macro_averagingbool
If True, calculate the recall for each class and return the average. If False, calculate the recall for the entire dataset.
Methods
- forward(y_true, y_pred)
Compute the recall metric.
Examples
>>> y_true = torch.tensor([0, 1, 2, 3, 4]) >>> y_pred = torch.randn(5, 5).argmax(dim=-1) >>> recall = Recall(num_classes=5) >>> recall(y_true, y_pred) 0.2 >>> recall = Recall(num_classes=5, macro_averaging=True) >>> recall(y_true, y_pred) 0.2
- num_classes
- macro_averaging = False
- __y_true = []
- __y_pred = []
- forward(true, logits)
- compute(y_true, y_pred)
- __compute_macro_averaging(y_true, y_pred)
- __compute_micro_averaging(y_true, y_pred)
- __returnmetric__()
- __reset__()