Backend plugin guide¶
A backend plugin provides a protein language model (PLM) embedding
implementation. protea-core dispatches embedding jobs to the
backend whose name attribute matches the model_backend field
of the EmbeddingConfig DB row.
Existing backends shipped in protea-backends: esm, t5,
ankh, esm3c.
The ABC¶
Your class must subclass protea_contracts.EmbeddingBackend
and implement two abstract methods:
from protea_contracts.embedding_backend import EmbeddingBackend
import numpy as np
from typing import Any
class EmbeddingBackend(ABC):
name: str
@abstractmethod
def load_model(
self,
model_name: str,
device: str,
emit: Any,
) -> tuple[Any, Any]:
... # Returns (model, tokenizer)
@abstractmethod
def embed_batch(
self,
model: Any,
tokenizer: Any,
sequences: list[str],
*,
emit: Any,
layers: list[int] | None = None,
layer_agg: str = "mean",
pooling: str = "mean",
) -> np.ndarray[Any, Any]:
... # Returns (batch_size, dim) float16 matrix
Key invariants:
load_modelreturns a(model, tokenizer)pair. Types areAnybecauseprotea-contractsdoes not depend ontorch.embed_batchreturns an(N, D)numpy.float16matrix. Special tokens (CLS,EOS,BOS) must be stripped before pooling.Nequalslen(sequences);Dis backend-specific.Both methods receive an
emitcallback. Use it to report structured events:emit("backend.mymodel.load_done", None, {"model_name": model_name}, "info")The signature is
emit(event: str, payload, fields: dict, level: str).Heavy ML imports belong inside
load_modelandembed_batch, not at module top (lazy-import rule).
Packaging snippet¶
pyproject.toml for protea-backends-mymodel (or a new entry
inside the shared protea-backends package):
[tool.poetry]
name = "protea-backends-mymodel"
version = "0.1.0"
packages = [{ include = "protea_backends_mymodel", from = "src" }]
[tool.poetry.dependencies]
python = ">=3.12,<4.0"
protea-contracts = ">=0.2"
numpy = ">=1.24"
# Heavy deps as extras so CI stays import-cheap:
torch = { version = ">=2.1", optional = true }
transformers = { version = ">=4.40", optional = true }
[tool.poetry.extras]
mymodel = ["torch", "transformers"]
[tool.poetry.plugins."protea.backends"]
mymodel = "protea_backends_mymodel:plugin"
Test scaffold¶
Copy and adapt the pattern from protea-backends/tests/test_esm.py:
"""Smoke tests for the mymodel backend plugin."""
from importlib.metadata import entry_points
from protea_contracts import EmbeddingBackend
from protea_backends_mymodel import MyModelBackend, plugin
def test_plugin_is_mymodel_instance() -> None:
assert isinstance(plugin, MyModelBackend)
def test_plugin_implements_embedding_backend_abc() -> None:
assert isinstance(plugin, EmbeddingBackend)
def test_plugin_name_matches_entry_point_key() -> None:
assert plugin.name == "mymodel"
def test_plugin_resolvable_via_entry_points() -> None:
eps = entry_points(group="protea.backends")
matches = [ep for ep in eps if ep.name == "mymodel"]
assert len(matches) == 1
assert matches[0].load() is plugin
def test_load_model_and_embed_batch_are_callable() -> None:
assert callable(plugin.load_model)
assert callable(plugin.embed_batch)
These five tests run without installing torch or any heavy dep.
Integration tests that actually call load_model / embed_batch
belong in protea-core’s test suite or in a separate extras-gated
job in CI.
Worked example: toy backend¶
The toy backend returns deterministic random vectors (seeded from
the input sequence). It has no heavy dependencies and can be run in
any environment. It is useful as a template and as a CI speed check
(zero GPU, instant embedding).
# src/protea_backends_toy/__init__.py
"""Toy backend: deterministic random embeddings seeded from sequence.
No torch, no transformers. Returns float16 vectors of a fixed
dimension (64). Useful as a template and in CI.
Install:
pip install -e .
Use in protea-core (after pip install so the entry-point registers):
EmbeddingConfig(model_backend="toy", model_name="toy_v0", ...)
"""
from __future__ import annotations
import hashlib
from typing import Any
import numpy as np
from protea_contracts.embedding_backend import EmbeddingBackend
_DIM = 64 # fixed output dimension
def _seed_from_sequence(seq: str) -> int:
"""Derive a 32-bit seed from the MD5 of a sequence string."""
return int(hashlib.md5(seq.encode()).hexdigest()[:8], 16)
class ToyBackend(EmbeddingBackend):
"""Deterministic random-vector backend for testing and templates.
Implements the full EmbeddingBackend contract without any heavy
ML dependency. ``load_model`` is a no-op; ``embed_batch`` seeds
numpy from each sequence's MD5 and returns a float16 matrix.
"""
name = "toy"
def load_model(
self,
model_name: str,
device: str,
emit: Any,
) -> tuple[Any, Any]:
"""No model to load; return (None, None) as the (model, tok) pair."""
emit("backend.toy.load", None, {"model_name": model_name}, "info")
return None, None
def embed_batch(
self,
model: Any,
tokenizer: Any,
sequences: list[str],
*,
emit: Any,
layers: list[int] | None = None,
layer_agg: str = "mean",
pooling: str = "mean",
) -> np.ndarray[Any, Any]:
"""Return a (N, 64) float16 matrix of deterministic random vectors.
Each row is seeded from the MD5 of the corresponding sequence
string so identical inputs always produce identical outputs
regardless of call order.
"""
if not sequences:
return np.zeros((0, _DIM), dtype=np.float16)
rows: list[np.ndarray[Any, Any]] = []
for seq in sequences:
rng = np.random.default_rng(_seed_from_sequence(seq))
vec = rng.standard_normal(_DIM).astype(np.float16)
rows.append(vec)
emit(
"backend.toy.embed_done",
None,
{"n_sequences": len(sequences)},
"info",
)
return np.stack(rows) # shape (N, 64), dtype float16
#: Module-level instance discovered via ``protea.backends`` entry_points.
plugin = ToyBackend()
The corresponding pyproject.toml entry-point stanza:
[tool.poetry.plugins."protea.backends"]
toy = "protea_backends_toy:plugin"
Verify the output directly:
from protea_backends_toy import plugin
noop = lambda *a, **k: None
model, tok = plugin.load_model("toy_v0", "cpu", emit=noop)
emb = plugin.embed_batch(model, tok, ["MKTII", "ACDEF"], emit=noop)
print(emb.shape) # Expected output: (2, 64)
print(emb.dtype) # Expected output: float16
# Calling again with the same input produces identical results:
emb2 = plugin.embed_batch(model, tok, ["MKTII", "ACDEF"], emit=noop)
assert (emb == emb2).all()