quaterion.loss.online_contrastive_loss module¶
- class OnlineContrastiveLoss(margin: Optional[float] = 0.5, distance_metric_name: Distance = Distance.COSINE, mining: Optional[str] = 'hard')[source]¶
Bases:
GroupLoss
Implements Contrastive Loss as defined in http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
Unlike
ContrastiveLoss
, this one supports online pair mining, i.e., it makes positive and negative pairs on-the-fly, so you don’t need to form such pairs yourself. Instead, it first calculates all possible pairs in a batch, and then filters valid positive pairs and valid negative pairs separately. Batch-all and batch-hard strategies for online pair mining are supported.- Parameters:
- forward(embeddings: Tensor, groups: LongTensor) Tensor [source]¶
Calculates Contrastive Loss by making pairs on-the-fly.
- Parameters:
embeddings (Tensor) – Batch of embeddings. Shape: (batch_size, embedding_dim)
groups (torch.LongTensor) – Batch of labels associated with embeddings. Shape: (batch_size,)
- Returns:
torch.Tensor – Scalar loss value.
- get_config_dict()[source]¶
Config used in saving and loading purposes.
Config object has to be JSON-serializable.
- Returns:
Dict[str, Any] – JSON-serializable dict of params
- training: bool¶