Source code for vis4d.engine.callbacks.base

"""Base module for callbacks."""

from __future__ import annotations

import lightning.pytorch as pl
from torch import Tensor

from vis4d.common.typing import DictStrArrNested
from vis4d.data.typing import DictData
from vis4d.engine.connectors import CallbackConnector


[docs] class Callback(pl.Callback): """Base class for Callbacks."""
[docs] def __init__( self, epoch_based: bool = True, train_connector: None | CallbackConnector = None, test_connector: None | CallbackConnector = None, ) -> None: """Init callback. Args: epoch_based (bool, optional): Whether the callback is epoch based. Defaults to False. train_connector (None | CallbackConnector, optional): Defines which kwargs to use during training for different callbacks. Defaults to None. test_connector (None | CallbackConnector, optional): Defines which kwargs to use during testing for different callbacks. Defaults to None. """ self.epoch_based = epoch_based self.train_connector = train_connector self.test_connector = test_connector
[docs] def setup( self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str ) -> None: """Setup callback."""
[docs] def get_train_callback_inputs( self, outputs: DictData, batch: DictData ) -> dict[str, Tensor | DictStrArrNested]: """Returns the data connector results for training. It extracts the required data from prediction and datas and passes it to the next component with the provided new key. Args: outputs (DictData): Outputs of the model. batch (DictData): Batch data. Returns: dict[str, Tensor | DictStrArrNested]: Data connector results. Raises: AssertionError: If train connector is None. """ assert self.train_connector is not None, "Train connector is None." return self.train_connector(outputs, batch)
[docs] def get_test_callback_inputs( self, outputs: DictData, batch: DictData ) -> dict[str, Tensor | DictStrArrNested]: """Returns the data connector results for inference. It extracts the required data from prediction and datas and passes it to the next component with the provided new key. Args: outputs (DictData): Outputs of the model. batch (DictData): Batch data. Returns: dict[str, Tensor | DictStrArrNested]: Data connector results. Raises: AssertionError: If test connector is None. """ assert self.test_connector is not None, "Test connector is None." return self.test_connector(outputs, batch)