lightning_lens
LightningLens
LightningLens(model_name: str, lens_cls: type[Lens] | str, layer_num: int, lr: float = 0.001, **kwargs)
Bases: LightningModule
kl_loss
Compute the Kullback-Leibler divergence between tensors.
Quantifies the difference between the probability distribution of the model's output versus the probability distribution of the attention lens.
Parameters:
-
logits
(Tensor[d_vocab]
) –A probability distribution of the model's outputs.
-
lens_logits
(Tensor[d_vocab]
) –The output of the AttentionLens model acting on the entire layer from the attention mechanism.
Returns:
-
loss
(Tensor
) –(torch.Tensor[bsz]): Returns difference between logits and lens_logits
setup
Sets up the model and tokenizer during training setup.
Parameters:
-
stage
–The stage of the training process.
forward
Compute a forward pass through the Attention Lens
Takes the hook information of an entire layer of the attention mechanism, and computes the forward pass through that layer of Transformer Lens models.
Args:
cache (torch.Tensor[bsz, q_len, d_model]): The hooked information of an
entire layer of the attention mechanism.
Returns:
-
lens_out
(Tensor[bsz, d_vocab]
) –The prediction of the attention lens models for that layer.
training_step
Defines a single step in the training loop. Takes in an entire batch and computes the KL-loss for that batch.
Parameters:
-
train_batch
(Tensor
) –The batch (bsz) of data for the current training
-
batch_idx
(int
) –The index of the batch.
Returns:
-
Tensor
–torch.Tensor: The loss for the current training step.