Source code for quaterion.eval.base_metric
from typing import Tuple
import torch
from quaterion.distances import Distance
[docs]class BaseMetric:
"""Base class for evaluation metrics
Provides a default implementation for distance matrix calculation.
Args:
distance_metric_name: name of a distance metric to calculate distance or similarity
matrices. Available names could be found in :class:`~quaterion.distances.Distance`.
"""
def __init__(
self,
distance_metric_name: Distance = Distance.COSINE,
):
self.distance = Distance.get_by_name(distance_metric_name)
self._distance_metric_name = distance_metric_name
[docs] def compute(self, *args, **kwargs) -> torch.Tensor:
"""Compute metric value
Args:
args, kwargs - contain embeddings and targets required to compute metric.
Returns:
torch.Tensor - computed metric
"""
raise NotImplementedError()
[docs] def evaluate(self) -> torch.Tensor:
"""Perform metric computation with accumulated state"""
raise NotImplementedError()
[docs] def raw_compute(
self, distance_matrix: torch.Tensor, labels: torch.Tensor
) -> torch.Tensor:
"""Perform metric computation on ready distance_matrix and labels
This method does not make any data and labels preparation.
It is assumed that `distance_matrix` has already been calculated, required changes such
masking distance from an element to itself have already been applied and corresponding
`labels` have been prepared.
Args:
distance_matrix: distance matrix ready to metric computation
labels: labels ready to metric computation with the same shape as `distance_matrix`.
For `PairMetric` values are taken from `SimilarityPairSample.score`, for
`GroupMetric` the possible values are in {0, 1}.
Returns:
torch.Tensor - calculated metric value
"""
raise NotImplementedError()
[docs] def precompute(
self,
embeddings: torch.Tensor,
**targets,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prepares data for computation
Compute distance matrix and final labels based on groups.
Args:
embeddings: embeddings to compute metric value
targets: objects to compute final labels
Returns:
torch.Tensor, torch.Tensor - labels and distance matrix
"""
labels = self.prepare_labels(**targets)
distance_matrix = self.distance.distance_matrix(embeddings).detach()
self_mask = torch.eye(distance_matrix.shape[0], dtype=torch.bool)
distance_matrix[self_mask] = torch.max(distance_matrix) + 1
return labels.float(), distance_matrix
[docs] @staticmethod
def prepare_labels(**targets) -> torch.Tensor:
"""Compute metric labels
Args:
**targets: objects to compute final labels. `**targets` in PairMetric consists of
`labels`, `pairs` and `subgroups`, in GroupMetric - of `groups`.
Returns:
targets: torch.Tensor - labels to be used during metric computation
"""
raise NotImplementedError()