Source code for vak.metrics.classification.functional
import torch
[docs]
def accuracy(y_pred, y_true):
"""standard supervised learning classification accuracy:
Sum of predicted labels that equal true labels, divided by number of true labels.
Parameters
----------
y_pred : torch.Tensor
y_true : torch.Tensor
Returns
-------
acc : float
between 0 and 1. Sum of predicted labels that equal true labels,
divided by number of true labels.
"""
correct = torch.eq(y_pred, y_true).view(-1)
return correct.sum().item() / correct.shape[0]