Training Attention Lens
This section guides you through the process of training a lens model using train_lens.py and train.py.
train_lens.py
Description:
The file AttentionLens/attention_lens/train/train_lens.py contains the train_lens() function, which trains the given lens model using the specified data module and training configuration.
Key Components:
train_lens: Handles the training process for the given lens model.training_precision: Defines training precision to be used based on config options.strategy: Parameter to enable distributed data parallel strategy with unused parameter detection.- Checkpoint Handling: Searches for the most recent checkpoint if no specific checkpoint is provided.
Notes
- The training precision is set to mixed precision (16-mixed) if config.mixed_precision is True, otherwise 32-bit precision is used.
- The training uses a distributed data parallel strategy with unused parameter detection enabled:
strategy="ddp_find_unused_parameters_true". (Necessary for GPU training, incompatible with CPU training.) - If no specific checkpoint to reload from is specified, the function searches for the most recent checkpoint in the checkpoint directory.
Usage:
train_lens(lens, data, config, callbacks=callbacks)
train.py
Description:
The file AttentionLens/train.py sets up the training configuration, initializes the model, data module and lens, and calls the train_lens function to begin the training process.
Usage
python train.py --lr 1e-3 --epochs 5 --batch_size 32 --num_nodes 2
See also: Running on Polaris.