CollaborativeCoding =================== .. py:module:: CollaborativeCoding Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/CollaborativeCoding/arg_parser/index /autoapi/CollaborativeCoding/createfolders/index /autoapi/CollaborativeCoding/dataloaders/index /autoapi/CollaborativeCoding/load_data/index /autoapi/CollaborativeCoding/load_metric/index /autoapi/CollaborativeCoding/load_model/index /autoapi/CollaborativeCoding/metrics/index /autoapi/CollaborativeCoding/models/index Classes ------- .. autoapisummary:: CollaborativeCoding.MetricWrapper Functions --------- .. autoapisummary:: CollaborativeCoding.get_args CollaborativeCoding.createfolders CollaborativeCoding.load_data CollaborativeCoding.load_model Package Contents ---------------- .. py:function:: get_args() .. py:function:: createfolders(*dirs: pathlib.Path) -> None Creates folders for storing data, results, model weights. Parameters ---------- args ArgParse object containing string paths to be created .. py:function:: load_data(dataset: str, *args, **kwargs) -> tuple Load the dataset based on the dataset name. Args ---- dataset : str Name of the dataset to load. *args : list Additional arguments for the dataset class. **kwargs : dict Additional keyword arguments for the dataset class. Returns ------- tuple Tuple of train, validation and test datasets. Raises ------ NotImplementedError If the dataset is not implemented. Examples -------- >>> from CollaborativeCoding import setup_data >>> train, val, test = setup_data("usps_0-6", data_path="data", train=True, download=True) >>> len(train), len(val), len(test) (4914, 546, 1782) .. py:class:: MetricWrapper(*metrics, num_classes, macro_averaging=False, **kwargs) Bases: :py:obj:`torch.nn.Module` A wrapper class for evaluating multiple metrics on the same dataset. This class allows you to compute several metrics simultaneously on given true and predicted labels. It supports a variety of common metrics and provides methods to accumulate results and reset the state. Args ---- num_classes : int The number of classes in the classification task. metrics : list[str] A list of metric names to be evaluated. macro_averaging : bool Whether to compute macro-averaged metrics for multi-class classification. Attributes ---------- metrics : dict A dictionary mapping metric names to their corresponding functions. num_classes : int The number of classes for the classification task. Methods ------- __call__(y_true, y_pred) Passes the true and predicted logits to the metric functions. getmetrics(str_prefix: str = None) Retrieves the dictionary of computed metrics, optionally all keys can be prefixed with a string. resetmetric() Resets the state of all metric computations. Examples -------- >>> from CollaborativeCoding import MetricWrapperProposed >>> metrics = MetricWrapperProposed(2, "entropy", "f1", "precision") >>> y_true = [0, 1, 0, 1] >>> y_pred = [[0.8, -1.9], [0.1, 9.0], [-1.9, -0.1], [1.9, 1.8]] >>> metrics(y_true, y_pred) >>> metrics.getmetrics() {'entropy': 0.3292665, 'f1': 0.5, 'precision': 0.5} >>> metrics.resetmetric() >>> metrics.getmetrics() {'entropy': [], 'f1': [], 'precision': []} .. py:attribute:: metrics .. py:attribute:: params .. py:method:: _get_metric(key) Retrieves the metric function based on the provided key. Args ---- key (str): The name of the metric. Returns ------- metric (callable): The function that computes the metric. .. py:method:: __call__(y_true, y_pred) .. py:method:: getmetrics(str_prefix: str = None) .. py:method:: resetmetric() .. py:function:: load_model(modelname: str, *args, **kwargs) -> torch.nn.Module Load the model based on the model name. Args ---- modelname : str Name of the model to load. *args : list Additional arguments for the model class. **kwargs : dict Additional keyword arguments for the model class. Returns ------- model : torch.nn.Module Model object. Raises ------ NotImplementedError If the model is not implemented. Examples -------- >>> from CollaborativeCoding import load_model >>> model = load_model("magnusmodel", num_classes=10) >>> model MagnusModel( (fc1): Linear(in_features=784, out_features=100, bias=True) (fc2): Linear(in_features=100, out_features=10, bias=True