{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Training a neural network with DALI and JAX\n", "\n", "This simple example shows how to train a neural network implemented in JAX with DALI pipelines. It builds on MNIST training example from JAX codebase that can be found [here](https://github.com/google/jax/blob/jax-v0.4.13/examples/mnist_classifier_fromscratch.py).\n", "\n", "We will use MNIST in Caffe2 format from [DALI_extra](https://github.com/NVIDIA/DALI_extra)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2023-07-28T07:43:41.850101Z", "iopub.status.busy": "2023-07-28T07:43:41.849672Z", "iopub.status.idle": "2023-07-28T07:43:41.853520Z", "shell.execute_reply": "2023-07-28T07:43:41.852990Z" } }, "outputs": [], "source": [ "import os\n", "\n", "training_data_path = os.path.join(os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/training/\")\n", "validation_data_path = os.path.join(os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/testing/\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "First step is to create a definition function that will later be used to create instances of DALI iterators. It defines all steps of the preprocessing. \n", "\n", "In this simple example we have `fn.readers.caffe2` for reading data in Caffe2 format, `fn.decoders.image` for image decoding, `fn.crop_mirror_normalize` used to normalize the images and `fn.reshape` to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with `labels.gpu()`. Our model expects labels to be in one-hot encoding, so we use `fn.one_hot` to convert them.\n", "\n", "This example focuses on how to use DALI to train a model defined in JAX. For more information on DALI and JAX integration look into [Getting started with JAX and DALI](jax-getting_started.ipynb) and [pipeline documentation](../../../pipeline.rst)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-07-28T07:43:41.855441Z", "iopub.status.busy": "2023-07-28T07:43:41.855301Z", "iopub.status.idle": "2023-07-28T07:43:41.986406Z", "shell.execute_reply": "2023-07-28T07:43:41.985500Z" } }, "outputs": [], "source": [ "from nvidia.dali.plugin.jax import data_iterator\n", "import nvidia.dali.fn as fn\n", "import nvidia.dali.types as types\n", "\n", "\n", "batch_size = 200\n", "image_size = 28\n", "num_classes = 10\n", "\n", "\n", "@data_iterator(output_map=[\"images\", \"labels\"], reader_name=\"caffe2_reader\")\n", "def mnist_iterator(data_path, random_shuffle):\n", " jpegs, labels = fn.readers.caffe2(\n", " path=data_path, random_shuffle=random_shuffle, name=\"caffe2_reader\"\n", " )\n", " images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.GRAY)\n", " images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0], output_layout=\"CHW\")\n", " images = fn.reshape(images, shape=[image_size * image_size])\n", "\n", " labels = labels.gpu()\n", "\n", " if random_shuffle:\n", " labels = fn.one_hot(labels, num_classes=num_classes)\n", "\n", " return images, labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we use the function to create DALI iterators for training and validation." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-07-28T07:43:41.989183Z", "iopub.status.busy": "2023-07-28T07:43:41.988964Z", "iopub.status.idle": "2023-07-28T07:43:42.104446Z", "shell.execute_reply": "2023-07-28T07:43:42.103668Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating iterators\n", "\n", "\n" ] } ], "source": [ "print(\"Creating iterators\")\n", "\n", "training_iterator = mnist_iterator(\n", " data_path=training_data_path, random_shuffle=True, batch_size=batch_size\n", ")\n", "\n", "validation_iterator = mnist_iterator(\n", " data_path=validation_data_path, random_shuffle=False, batch_size=batch_size\n", ")\n", "\n", "print(training_iterator)\n", "print(validation_iterator)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "With the setup above, DALI iterators are ready for the training. \n", "\n", "Finally, we import training utilities implemented in JAX. `init_model` will create the model instance and initialize its parameters. In this simple example it is a MLP model with two hidden layers. `update` performs one iteration of the training. `accuracy` is a helper function to run validation after each epoch on the test set and get current accuracy of the model." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-07-28T07:43:43.559575Z", "iopub.status.busy": "2023-07-28T07:43:43.559420Z", "iopub.status.idle": "2023-07-28T07:43:43.618221Z", "shell.execute_reply": "2023-07-28T07:43:43.617532Z" } }, "outputs": [], "source": [ "from model import init_model, update, accuracy" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "At this point, everything is ready to run the training." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-07-28T07:43:43.622376Z", "iopub.status.busy": "2023-07-28T07:43:43.621205Z", "iopub.status.idle": "2023-07-28T07:43:58.016073Z", "shell.execute_reply": "2023-07-28T07:43:58.015333Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting training\n", "Epoch 0 sec\n", "Test set accuracy 0.67330002784729\n", "Epoch 1 sec\n", "Test set accuracy 0.7855000495910645\n", "Epoch 2 sec\n", "Test set accuracy 0.8251000642776489\n", "Epoch 3 sec\n", "Test set accuracy 0.8469000458717346\n", "Epoch 4 sec\n", "Test set accuracy 0.8616000413894653\n" ] } ], "source": [ "print(\"Starting training\")\n", "\n", "model = init_model()\n", "num_epochs = 5\n", "\n", "for epoch in range(num_epochs):\n", " for batch in training_iterator:\n", " model = update(model, batch)\n", "\n", " test_acc = accuracy(model, validation_iterator)\n", " print(f\"Epoch {epoch} sec\")\n", " print(f\"Test set accuracy {test_acc}\")" ] } ], "metadata": { "celltoolbar": "Raw Cell Format", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }