quaterion.loss.triplet_loss module
- class TripletLoss(margin: Optional[float] = 1.0, distance_metric_name: Distance = Distance.COSINE, mining: Optional[str] = 'hard')[source]
- Bases: - GroupLoss- Implements Triplet Loss as defined in https://arxiv.org/abs/1503.03832 - It supports batch-all and batch-hard strategies for online triplet mining. - Parameters:
 - forward(embeddings: Tensor, groups: LongTensor) Tensor[source]
- Calculates Triplet Loss with specified embeddings and labels. - Parameters:
- embeddings (Tensor) – Batch of embeddings. Shape: (batch_size, embedding_dim) 
- groups (torch.LongTensor) – Batch of labels associated with embeddings. Shape: (batch_size,) 
 
- Returns:
- torch.Tensor – Scalar loss value. 
 
 - get_config_dict()[source]
- Config used in saving and loading purposes. - Config object has to be JSON-serializable. - Returns:
- Dict[str, Any] – JSON-serializable dict of params