Training neural network with DALI and Paxml

This simple example shows how to train a neural network implemented in Paxml with DALI data preprocessing. It builds on MNIST training example from Paxml codebse that can be found here.

We use MNIST in Caffe2 format from DALI_extra as a data source.

import os

training_data_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/")
validation_data_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/")

The first step is to create the iterator definition function that will later be used to create instances of DALI iterators. It defines all steps of the preprocessing. In this simple example we have fn.readers.caffe2 for reading data in Caffe2 format, fn.decoders.image for image decoding, fn.crop_mirror_normalize used to normalize the images and fn.reshape to adjust the shape of the output tensors.

This example focuses on how to use DALI pipeline with Paxml. For more information on writing DALI iterators look into DALI and JAX getting started and pipeline documentation. To learn more about Paxml and how to write neural networks with it, look into Paxml Github page.

import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.jax import data_iterator

@data_iterator(output_map=["inputs", "labels"], reader_name="mnist_caffe2_reader", auto_reset=True)
def mnist_iterator(data_path, random_shuffle):
    jpegs, labels = fn.readers.caffe2(
        path=data_path, random_shuffle=random_shuffle, name="mnist_caffe2_reader"
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0], output_layout="HWC")

    labels = labels.gpu()
    labels = fn.reshape(labels, shape=[])

    return images, labels

This example uses Pax data input defined in Praxis. We will create a simple wrapper that uses DALI iterator for JAX as a data source.

from praxis import base_input
from nvidia.dali.plugin import jax as dax

class MnistDaliInput(base_input.BaseInput):
    def __post_init__(self):

        data_path = training_data_path if self.is_training else validation_data_path

        training_pipeline = mnist_iterator(
            data_path=data_path, random_shuffle=self.is_training, batch_size=self.batch_size
        self._iterator = dax.DALIGenericIterator(
            output_map=["inputs", "labels"],

    def get_next(self):
            return next(self._iterator)
        except StopIteration:
            return next(self._iterator)

    def reset(self) -> None:
        self._iterator = self._iterator.reset()

MnistDaliInput can be used in Pax Experiment as a source of data. Code sample below shows how these two classes can be connected by defining datasets method of Experiment class.

def datasets(self) -> list[pax_fiddle.Config[base_input.BaseInput]]:
  return [
          MnistDaliInput, batch_size=self.BATCH_SIZE, is_training=True

For the full working example you can look into docs/examples/frameworks/jax/pax_examples. Code in this folder can be tested by running command below.

!python -m paxml.main --job_log_dir=/tmp/dali_pax_logs --exp pax_examples.dali_pax_example.MnistExperiment 2>/dev/null

It produces a log compatible with tensorboard in /tmp/dali_pax_logs. We use a helper function that reads training accuracy from the logs and prints it in the terminal.

from tensorflow.core.util import event_pb2
from import tf_record
from tensorflow.python.framework import tensor_util

def print_logs(path):
    "Helper function to print logs from logs directory created by paxml example"

    def summary_iterator():
        for r in tf_record.tf_record_iterator(path):
            yield event_pb2.Event.FromString(r)

    for summary in summary_iterator():
        for value in summary.summary.value:
            if value.tag == "Metrics/accuracy":
                t = tensor_util.MakeNdarray(value.tensor)
                print(f"Iteration: {summary.step}, accuracy: {t}")

With this helper function we can print the accuracy of the training inside Python code.

for file in os.listdir("/tmp/dali_pax_logs/summaries/train/"):
    print_logs(os.path.join("/tmp/dali_pax_logs/summaries/train/", file))
Iteration: 100, accuracy: 0.3935546875
Iteration: 200, accuracy: 0.5634765625
Iteration: 300, accuracy: 0.728515625
Iteration: 400, accuracy: 0.8369140625
Iteration: 500, accuracy: 0.87109375
Iteration: 600, accuracy: 0.87890625
Iteration: 700, accuracy: 0.884765625
Iteration: 800, accuracy: 0.8994140625
Iteration: 900, accuracy: 0.8994140625
Iteration: 1000, accuracy: 0.90625