PLM Attention Complexity¶
Note
Dominant cost: O(L^2 * d) per sequence per layer.
Self-attention in a transformer computes a query-key dot product matrix of
shape (L, L) for each attention head (d/h head dimension, h
heads, so d total), giving O(L^2 * d) time and O(L^2 + L*d) memory
per layer. With num_layers layers the total forward pass is
O(num_layers * L^2 * d). Because num_layers and d are fixed per
PLM checkpoint, the per-sequence cost grows quadratically in length L.
PROTEA supports four PLM families: ESM-2, ESM-C, T5/ProstT5, and Ankh. Each
has a different sequence length cap and a different effective d:
PLM family |
Max L (tokens) |
Embed dim |
Notes |
|---|---|---|---|
ESM-2 (150M, 650M) |
1022 residues |
480 / 1280 |
CLS and EOS excluded from pooling |
ESM-C 300M / 600M |
1022 residues |
960 / 1152 |
HuggingFace |
T5 / ProstT5-XL |
configurable ( |
1024 |
Encoder-only forward; sequence3D prefix for ProstT5 |
Ankh (base, large) |
configurable |
768 / 1536 |
T5-family encoder |
Chunking to bound VRAM
For sequences longer than the model’s max_length or when VRAM is
tight, PROTEA splits the sequence into overlapping chunks and pools the
per-chunk embeddings. The chunk logic is in
_compute_chunk_spans:
def _compute_chunk_spans(length: int, chunk_size: int, overlap: int) -> list[tuple[int, int]]:
"""Compute (start, end) spans for overlapping chunks over a sequence of ``length`` residues.
Raises ``ValueError`` if ``overlap >= chunk_size``; such a configuration
would produce O(L) single-residue chunks or an infinite loop.
"""
if overlap >= chunk_size:
raise ValueError(
f"chunk_overlap ({overlap}) must be strictly less than chunk_size ({chunk_size})"
)
step = chunk_size - overlap
spans: list[tuple[int, int]] = []
start = 0
while start < length:
end = min(start + chunk_size, length)
spans.append((start, end))
start += step
return spans
Each chunk is an independent forward pass of cost O(chunk_size^2 * d). The number of chunks is ceil(L / (chunk_size - overlap)), so the total cost grows linearly in L once chunking kicks in.
ESM-2 / ESM-C forward pass
The production entry point for a single sequence is _embed_esm_one:
def _embed_esm_one(
model: Any,
tokenizer: Any,
seq_str: str,
config: EmbeddingConfig,
device: str,
) -> list[ChunkEmbedding]:
"""ESM-2 forward pass + pooling for one sequence.
Used by :func:`_embed_esm` to keep the per-batch loop body
readable. Excludes CLS (position 0) and EOS (last valid position)
from residue-level operations. ``attention_mask.sum()`` covers
CLS + content + EOS, so the residue slice is ``[1:actual_len-1]``.
"""
import torch
import torch.nn.functional as F
tokens = tokenizer(
seq_str,
return_tensors="pt",
truncation=True,
max_length=config.max_length,
add_special_tokens=True,
)
tokens = {k: v.to(device) for k, v in tokens.items()}
outputs = model(**tokens, output_hidden_states=True)
hidden_states = outputs.hidden_states
valid_layers = _validate_layers(config.layer_indices, hidden_states, "ESM", seq_str[:20])
if config.pooling == "cls":
layer_tensors_1d = [hidden_states[-(li + 1)][0, 0, :].float() for li in valid_layers]
pooled = _aggregate_1d(layer_tensors_1d, config.layer_agg)
if config.normalize:
pooled = F.normalize(pooled.unsqueeze(0), p=2, dim=1).squeeze(0)
chunks = [ChunkEmbedding(0, None, pooled.cpu().numpy())]
else:
actual_len = int(tokens["attention_mask"].sum().item())
layer_tensors_2d = [
hidden_states[-(li + 1)][0, 1 : actual_len - 1, :].float() for li in valid_layers
]
residues = _aggregate_residue_layers(layer_tensors_2d, config.layer_agg)
if config.normalize_residues:
residues = F.normalize(residues, p=2, dim=1)
chunks = _chunk_and_pool(residues, config)
del outputs, hidden_states
torch.cuda.empty_cache()
return chunks
The output_hidden_states=True flag forces the model to materialise
all intermediate activations (needed for multi-layer pooling), which
doubles peak VRAM compared to a single last-layer pass.
Cross-reference
Thesis Ch. 5.1 derives the per-sequence VRAM formula and includes a measured L-vs-throughput curve for ESM-2 650M on the A100 used in FARM-EXP.13.