quaterion.loss.triplet_loss module¶
- class TripletLoss(margin: Optional[float] = 1.0, distance_metric_name: Distance = Distance.COSINE, mining: Optional[str] = 'hard')[source]¶
Bases:
GroupLoss
Implements Triplet Loss as defined in https://arxiv.org/abs/1503.03832
It supports batch-all and batch-hard strategies for online triplet mining.
- Parameters:
- forward(embeddings: Tensor, groups: LongTensor) Tensor [source]¶
Calculates Triplet Loss with specified embeddings and labels.
- Parameters:
embeddings (Tensor) – Batch of embeddings. Shape: (batch_size, embedding_dim)
groups (torch.LongTensor) – Batch of labels associated with embeddings. Shape: (batch_size,)
- Returns:
torch.Tensor – Scalar loss value.
- get_config_dict()[source]¶
Config used in saving and loading purposes.
Config object has to be JSON-serializable.
- Returns:
Dict[str, Any] – JSON-serializable dict of params
- training: bool¶