Training a neural network with DALI and TorchData#

This notebook trains and validates a simple classifier on MNIST, using DALI in dynamic mode with torchdata.nodes for data loading.

[1]:
import os
from pathlib import Path

import nvidia.dali.experimental.dynamic as ndd
import nvidia.dali.types as types
import torch
import torch.nn as nn
import torchdata.nodes as tn

mnist_root = Path(os.environ["DALI_EXTRA_PATH"]) / "db" / "MNIST"
data_train = mnist_root / "training"
data_test = mnist_root / "testing"

BATCH_SIZE = 64
NUM_EPOCHS = 5

Image Processing#

We decode grayscale images on the GPU and normalize them using the standard MNIST mean and standard deviation.

[2]:
def process_images(jpegs: ndd.Batch) -> ndd.Batch:
    images = ndd.decoders.image(jpegs, device="gpu", output_type=types.GRAY)
    images = ndd.crop_mirror_normalize(
        images,
        dtype=types.FLOAT,
        output_layout="CHW",
        mean=[0.1307 * 255],
        std=[0.3081 * 255],
    )
    return images

Data Loader Pipeline#

The data loading pipeline composes dynamic mode nodes with torchdata.nodes:

  1. Reader reads batches from an LMDB dataset.

  2. DictMapper applies our process_images function to the "data" key.

  3. ToTorch converts DALI batches to PyTorch tensors, moving CPU data to GPU if necessary.

  4. Prefetcher overlaps data loading with training.

[3]:
def build_loader(data_path: Path, batch_size: int, random_shuffle: bool = True):
    reader_node = ndd.pytorch.nodes.Reader(
        ndd.readers.Caffe2,
        batch_size=batch_size,
        path=data_path,
        random_shuffle=random_shuffle,
    )
    mapper_node = ndd.pytorch.nodes.DictMapper(
        source=reader_node,
        map_fn=process_images,
    )
    torch_node = ndd.pytorch.nodes.ToTorch(mapper_node)
    prefetch_node = tn.Prefetcher(torch_node, prefetch_factor=2)
    return tn.Loader(prefetch_node)

Model Definition#

A simple fully-connected network for 28×28 grayscale images, classifying into 10 digit classes.

[4]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x: torch.Tensor):
        return self.model(x)


model = MNISTClassifier().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Training Loop#

Train the model for a few epochs and report loss and accuracy.

[5]:
train_loader = build_loader(data_train, BATCH_SIZE, random_shuffle=True)

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        labels = labels.squeeze(-1).long()

        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        correct += (output.argmax(1) == labels).sum().item()
        total += labels.size(0)

    print(
        f"Epoch {epoch + 1}/{NUM_EPOCHS} — "
        f"Loss: {total_loss / total:.4f}, "
        f"Accuracy: {correct / total:.1%}"
    )
Epoch 1/5 — Loss: 0.2596, Accuracy: 92.5%
Epoch 2/5 — Loss: 0.1115, Accuracy: 96.6%
Epoch 3/5 — Loss: 0.0767, Accuracy: 97.7%
Epoch 4/5 — Loss: 0.0577, Accuracy: 98.2%
Epoch 5/5 — Loss: 0.0453, Accuracy: 98.6%

Validation#

Evaluate the trained model on the test set.

[6]:
val_loader = build_loader(data_test, BATCH_SIZE, random_shuffle=False)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        labels = labels.squeeze(-1).long()
        output = model(images)
        correct += (output.argmax(1) == labels).sum().item()
        total += labels.size(0)

print(f"Validation Accuracy: {100.0 * correct / total:.1f}%")
Validation Accuracy: 97.6%