PLM Attention Complexity

Note

Dominant cost: O(L^2 * d) per sequence per layer.

Self-attention in a transformer computes a query-key dot product matrix of shape (L, L) for each attention head (d/h head dimension, h heads, so d total), giving O(L^2 * d) time and O(L^2 + L*d) memory per layer. With num_layers layers the total forward pass is O(num_layers * L^2 * d). Because num_layers and d are fixed per PLM checkpoint, the per-sequence cost grows quadratically in length L.

PROTEA supports four PLM families: ESM-2, ESM-C, T5/ProstT5, and Ankh. Each has a different sequence length cap and a different effective d:

PLM family

Max L (tokens)

Embed dim d

Notes

ESM-2 (150M, 650M)

1022 residues

480 / 1280

CLS and EOS excluded from pooling

ESM-C 300M / 600M

1022 residues

960 / 1152

HuggingFace EsmModel path

T5 / ProstT5-XL

configurable (max_length)

1024

Encoder-only forward; sequence3D prefix for ProstT5

Ankh (base, large)

configurable

768 / 1536

T5-family encoder

Chunking to bound VRAM

For sequences longer than the model’s max_length or when VRAM is tight, PROTEA splits the sequence into overlapping chunks and pools the per-chunk embeddings. The chunk logic is in _compute_chunk_spans:

def _compute_chunk_spans(length: int, chunk_size: int, overlap: int) -> list[tuple[int, int]]:
    """Compute (start, end) spans for overlapping chunks over a sequence of ``length`` residues.

    Raises ``ValueError`` if ``overlap >= chunk_size``; such a configuration
    would produce O(L) single-residue chunks or an infinite loop.
    """
    if overlap >= chunk_size:
        raise ValueError(
            f"chunk_overlap ({overlap}) must be strictly less than chunk_size ({chunk_size})"
        )
    step = chunk_size - overlap
    spans: list[tuple[int, int]] = []
    start = 0
    while start < length:
        end = min(start + chunk_size, length)
        spans.append((start, end))
        start += step
    return spans

Each chunk is an independent forward pass of cost O(chunk_size^2 * d). The number of chunks is ceil(L / (chunk_size - overlap)), so the total cost grows linearly in L once chunking kicks in.

ESM-2 / ESM-C forward pass

The production entry point for a single sequence is _embed_esm_one:

def _embed_esm_one(
    model: Any,
    tokenizer: Any,
    seq_str: str,
    config: EmbeddingConfig,
    device: str,
) -> list[ChunkEmbedding]:
    """ESM-2 forward pass + pooling for one sequence.

    Used by :func:`_embed_esm` to keep the per-batch loop body
    readable. Excludes CLS (position 0) and EOS (last valid position)
    from residue-level operations. ``attention_mask.sum()`` covers
    CLS + content + EOS, so the residue slice is ``[1:actual_len-1]``.
    """
    import torch
    import torch.nn.functional as F

    tokens = tokenizer(
        seq_str,
        return_tensors="pt",
        truncation=True,
        max_length=config.max_length,
        add_special_tokens=True,
    )
    tokens = {k: v.to(device) for k, v in tokens.items()}
    outputs = model(**tokens, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    valid_layers = _validate_layers(config.layer_indices, hidden_states, "ESM", seq_str[:20])
    if config.pooling == "cls":
        layer_tensors_1d = [hidden_states[-(li + 1)][0, 0, :].float() for li in valid_layers]
        pooled = _aggregate_1d(layer_tensors_1d, config.layer_agg)
        if config.normalize:
            pooled = F.normalize(pooled.unsqueeze(0), p=2, dim=1).squeeze(0)
        chunks = [ChunkEmbedding(0, None, pooled.cpu().numpy())]
    else:
        actual_len = int(tokens["attention_mask"].sum().item())
        layer_tensors_2d = [
            hidden_states[-(li + 1)][0, 1 : actual_len - 1, :].float() for li in valid_layers
        ]
        residues = _aggregate_residue_layers(layer_tensors_2d, config.layer_agg)
        if config.normalize_residues:
            residues = F.normalize(residues, p=2, dim=1)
        chunks = _chunk_and_pool(residues, config)
    del outputs, hidden_states
    torch.cuda.empty_cache()
    return chunks


The output_hidden_states=True flag forces the model to materialise all intermediate activations (needed for multi-layer pooling), which doubles peak VRAM compared to a single last-layer pass.

Cross-reference

Thesis Ch. 5.1 derives the per-sequence VRAM formula and includes a measured L-vs-throughput curve for ESM-2 650M on the A100 used in FARM-EXP.13.