Skip to content

lensA

LensA

LensA(unembed, bias, n_head, d_model, d_vocab)

Bases: Lens

get_lens classmethod

get_lens(name: str) -> type[Lens]

This takes the name of the lens and queries the registry to grab the corresponding Lens subclas.

Parameters:

  • name (str) –

    The name of the child Lens implementation.

Returns:

  • type[Lens]

    Subclass Lens implementation.

forward

forward(input_tensor)

Performs a forward pass through the LensA model.

Parameters:

  • input_tensor (Tensor) –

    Input tensor of shape (batch_size, pos, n_head, d_model).

Returns:

  • torch.tensor: Output tensor of shape (batch_size, pos, d_vocab) after processing through

  • the linear layers and summing across the attention heads.