{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using DALI in PyTorch Lightning\n", "\n", "### Overview\n", "\n", "This example shows how to use DALI in PyTorch Lightning.\n", "\n", "Let us grab [a toy example](https://pytorch-lightning.readthedocs.io/en/1.6.1/starter/core_guide.html) showcasing a classification network and see how DALI can accelerate it.\n", "\n", "The DALI_EXTRA_PATH environment variable should point to a [DALI extra](https://github.com/NVIDIA/DALI_extra) copy. Please make sure that the proper release tag, the one associated with your DALI version, is checked out." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.nn import functional as F\n", "from torch import nn\n", "from pytorch_lightning import Trainer, LightningModule\n", "from torch.optim import Adam\n", "from torchvision.datasets import MNIST\n", "from torchvision import datasets, transforms\n", "from torch.utils.data import DataLoader\n", "\n", "import os\n", "\n", "BATCH_SIZE = 64\n", "\n", "# workaround for https://github.com/pytorch/vision/issues/1938 - error 403 when downloading mnist dataset\n", "import urllib\n", "\n", "opener = urllib.request.build_opener()\n", "opener.addheaders = [(\"User-agent\", \"Mozilla/5.0\")]\n", "urllib.request.install_opener(opener)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will start by implement a training class that uses the native data loader" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class LitMNIST(LightningModule):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " # mnist images are (1, 28, 28) (channels, width, height)\n", " self.layer_1 = torch.nn.Linear(28 * 28, 128)\n", " self.layer_2 = torch.nn.Linear(128, 256)\n", " self.layer_3 = torch.nn.Linear(256, 10)\n", "\n", " def forward(self, x):\n", " batch_size, channels, width, height = x.size()\n", "\n", " # (b, 1, 28, 28) -> (b, 1*28*28)\n", " x = x.view(batch_size, -1)\n", " x = self.layer_1(x)\n", " x = F.relu(x)\n", " x = self.layer_2(x)\n", " x = F.relu(x)\n", " x = self.layer_3(x)\n", "\n", " x = F.log_softmax(x, dim=1)\n", " return x\n", "\n", " def process_batch(self, batch):\n", " return batch\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = self.process_batch(batch)\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def cross_entropy_loss(self, logits, labels):\n", " return F.nll_loss(logits, labels)\n", "\n", " def configure_optimizers(self):\n", " return Adam(self.parameters(), lr=1e-3)\n", "\n", " def prepare_data(self):\n", " # download data only\n", " self.mnist_train = MNIST(\n", " os.getcwd(), train=True, download=True, transform=transforms.ToTensor()\n", " )\n", "\n", " def setup(self, stage=None):\n", " # transforms for images\n", " transform = transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n", " )\n", " self.mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=64, num_workers=8, pin_memory=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And see how it works" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n" ] } ], "source": [ "model = LitMNIST()\n", "trainer = Trainer(max_epochs=5, devices=1, accelerator=\"gpu\")\n", "# ddp work only in no-interactive mode, to test it unncoment and run as a script\n", "# trainer = Trainer(devices=8, accelerator=\"gpu\", strategy=\"ddp\", max_epochs=5)\n", "## MNIST data set is not always available to download due to network issues\n", "## to run this part of example either uncomment below line\n", "# trainer.fit(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The next step is to define a DALI pipeline that will be used for loading and pre-processing data." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import nvidia.dali as dali\n", "from nvidia.dali import pipeline_def\n", "import nvidia.dali.fn as fn\n", "import nvidia.dali.types as types\n", "from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy\n", "\n", "# Path to MNIST dataset\n", "data_path = os.path.join(os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/training/\")\n", "\n", "\n", "@pipeline_def\n", "def GetMnistPipeline(device, shard_id=0, num_shards=1):\n", " jpegs, labels = fn.readers.caffe2(\n", " path=data_path, shard_id=shard_id, num_shards=num_shards, random_shuffle=True, name=\"Reader\"\n", " )\n", " images = fn.decoders.image(\n", " jpegs, device=\"mixed\" if device == \"gpu\" else \"cpu\", output_type=types.GRAY\n", " )\n", " images = fn.crop_mirror_normalize(\n", " images, dtype=types.FLOAT, std=[0.3081 * 255], mean=[0.1307 * 255], output_layout=\"CHW\"\n", " )\n", " if device == \"gpu\":\n", " labels = labels.gpu()\n", " # PyTorch expects labels as INT64\n", " labels = fn.cast(labels, dtype=types.INT64)\n", " return images, labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are ready to modify the training class to use the DALI pipeline we have just defined. Because we want to integrate with PyTorch, we wrap our pipeline with a PyTorch DALI iterator, that can replace the native data loader with some minor changes in the code. The DALI iterator returns a list of dictionaries, where each element in the list corresponds to a pipeline instance, and the entries in the dictionary map to the outputs of the pipeline.\n", "\n", "For more information, check the documentation of DALIGenericIterator." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class DALILitMNIST(LitMNIST):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def prepare_data(self):\n", " # no preparation is needed in DALI\n", " pass\n", "\n", " def setup(self, stage=None):\n", " device_id = self.local_rank\n", " shard_id = self.global_rank\n", " num_shards = self.trainer.world_size\n", " mnist_pipeline = GetMnistPipeline(\n", " batch_size=BATCH_SIZE,\n", " device=\"gpu\",\n", " device_id=device_id,\n", " shard_id=shard_id,\n", " num_shards=num_shards,\n", " num_threads=8,\n", " )\n", " self.train_loader = DALIClassificationIterator(\n", " mnist_pipeline, reader_name=\"Reader\", last_batch_policy=LastBatchPolicy.PARTIAL\n", " )\n", "\n", " def train_dataloader(self):\n", " return self.train_loader\n", "\n", " def process_batch(self, batch):\n", " x = batch[0][\"data\"]\n", " y = batch[0][\"label\"].squeeze(-1)\n", " return (x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now run the training" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "-----------------------------------\n", "0 | layer_1 | Linear | 100 K \n", "1 | layer_2 | Linear | 33.0 K\n", "2 | layer_3 | Linear | 2.6 K \n", "-----------------------------------\n", "136 K Trainable params\n", "0 Non-trainable params\n", "136 K Total params\n", "0.544 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a16d571e3bcb4db7aadfa06000859610", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Even if previous Trainer finished his work it still keeps the GPU booked, force it to release the device.\n", "if \"PL_TRAINER_GPUS\" in os.environ:\n", " os.environ.pop(\"PL_TRAINER_GPUS\")\n", "model = DALILitMNIST()\n", "trainer = Trainer(max_epochs=5, devices=1, accelerator=\"gpu\", num_sanity_val_steps=0)\n", "# ddp work only in no-interactive mode, to test it unncoment and run as a script\n", "# trainer = Trainer(devices=8, accelerator=\"gpu\", strategy=\"ddp\", max_epochs=5)\n", "trainer.fit(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For even better integration, we can provide a custom DALI iterator wrapper so that no extra processing is required inside `LitMNIST.process_batch`. Also, PyTorch can learn the size of the dataset this way." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class BetterDALILitMNIST(LitMNIST):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def prepare_data(self):\n", " # no preparation is needed in DALI\n", " pass\n", "\n", " def setup(self, stage=None):\n", " device_id = self.local_rank\n", " shard_id = self.global_rank\n", " num_shards = self.trainer.world_size\n", " mnist_pipeline = GetMnistPipeline(\n", " batch_size=BATCH_SIZE,\n", " device=\"gpu\",\n", " device_id=device_id,\n", " shard_id=shard_id,\n", " num_shards=num_shards,\n", " num_threads=8,\n", " )\n", "\n", " class LightningWrapper(DALIClassificationIterator):\n", " def __init__(self, *kargs, **kvargs):\n", " super().__init__(*kargs, **kvargs)\n", "\n", " def __next__(self):\n", " out = super().__next__()\n", " # DDP is used so only one pipeline per process\n", " # also we need to transform dict returned by DALIClassificationIterator to iterable\n", " # and squeeze the lables\n", " out = out[0]\n", " return [out[k] if k != \"label\" else torch.squeeze(out[k]) for k in self.output_map]\n", "\n", " self.train_loader = LightningWrapper(\n", " mnist_pipeline, reader_name=\"Reader\", last_batch_policy=LastBatchPolicy.PARTIAL\n", " )\n", "\n", " def train_dataloader(self):\n", " return self.train_loader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us run the training one more time" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "-----------------------------------\n", "0 | layer_1 | Linear | 100 K \n", "1 | layer_2 | Linear | 33.0 K\n", "2 | layer_3 | Linear | 2.6 K \n", "-----------------------------------\n", "136 K Trainable params\n", "0 Non-trainable params\n", "136 K Total params\n", "0.544 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d58bd060e32a44ae868a96cb67fcacc3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Even if previous Trainer finished his work it still keeps the GPU booked, force it to release the device.\n", "if \"PL_TRAINER_GPUS\" in os.environ:\n", " os.environ.pop(\"PL_TRAINER_GPUS\")\n", "model = BetterDALILitMNIST()\n", "trainer = Trainer(max_epochs=5, devices=1, accelerator=\"gpu\", num_sanity_val_steps=0)\n", "# ddp work only in no-interactive mode, to test it unncoment and run as a script\n", "# trainer = Trainer(devices=8, accelerator=\"gpu\", strategy=\"ddp\", max_epochs=5)\n", "trainer.fit(model)" ] } ], "metadata": { "interpreter": { "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }