Shortcuts

quaterion.main module

class Quaterion[source]

Bases: object

Fine-tuning entry point

Contains methods to launch the actual training and evaluation processes.

classmethod evaluate(evaluator: Evaluator, dataset: Union[Sized, Iterable, Dataset], model: SimilarityModel) Dict[str, Tensor][source]

Compute metrics on a dataset

Parameters:
  • evaluator – Object which holds the configuration of which metrics to use and how to obtain samples for them

  • dataset – Sized object, like list, tuple, torch.utils.data.Dataset, etc. to compute metrics

  • model – SimilarityModel instance to perform objects encoding

Returns:

Dict[str, torch.Tensor] - dict of computed metrics. Where key - name of the metric and value - metric estimated values

classmethod fit(trainable_model: TrainableModel, trainer: Optional[Trainer], train_dataloader: SimilarityDataLoader, val_dataloader: Optional[SimilarityDataLoader] = None, ckpt_path: Optional[str] = None)[source]

Handle training routine

Assemble data loaders, performs caching and whole training process.

Parameters:
  • trainable_model – model to fit

  • trainerpytorch_lightning.Trainer instance to handle fitting routine internally. If None passed, trainer will be created with Quaterion.trainer_defaults(). The default parameters are intended to serve as a quick start for learning the model, and we encourage users to try different parameters if the default ones do not give a satisfactory result.

  • train_dataloader – DataLoader instance to retrieve samples during training stage

  • val_dataloader – Optional DataLoader instance to retrieve samples during validation stage

  • ckpt_path – Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.

static trainer_defaults(trainable_model: Optional[TrainableModel] = None, train_dataloader: Optional[SimilarityDataLoader] = None)[source]

Reasonable default parameters for pytorch_lightning.Trainer

This function generates parameter set for Trainer, which are considered “recommended” for most use-cases of Quaterion. Quaterion similarity learning train process has characteristics that differentiate it from regular deep learning model training. This default parameters may be overwritten, if you need some special behaviour for your special task.

Consider overriding default parameters if you need to adjust Trainer parameters:

Example:

trainer_kwargs = Quaterion.trainer_defaults(
    trainable_model=model,
    train_dataloader=train_dataloader
)
trainer_kwargs['logger'] = pl.loggers.WandbLogger(
    name="example_model",
    project="example_project",
)
trainer_kwargs['callbacks'].append(YourCustomCallback())
trainer = pl.Trainer(**trainer_kwargs)
Parameters:
  • trainable_model – We will try to adjust default params based on model configuration, if provided

  • train_dataloader – If provided, trainer params will be adjusted according to dataset

Returns:

kwargs for pytorch_lightning.Trainer

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community