Skip to content

lightning_lens

LightningLens

LightningLens(model_name: str, lens_cls: type[Lens] | str, layer_num: int, lr: float = 0.001, **kwargs)

Bases: LightningModule

kl_loss

kl_loss(logits, lens_logits) -> torch.Tensor

Compute the Kullback-Leibler divergence between tensors.

Quantifies the difference between the probability distribution of the model's output versus the probability distribution of the attention lens.

\[ D_{KL} (\text{logits} \Vert \text{lens\_logits}) \]

Parameters:

  • logits (Tensor[d_vocab]) –

    A probability distribution of the model's outputs.

  • lens_logits (Tensor[d_vocab]) –

    The output of the AttentionLens model acting on the entire layer from the attention mechanism.

Returns:

  • loss ( Tensor ) –

    (torch.Tensor[bsz]): Returns difference between logits and lens_logits

setup

setup(stage) -> None

Sets up the model and tokenizer during training setup.

Parameters:

  • stage

    The stage of the training process.

forward

forward(cache) -> torch.Tensor

Compute a forward pass through the Attention Lens

Takes the hook information of an entire layer of the attention mechanism, and computes the forward pass through that layer of Transformer Lens models.

Args:

cache (torch.Tensor[bsz, q_len, d_model]): The hooked information of an

    entire layer of the attention mechanism.

Returns:

  • lens_out ( Tensor[bsz, d_vocab] ) –

    The prediction of the attention lens models for that layer.

training_step

training_step(train_batch: torch.Tensor, batch_idx: int) -> torch.Tensor

Defines a single step in the training loop. Takes in an entire batch and computes the KL-loss for that batch.

Parameters:

  • train_batch (Tensor) –

    The batch (bsz) of data for the current training

  • batch_idx (int) –

    The index of the batch.

Returns:

  • Tensor

    torch.Tensor: The loss for the current training step.

configure_optimizers

configure_optimizers() -> torch.optim.Optimizer

Configures the optimizer for training.

Returns:

  • Optimizer

    torch.optim.Optimizer: The optimizer for training.