vis4d.engine.training_module
LightningModule that wraps around the models, losses and optims.
Classes
|
LightningModule that wraps around the vis4d implementations. |
- class TrainingModule(model_cfg, optimizers_cfg, loss_module, train_data_connector, test_data_connector, hyper_parameters=None, seed=-1, ckpt_path=None, compute_flops=False, check_unused_parameters=False)[source]
LightningModule that wraps around the vis4d implementations.
This is a wrapper around the vis4d implementations that allows to use pytorch-lightning for training and testing.
- __init__(model_cfg, optimizers_cfg, loss_module, train_data_connector, test_data_connector, hyper_parameters=None, seed=-1, ckpt_path=None, compute_flops=False, check_unused_parameters=False)[source]
Initialize the TrainingModule.
- Parameters:
model_cfg (
ConfigDict) – The model config.optimizers_cfg (
list[OptimizerConfig]) – The optimizers config.loss_module (
None|LossModule) – The loss module.train_data_connector (
None|DataConnector) – The data connector to use.test_data_connector (
None|DataConnector) – The data connector to use.data_connector – The data connector to use.
hyper_parameters (DictStrAny | None, optional) – The hyper parameters to use. Defaults to None.
seed (int, optional) – The integer value seed for global random state. Defaults to -1. If -1, a random seed will be generated.
ckpt_path (str, optional) – The path to the checkpoint to load. Defaults to None.
compute_flops (bool, optional) – If to compute the FLOPs of the model. Defaults to False.
check_unused_parameters (bool, optional) – If to check the unused parameters. Defaults to False.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]
Perform a single validation step.
- Return type:
Dict[str,Any]
- test_step(batch, batch_idx, dataloader_idx=0)[source]
Perform a single test step.
- Return type:
Dict[str,Any]