CollaborativeCoding.metrics.recall

Classes

Recall

Recall metric.

Functions

one_hot_encode(vec, num_classes)

One-hot encode the target tensor.

Module Contents

CollaborativeCoding.metrics.recall.one_hot_encode(vec, num_classes)

One-hot encode the target tensor.

Args

vectorch.Tensor

Target tensor.

num_classesint

Number of classes in the dataset.

Returns

torch.Tensor

One-hot encoded tensor.

class CollaborativeCoding.metrics.recall.Recall(num_classes, macro_averaging=False)

Bases: torch.nn.Module

Recall metric.

Args

num_classesint

Number of classes in the dataset.

macro_averagingbool

If True, calculate the recall for each class and return the average. If False, calculate the recall for the entire dataset.

Methods

forward(y_true, y_pred)

Compute the recall metric.

Examples

>>> y_true = torch.tensor([0, 1, 2, 3, 4])
>>> y_pred = torch.randn(5, 5).argmax(dim=-1)
>>> recall = Recall(num_classes=5)
>>> recall(y_true, y_pred)
0.2
>>> recall = Recall(num_classes=5, macro_averaging=True)
>>> recall(y_true, y_pred)
0.2
num_classes
macro_averaging = False
__y_true = []
__y_pred = []
forward(true, logits)
compute(y_true, y_pred)
__compute_macro_averaging(y_true, y_pred)
__compute_micro_averaging(y_true, y_pred)
__returnmetric__()
__reset__()