Shortcuts

Source code for quaterion.eval.attached_metric

from typing import Optional

from quaterion.eval.base_metric import BaseMetric
from quaterion.utils.enums import TrainStage


[docs]class AttachedMetric: """Attach batch-wise metric to :class:`~quaterion.train.trainable_model.TrainableModel` Contain required parameters to compute and log batch-wise metric during training process. Args: name: name of an attached metric to be used in log. metric: metric to be calculated. on_step: Logs the metric at the current step. on_epoch: Automatically accumulates and logs at the end of the epoch. prog_bar: Logs to the progress bar (Default: False). logger: Logs to the logger like Tensorboard, or any other custom logger passed to the Trainer (Default: True). **log_options: additional kwargs to be passed to model's log. The remaining options can be found at: https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html """ def __init__( self, name: str, metric: BaseMetric, logger: bool = True, prog_bar: bool = False, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, **log_options, ): self._metric = metric self.stages = [TrainStage.TRAIN, TrainStage.VALIDATION] self.name = name self.log_options = { "logger": logger, "prog_bar": prog_bar, "on_step": on_step, "on_epoch": on_epoch, **log_options, } def __getattr__(self, item: str): prevent_lookup = {"_metric", "name"} if item in prevent_lookup: raise AttributeError( "Prevents recursion. " "Tried to access the field which has to be presented in an initialized instance." ) try: return getattr(self._metric, item) except AttributeError as ae: raise AttributeError( f"`AttachedMetric` object (<{self.name}>) has no attribute <{item}>" ) from ae

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community