Shortcuts

quaterion.loss.triplet_loss module

class TripletLoss(margin: Optional[float] = 1.0, distance_metric_name: Distance = Distance.COSINE, mining: Optional[str] = 'hard')[source]

Bases: GroupLoss

Implements Triplet Loss as defined in https://arxiv.org/abs/1503.03832

It supports batch-all and batch-hard strategies for online triplet mining.

Parameters:
  • margin – Margin value to push negative examples apart. Optional, defaults to 0.5.

  • distance_metric_name – Name of the distance function, e.g., Distance. Optional, defaults to COSINE.

  • mining (str, optional) – Triplet mining strategy. One of “all”, “hard”. Defaults to “hard”.

forward(embeddings: Tensor, groups: LongTensor) Tensor[source]

Calculates Triplet Loss with specified embeddings and labels.

Parameters:
  • embeddings (Tensor) – Batch of embeddings. Shape: (batch_size, embedding_dim)

  • groups (torch.LongTensor) – Batch of labels associated with embeddings. Shape: (batch_size,)

Returns:

torch.Tensor – Scalar loss value.

get_config_dict()[source]

Config used in saving and loading purposes.

Config object has to be JSON-serializable.

Returns:

Dict[str, Any] – JSON-serializable dict of params

training: bool

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community