Skip to content

Training Attention Lens

This section guides you through the process of training a lens model using train_lens.py and train.py.

train_lens.py

Description:

The file AttentionLens/attention_lens/train/train_lens.py contains the train_lens() function, which trains the given lens model using the specified data module and training configuration.

Key Components:

  • train_lens: Handles the training process for the given lens model.
  • training_precision: Defines training precision to be used based on config options.
  • strategy: Parameter to enable distributed data parallel strategy with unused parameter detection.
  • Checkpoint Handling: Searches for the most recent checkpoint if no specific checkpoint is provided.
Notes

  • The training precision is set to mixed precision (16-mixed) if config.mixed_precision is True, otherwise 32-bit precision is used.
  • The training uses a distributed data parallel strategy with unused parameter detection enabled: strategy="ddp_find_unused_parameters_true". (Necessary for GPU training, incompatible with CPU training.)
  • If no specific checkpoint to reload from is specified, the function searches for the most recent checkpoint in the checkpoint directory.

Usage:

train_lens(lens, data, config, callbacks=callbacks)

train.py

Description:

The file AttentionLens/train.py sets up the training configuration, initializes the model, data module and lens, and calls the train_lens function to begin the training process.

Usage

python train.py --lr 1e-3 --epochs 5 --batch_size 32 --num_nodes 2

See also: Running on Polaris.