Skip to content

Lens Definition

This section covers the definition of the AttentionLens

base.py

Description:

The file AttentionLens/attention_lens/lens/base.py defines the base Lens class, which is used as a foundation for creating specific lens models.

Key Components:

  • get_lens(): Retrieves a lens class from the registry by name.

Usage:

lens_cls = Lens.get_lense(lens_cls)

lensA.py

Description:

The file AttentionLens/attention_lens/lens/registry/lensA.py defines the LensA lens model. This particular lens is configured to create a lens for each head in a LM layer.

Usage:

lensA = LightningLens('gpt2', 'lensa', layer_num=7, lr=1e-3)

lightning_lens.py

Description:

The file AttentionLens/attention_lens/train/lightning_lens.py prepares the Lens for training with train_lens.py by configuring the lens, loss function, forward passes and the optimizer.

Key Components:

  • kl_loss: Computes the Kullback-Leibler divergence loss between model logits and lens logits.
  • setup: Sets up the model and tokenizer during the training setup.
  • forward: Computes a forward pass through the Attention Lens.
  • training_step: Defines a single step in the training loop, and returns the resultant loss.
  • configure_optimizer: Configures the optimizer for training.

Usage:

lensA = LightningLens('gpt2', 'lensa', layer_num=7, lr=1e-3)