Skip to content

train_lens

train_lens

train_lens(lens: LightningLens, data_module: DataModule, config: TrainConfig, callbacks: Optional[Union[list[Callback], Callback]] = None)

Trains the given lens model using the provided data module and training configuration.

Parameters:

  • lens (LightningLens) –

    The LightningLens model to be trained.

  • data_module (DataModule) –

    The DataModule providing the training and validation data.

  • config (TrainConfig) –

    The configuration settings for training.

  • callbacks (Optional[Union[list[Callback], Callback]], default: None ) –

    Optional list of callbacks or

Notes
  • The training precision is set to mixed precision (16-mixed) if config.mix_precision is True, otherwise 32-bit precision is used.
  • The training uses a distributed data parallel strategy with unused parameter detection enabled. (Necessary for GPU training, incompatible with CPU training.)
  • If no specific checkpoint to reload from is specified, the function searches for the most recent checkpoint in the checkpoint directory.

Returns:

  • trainer fits the lens according to data_module.

Examples:

>>> lens = LightningLens(config.model_name, "lensa", config.layer_number, config.lr)
>>> data = DataModule()
>>> train_lens(lens, data, config, callbacks=[checkpoint_callback, early_stop_callback])