lensA
LensA
Bases: Lens
get_lens
classmethod
This takes the name of the lens and queries the registry to grab the corresponding Lens
subclas.
Parameters:
-
name
(str
) –The name of the child
Lens
implementation.
Returns:
-
type[Lens]
–Subclass
Lens
implementation.
forward
Performs a forward pass through the LensA model.
Parameters:
-
input_tensor
(Tensor
) –Input tensor of shape (batch_size, pos, n_head, d_model).
Returns:
-
–
torch.tensor: Output tensor of shape (batch_size, pos, d_vocab) after processing through
-
–
the linear layers and summing across the attention heads.