quaterion.train.trainable_model module¶
- class TrainableModel(*args: Any, **kwargs: Any)[source]¶
Bases:
LightningModule
,CacheMixin
Base class for models to be trained.
TrainableModel is used to describe how and which components of the model should be trained.
It assembles model from building blocks like
Encoder
,EncoderHead
, etc.┌─────────┐ ┌─────────┐ ┌─────────┐ │Encoder 1│ │Encoder 2│ │Encoder 3│ └────┬────┘ └────┬────┘ └────┬────┘ │ │ │ └────────┐ │ ┌────────┘ │ │ │ ┌───┴──┴──┴───┐ │ concat │ └──────┬──────┘ │ ┌──────┴──────┐ │ Head │ └─────────────┘
TrainableModel also handles the majority of the training process routine: training and validation steps, tensors device management, logging, and many more. Most of the training routines are inherited from
LightningModule
, which is a direct ancestor of TrainableModel.To train a model you need to inherit it from TrainableModel and implement required methods and attributes.
Minimal Example:
class ExampleModel(TrainableModel): def __init__(self, lr=10e-5, *args, **kwargs): self.lr = lr super().__init__(*args, **kwargs) # backbone of the model def configure_encoders(self): return YourAwesomeEncoder() # top layer of the model def configure_head(self, input_embedding_size: int): return SkipConnectionHead(input_embedding_size) def configure_optimizers(self): return Adam(self.model.parameters(), lr=self.lr) def configure_loss(self): return ContrastiveLoss()
- configure_caches() Optional[CacheConfig] [source]¶
Method to provide cache configuration
Use this method to define which encoders should cache calculated embeddings and what kind of cache they should use.
- Returns:
Optional[
CacheConfig
] – cache configuration to be applied if provided, None otherwise.
Examples:
Do not use cache (default):
return None
Configure cache automatically for all non-trainable encoders:
return CacheConfig(CacheType.AUTO)
Specify cache type for each encoder individually:
return CacheConfig(mapping={ "text_encoder": CacheType.GPU, # Store cache in GPU for `text_encoder` "image_encoder": CacheType.CPU # Store cache in RAM for `image_encoder` } )
Specify key for cache object disambiguation:
return CacheConfig( cache_type=CacheType.AUTO, key_extractors={"text_encoder": hash} )
This function might be useful if you want to provide some more sophisticated way of storing association between cached vectors and original object. Item numbers from dataset will be used by default if key is not specified.
- configure_encoders() Union[Encoder, Dict[str, Encoder]] [source]¶
Method to provide encoders configuration
Use this function to define an initial state of encoders. This function should be used to assign initial values for encoders before training as well as during the checkpoint loading.
- configure_head(input_embedding_size: int) EncoderHead [source]¶
Use this function to define an initial state for head layer of the model.
- Parameters:
input_embedding_size – size of embeddings produced by encoders
- Returns:
EncoderHead
– head to be added on top of a model
- configure_loss() SimilarityLoss [source]¶
Method to configure loss function to use.
- configure_metrics() Union[AttachedMetric, List[AttachedMetric]] [source]¶
Method to configure batch-wise metrics for a training process
Use this method to attach batch-wise metrics to a training process. Provided metrics have to have similar to
PairMetric
orGroupMetric
- Returns:
Union[
AttachedMetric
, List[AttachedMetric
]] - metrics attached to the model
Examples:
return [ AttachedMetric( "RetrievalPrecision", RetrievalPrecision(k=1), prog_bar=True, on_epoch=True, ), AttachedMetric( "RetrievalReciprocalRank", RetrievalReciprocalRank(), prog_bar=True, ), ]
- process_results(embeddings: Tensor, targets: Dict[str, Any], batch_idx: int, stage: TrainStage, **kwargs)[source]¶
Method to provide any additional evaluations of embeddings.
- Parameters:
embeddings – shape: (batch_size, embedding_size) - model’s output.
targets – output of batch target collate.
batch_idx – ID of the processing batch.
stage – train, validation or test stage.
- save_servable(path: str)[source]¶
Save model for serving, independent of Pytorch Lightning
- Parameters:
path – path to save to
- setup_dataloader(dataloader: SimilarityDataLoader)[source]¶
Setup data loader for encoder-specific settings, Setup encoder-specific collate function
Each encoder have its own unique way to transform a list of records into NN-compatible format. These transformations are usually done during data pre-processing step.
- property loss: SimilarityLoss¶
Property to get the loss function to use.
- property model: SimilarityModel¶
Origin model to be trained
- Returns:
SimilarityModel
– model to be trained
- training: bool¶