Lens Definition
This section covers the definition of the AttentionLens
base.py
Description:
The file AttentionLens/attention_lens/lens/base.py
defines the base Lens class, which is used as a foundation for creating specific lens models.
Key Components:
get_lens()
: Retrieves a lens class from the registry by name.
Usage:
lens_cls = Lens.get_lense(lens_cls)
lensA.py
Description:
The file AttentionLens/attention_lens/lens/registry/lensA.py
defines the LensA lens model. This particular lens is configured to create a lens for each head in a LM layer.
Usage:
lensA = LightningLens('gpt2', 'lensa', layer_num=7, lr=1e-3)
lightning_lens.py
Description:
The file AttentionLens/attention_lens/train/lightning_lens.py
prepares the Lens for training with train_lens.py
by configuring the lens, loss function, forward passes and the optimizer.
Key Components:
kl_loss
: Computes the Kullback-Leibler divergence loss between model logits and lens logits.setup
: Sets up the model and tokenizer during the training setup.forward
: Computes a forward pass through the Attention Lens.training_step
: Defines a single step in the training loop, and returns the resultant loss.configure_optimizer
: Configures the optimizer for training.
Usage:
lensA = LightningLens('gpt2', 'lensa', layer_num=7, lr=1e-3)