CollaborativeCoding.load_metric
Attributes
Classes
A wrapper class for evaluating multiple metrics on the same dataset. |
Module Contents
- class CollaborativeCoding.load_metric.MetricWrapper(*metrics, num_classes, macro_averaging=False, **kwargs)
Bases:
torch.nn.Module
A wrapper class for evaluating multiple metrics on the same dataset. This class allows you to compute several metrics simultaneously on given true and predicted labels. It supports a variety of common metrics and provides methods to accumulate results and reset the state.
Args
- num_classesint
The number of classes in the classification task.
- metricslist[str]
A list of metric names to be evaluated.
- macro_averagingbool
Whether to compute macro-averaged metrics for multi-class classification.
Attributes
- metricsdict
A dictionary mapping metric names to their corresponding functions.
- num_classesint
The number of classes for the classification task.
Methods
- __call__(y_true, y_pred)
Passes the true and predicted logits to the metric functions.
- getmetrics(str_prefix: str = None)
Retrieves the dictionary of computed metrics, optionally all keys can be prefixed with a string.
- resetmetric()
Resets the state of all metric computations.
Examples
>>> from CollaborativeCoding import MetricWrapperProposed >>> metrics = MetricWrapperProposed(2, "entropy", "f1", "precision") >>> y_true = [0, 1, 0, 1] >>> y_pred = [[0.8, -1.9], [0.1, 9.0], [-1.9, -0.1], [1.9, 1.8]] >>> metrics(y_true, y_pred) >>> metrics.getmetrics() {'entropy': 0.3292665, 'f1': 0.5, 'precision': 0.5} >>> metrics.resetmetric() >>> metrics.getmetrics() {'entropy': [], 'f1': [], 'precision': []}
- metrics
- params
- _get_metric(key)
Retrieves the metric function based on the provided key. Args —-
key (str): The name of the metric.
Returns
metric (callable): The function that computes the metric.
- __call__(y_true, y_pred)
- getmetrics(str_prefix: str = None)
- resetmetric()
- CollaborativeCoding.load_metric.metrics = ['entropy', 'f1', 'recall', 'precision', 'accuracy']