Training neural network with DALI and Pax

This simple example shows how to train a neural network implemented in JAX with DALI data preprocessing. It builds on MNIST training example from Pax codebse that can be found here.

We will use MNIST in Caffe2 format from DALI_extra.

[2]:
import os

training_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
validation_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/testing/')

The first step is to create a pipeline definition function that will later be used to create instances of DALI pipelines. It defines all steps of the preprocessing. In this simple example we have fn.readers.caffe2 for reading data in Caffe2 format, fn.decoders.image for image decoding, fn.crop_mirror_normalize used to normalize the images and fn.reshape to adjust the shape of the output tensors.

This example focuses on how to use DALI pipeline with Pax. For more information on writing DALI pipelines look into Getting started and pipeline documentation.

[3]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types


@pipeline_def(device_id=0, num_threads=4, seed=0)
def mnist_pipeline(data_path, random_shuffle):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=random_shuffle,
        name="mnist_caffe2_reader")
    images = fn.decoders.image(
        jpegs, device='mixed', output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.], output_layout="HWC")

    labels = labels.gpu()
    labels = fn.reshape(labels, shape=[])

    return images, labels

This example uses Pax data input defined in Praxis. We will create a simple wrapper that uses DALIGenericIterator for JAX. This iterator runs DALI pipeline and converts its outputs to JAX arrays. To learn more about how DALI interfaces with JAX look into basic DALI and JAX tutorial.

[4]:
from praxis import base_input
from nvidia.dali.plugin import jax as dax


class MnistDaliInput(base_input.BaseInput):
    def __post_init__(self):
      super().__post_init__()

      data_path = training_data_path if self.is_training else validation_data_path

      training_pipeline = mnist_pipeline(data_path=data_path, random_shuffle=self.is_training, batch_size=self.batch_size)
      self._iterator = dax.DALIGenericIterator(
        training_pipeline,
        output_map=["inputs", "labels"],
        reader_name="mnist_caffe2_reader",
        auto_reset=True)

    def get_next(self):
      try:
        return next(self._iterator)
      except StopIteration:
        self._iterator.reset()
        return next(self._iterator)


    def reset(self) -> None:
      super().reset()
      self._iterator = self._iterator.reset()

MnistDaliInput can be used in Pax Experiment as a source of data. Code sample below shows how these two classes can be connected by defining datasets method of Experiment class.

def datasets(self) -> list[pax_fiddle.Config[base_input.BaseInput]]:
  return [
      pax_fiddle.Config(
          MnistDaliInput, batch_size=self.BATCH_SIZE, is_training=True
      )
  ]

For the full working example you can look into docs/examples/frameworks/jax/pax_examples. Code in this folder can be tested by running command below.

[5]:
!python -m paxml.main --job_log_dir=/tmp/dali_pax_logs --exp pax_examples.dali_pax_example.MnistExperiment 2>/dev/null

It produces log compatible with tensorboard under /tmp/dali_pax_logs.

To read this log in console we create a helper function that prints training accuracy from the logs. The logs created in this example are comaptible with TensorBoard and can be visualized using this tool.

[6]:
import os

from tensorflow.core.util import event_pb2
from tensorflow.python.lib.io import tf_record
from tensorflow.python.framework import tensor_util

def print_logs(path):
    "Helper function to print logs from logs directory created by paxml example"
    def summary_iterator():
        for r in tf_record.tf_record_iterator(path):
            yield event_pb2.Event.FromString(r)

    for summary in summary_iterator():
        for value in summary.summary.value:
            if value.tag == 'Metrics/accuracy':
                t = tensor_util.MakeNdarray(value.tensor)
                print(f"Iteration: {summary.step}, accuracy: {t}")

With this helper function we can print the accuracy of the training inside Python code.

[7]:
for file in os.listdir('/tmp/dali_pax_logs/summaries/train/'):
    print_logs(os.path.join('/tmp/dali_pax_logs/summaries/train/', file))
Iteration: 100, accuracy: 0.3935546875
Iteration: 200, accuracy: 0.5634765625
Iteration: 300, accuracy: 0.7275390625
Iteration: 400, accuracy: 0.8369140625
Iteration: 500, accuracy: 0.87109375
Iteration: 600, accuracy: 0.87890625
Iteration: 700, accuracy: 0.884765625
Iteration: 800, accuracy: 0.8994140625
Iteration: 900, accuracy: 0.8994140625
Iteration: 1000, accuracy: 0.90625