← Back to all writing
2026-01-10·1 min read

Architecture of torchwisdom

Design decisions behind the torchwisdom PyTorch training framework and how it simplifies the training loop

Why Another Training Framework?

When I started building production ML models, I found myself writing the same boilerplate training code over and over. PyTorch Lightning was too opinionated for my workflow, and raw PyTorch loops meant duplicating validation, checkpointing, and logging logic across projects.

torchwisdom was born as a middle ground — a thin framework that handles the boring parts while staying out of your way.

Core Architecture

The framework revolves around three concepts:

  1. Trainer — orchestrates the training loop, handles device placement, and manages callbacks
  2. Callback — hook system for extending behavior (logging, early stopping, checkpointing)
  3. Metric — composable metric computation with automatic accumulation
trainer.py
class Trainer:
    def __init__(self, model, optimizer, callbacks=None):
        self.model = model
        self.optimizer = optimizer
        self.callbacks = callbacks or []
 
    def fit(self, train_loader, val_loader=None, epochs=10):
        for epoch in range(epochs):
            self._run_epoch(train_loader, mode="train")
            if val_loader:
                self._run_epoch(val_loader, mode="val")
            self._on_epoch_end(epoch)
python

Design Decisions

No Magic

Every operation is explicit. Unlike frameworks that auto-detect your model's forward signature, torchwisdom requires you to define a training_step method. This makes debugging straightforward.

Callbacks Over Inheritance

Instead of subclassing a base trainer, you compose behavior through callbacks:

usage.py
trainer = Trainer(
    model=my_model,
    optimizer=optimizer,
    callbacks=[
        EarlyStopping(patience=5),
        ModelCheckpoint(path="checkpoints/"),
        WandbLogger(project="my-experiment"),
    ],
)
python

Metric System

Metrics accumulate across batches and reset per epoch automatically:

metrics.py
accuracy = Accuracy()
for batch in loader:
    accuracy.update(preds, targets)
print(f"Epoch accuracy: {accuracy.compute():.4f}")
python

What I Learned

Building a framework taught me more about PyTorch internals than any tutorial. Key insights:

  • Keep the API surface small — every feature you add is a feature you maintain
  • Composition beats inheritance for extensibility
  • Good defaults matter more than configurability

The framework is used in several production projects including the TabLogs OCR pipeline.