Source code for vis4d.engine.callbacks.util

"""PyTorch Lightning callbacks utilities."""

from __future__ import annotations

import lightning.pytorch as pl
from torch import nn

from vis4d.engine.loss_module import LossModule
from vis4d.engine.training_module import TrainingModule


[docs] def get_model(model: pl.LightningModule) -> nn.Module: """Get model from pl module.""" if isinstance(model, TrainingModule): return model.model return model
[docs] def get_loss_module(loss_module: pl.LightningModule) -> LossModule: """Get loss_module from pl module.""" assert hasattr(loss_module, "loss_module") and isinstance( loss_module.loss_module, LossModule ), "Loss module is not set in the training module." return loss_module.loss_module