Training a neural network with DALI and TorchData#
This notebook trains and validates a simple classifier on MNIST, using DALI in dynamic mode with torchdata.nodes for data loading.
[1]:
import os
from pathlib import Path
import nvidia.dali.experimental.dynamic as ndd
import nvidia.dali.types as types
import torch
import torch.nn as nn
import torchdata.nodes as tn
mnist_root = Path(os.environ["DALI_EXTRA_PATH"]) / "db" / "MNIST"
data_train = mnist_root / "training"
data_test = mnist_root / "testing"
BATCH_SIZE = 64
NUM_EPOCHS = 5
Image Processing#
We decode grayscale images on the GPU and normalize them using the standard MNIST mean and standard deviation.
[2]:
def process_images(jpegs: ndd.Batch) -> ndd.Batch:
images = ndd.decoders.image(jpegs, device="gpu", output_type=types.GRAY)
images = ndd.crop_mirror_normalize(
images,
dtype=types.FLOAT,
output_layout="CHW",
mean=[0.1307 * 255],
std=[0.3081 * 255],
)
return images
Data Loader Pipeline#
The data loading pipeline composes dynamic mode nodes with torchdata.nodes:
Reader reads batches from an LMDB dataset.
DictMapper applies our
process_imagesfunction to the"data"key.ToTorch converts DALI batches to PyTorch tensors, moving CPU data to GPU if necessary.
Prefetcher overlaps data loading with training.
[3]:
def build_loader(data_path: Path, batch_size: int, random_shuffle: bool = True):
reader_node = ndd.pytorch.nodes.Reader(
ndd.readers.Caffe2,
batch_size=batch_size,
path=data_path,
random_shuffle=random_shuffle,
)
mapper_node = ndd.pytorch.nodes.DictMapper(
source=reader_node,
map_fn=process_images,
)
torch_node = ndd.pytorch.nodes.ToTorch(mapper_node)
prefetch_node = tn.Prefetcher(torch_node, prefetch_factor=2)
return tn.Loader(prefetch_node)
Model Definition#
A simple fully-connected network for 28×28 grayscale images, classifying into 10 digit classes.
[4]:
class MNISTClassifier(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
def forward(self, x: torch.Tensor):
return self.model(x)
model = MNISTClassifier().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
Training Loop#
Train the model for a few epochs and report loss and accuracy.
[5]:
train_loader = build_loader(data_train, BATCH_SIZE, random_shuffle=True)
for epoch in range(NUM_EPOCHS):
model.train()
total_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
labels = labels.squeeze(-1).long()
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * labels.size(0)
correct += (output.argmax(1) == labels).sum().item()
total += labels.size(0)
print(
f"Epoch {epoch + 1}/{NUM_EPOCHS} — "
f"Loss: {total_loss / total:.4f}, "
f"Accuracy: {correct / total:.1%}"
)
Epoch 1/5 — Loss: 0.2596, Accuracy: 92.5%
Epoch 2/5 — Loss: 0.1115, Accuracy: 96.6%
Epoch 3/5 — Loss: 0.0767, Accuracy: 97.7%
Epoch 4/5 — Loss: 0.0577, Accuracy: 98.2%
Epoch 5/5 — Loss: 0.0453, Accuracy: 98.6%
Validation#
Evaluate the trained model on the test set.
[6]:
val_loader = build_loader(data_test, BATCH_SIZE, random_shuffle=False)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
labels = labels.squeeze(-1).long()
output = model(images)
correct += (output.argmax(1) == labels).sum().item()
total += labels.size(0)
print(f"Validation Accuracy: {100.0 * correct / total:.1f}%")
Validation Accuracy: 97.6%