{ "cells": [ { "cell_type": "markdown", "id": "228afd48", "metadata": {}, "source": [ "# WebDataset integration using External Source\n", "In this notebook is an example of how one may combine the [webdataset](https://github.com/webdataset/webdataset) with a DALI pipeline, using an external source operator" ] }, { "cell_type": "markdown", "id": "8e37d740", "metadata": {}, "source": [ "## Introduction\n", "### Data Representation\n", "Web Dataset is a dataset representation that heavily optimizes networked accessed storage performance. At its simplest, it stores the whole dataset in one tarball file, where each sample is represented by one or more entries with the same name but different extensions. This approach improves drive access caching in RAM, since the data is represented sequentially." ] }, { "cell_type": "markdown", "id": "5380a878", "metadata": {}, "source": [ "### Sharding\n", "In order to improve distributed storage access and network data transfer, the webdataset employs a strategy called sharding. In this approach, the tarball holding the data is split into several smaller ones, called shards, which allows for fetching from several storage drives at once, and reduces the packet size that has to be transferred via the network." ] }, { "cell_type": "markdown", "id": "6810b08b", "metadata": {}, "source": [ "## Sample Implementation\n", "First, let's import the necessary modules and define the locations of the datasets that will be needed later.\n", "\n", "`DALI_EXTRA_PATH` environment variable should point to the place where the data from [DALI extra repository](https://github.com/NVIDIA/DALI_extra) is downloaded. Please make sure that the proper release tag is checked out.\n", "\n", "The `tar_dataset_paths` holds the paths to the shards that will be loaded while showing and testing the webdataset loader.\n", "\n", "`batch_size` is the common batch size for both loaders" ] }, { "cell_type": "code", "execution_count": 1, "id": "c6ca5ce2", "metadata": {}, "outputs": [], "source": [ "import nvidia.dali.fn as fn\n", "import nvidia.dali as dali\n", "import nvidia.dali.types as types\n", "import webdataset as wds\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import glob\n", "import os\n", "import random\n", "import tempfile\n", "import tarfile\n", "\n", "root_path = os.path.join(os.environ[\"DALI_EXTRA_PATH\"], \"db\", \"webdataset\", \"MNIST\")\n", "tar_dataset_paths = [os.path.join(root_path, data_file) \n", " for data_file in [\"devel-0.tar\", \"devel-1.tar\", \"devel-2.tar\"]]\n", "batch_size = 16" ] }, { "cell_type": "markdown", "id": "c4df68f6", "metadata": {}, "source": [ "Next, let's extract the files that will later be used for comparing the file reader to our custom one.\n", "\n", "The `folder_dataset_files` holds the paths to the files " ] }, { "cell_type": "code", "execution_count": 2, "id": "c4667b45", "metadata": {}, "outputs": [], "source": [ "folder_dataset_root_dir = tempfile.TemporaryDirectory()\n", "folder_dataset_dirs = [tempfile.TemporaryDirectory(dir=folder_dataset_root_dir.name) \n", " for dataset in tar_dataset_paths]\n", "folder_dataset_tars = [tarfile.open(dataset) for dataset in tar_dataset_paths]\n", "\n", "for folder_dataset_tar, folder_dataset_subdir in zip(folder_dataset_tars, folder_dataset_dirs):\n", " folder_dataset_tar.extractall(path=folder_dataset_subdir.name)\n", "\n", "folder_dataset_files = [\n", " filepath\n", " for folder_dataset_subdir in folder_dataset_dirs\n", " for filepath in sorted(\n", " glob.glob(os.path.join(folder_dataset_subdir.name, \"*.jpg\")), \n", " key=lambda s: int(s[s.rfind('/') + 1:s.rfind(\".jpg\")])\n", " )\n", "]" ] }, { "cell_type": "markdown", "id": "8870c432", "metadata": {}, "source": [ "The function below is used to later randomize the output from the dataset. The samples are first stored in a prefetch buffer, and then they're randomly yielded in a generator and replaced by a new sample." ] }, { "cell_type": "code", "execution_count": 3, "id": "42b9852e", "metadata": {}, "outputs": [], "source": [ "def buffered_shuffle(generator_factory, initial_fill, seed):\n", " def buffered_shuffle_generator():\n", " nonlocal generator_factory, initial_fill, seed\n", " generator = generator_factory()\n", " # The buffer size must be positive\n", " assert(initial_fill > 0)\n", "\n", " # The buffer that will hold the randomized samples\n", " buffer = []\n", "\n", " # The random context for preventing side effects\n", " random_context = random.Random(seed)\n", "\n", " try:\n", " while len(buffer) < initial_fill: # Fills in the random buffer\n", " buffer.append(next(generator))\n", "\n", " while True: # Selects a random sample from the buffer and then fills it back in with a new one\n", " idx = random_context.randint(0, initial_fill-1)\n", "\n", " yield buffer[idx]\n", " buffer[idx] = None\n", " buffer[idx] = next(generator)\n", "\n", " except StopIteration: # When the generator runs out of the samples flushes our the buffer\n", " random_context.shuffle(buffer)\n", "\n", " while buffer:\n", " if buffer[-1] != None: # Prevents the one sample that was not filled from being duplicated\n", " yield buffer[-1]\n", " buffer.pop()\n", " return buffered_shuffle_generator\n", " " ] }, { "cell_type": "markdown", "id": "921072ab", "metadata": {}, "source": [ "The next function is used for padding the last batch with the last sample, in order to make it the same size as all the other ones." ] }, { "cell_type": "code", "execution_count": 4, "id": "62b05312", "metadata": {}, "outputs": [], "source": [ "def last_batch_padding(generator_factory, batch_size):\n", " def last_batch_padding_generator():\n", " nonlocal generator_factory, batch_size\n", " generator = generator_factory()\n", " in_batch_idx = 0\n", " last_item = None\n", " try:\n", " while True: # Keeps track of the last sample and the sample number mod batch_size\n", " if in_batch_idx >= batch_size:\n", " in_batch_idx -= batch_size\n", " last_item = next(generator)\n", " in_batch_idx += 1\n", " yield last_item\n", " except StopIteration: # Repeats the last sample the necessary number of times\n", " while in_batch_idx < batch_size:\n", " yield last_item\n", " in_batch_idx += 1\n", " return last_batch_padding_generator" ] }, { "cell_type": "markdown", "id": "8911700d", "metadata": {}, "source": [ "The final function collects all the data into batches in order to be able to have a variable length batch for the last sample" ] }, { "cell_type": "code", "execution_count": 5, "id": "4b53d6cf", "metadata": {}, "outputs": [], "source": [ "def collect_batches(generator_factory, batch_size):\n", " def collect_batches_generator():\n", " nonlocal generator_factory, batch_size\n", " generator = generator_factory()\n", " batch = []\n", " try:\n", " while True:\n", " batch.append(next(generator))\n", " if len(batch) == batch_size:\n", " # Converts tuples of samples into tuples of batches of samples\n", " yield tuple(map(list, zip(*batch)))\n", " batch = []\n", " except StopIteration:\n", " if batch is not []:\n", " # Converts tuples of samples into tuples of batches of samples\n", " yield tuple(map(list, zip(*batch)))\n", " return collect_batches_generator" ] }, { "cell_type": "markdown", "id": "4fe0049f", "metadata": {}, "source": [ "And finally the data loader, that configures and returns an [ExternalSource](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/external_input.html) node.\n", "\n", "### Keyword Arguments:\n", "\n", "`paths`: describes the paths to the file/files containing the webdataset, and can be formatted as any data accepted by the `WebDataset`\n", "\n", "`extensions`: describes the extensions containing the data to be output through the dataset. By default, all image format extensions supported by `WebDataset` are used\n", "\n", "`random_shuffle`: describes whether to shuffle the data read by the `WebDataset`\n", "\n", "`initial_fill`: if `random_shuffle` is True describes the buffer size of the data shuffler. Set to 256 by default.\n", "\n", "`seed`: describes the seed for shuffling the data. Useful for getting consistent results. Set to 0 by default\n", "\n", "`pad_last_batch`: describes whether to pad the last batch with the final sample to match the regular batch size\n", "\n", "`read_ahead`: describes whether to prefetch the data into the memory\n", "\n", "`cycle`: can be either `\"raise\"`, in which case the data loader will throw StopIteration once it reaches the end of the data, in which case the user has to invoke `pipeline.reset()` before the next epoch, or `\"quiet\"`(Default), in which case it will keep looping over the data over and over" ] }, { "cell_type": "code", "execution_count": 6, "id": "2479f399", "metadata": {}, "outputs": [], "source": [ "def read_webdataset(\n", " paths, \n", " extensions=None,\n", " random_shuffle=False, \n", " initial_fill=256, \n", " seed=0,\n", " pad_last_batch=False,\n", " read_ahead=False,\n", " cycle=\"quiet\"\n", "):\n", " # Parsing the input data\n", " assert(cycle in {\"quiet\", \"raise\", \"no\"})\n", " if extensions == None:\n", " extensions = ';'.join([\"jpg\", \"jpeg\", \"img\", \"image\", \"pbm\", \"pgm\", \"png\"]) # All supported image formats\n", " if type(extensions) == str:\n", " extensions = (extensions,)\n", " \n", " # For later information for batch collection and padding\n", " max_batch_size = dali.pipeline.Pipeline.current().max_batch_size\n", " \n", " def webdataset_generator():\n", " bytes_np_mapper = (lambda data: np.frombuffer(data, dtype=np.uint8),)*len(extensions)\n", " dataset_instance = (wds.WebDataset(paths)\n", " .to_tuple(*extensions)\n", " .map_tuple(*bytes_np_mapper))\n", " \n", " for sample in dataset_instance:\n", " yield sample\n", " \n", " dataset = webdataset_generator\n", " \n", " # Adding the buffered shuffling\n", " if random_shuffle:\n", " dataset = buffered_shuffle(dataset, initial_fill, seed)\n", " \n", " # Adding the batch padding\n", " if pad_last_batch:\n", " dataset = last_batch_padding(dataset, max_batch_size)\n", " \n", " # Collecting the data into batches (possibly undefull)\n", " # Handled by a custom function only when `silent_cycle` is False\n", " if cycle != \"quiet\":\n", " dataset = collect_batches(dataset, max_batch_size)\n", " \n", " # Prefetching the data\n", " if read_ahead:\n", " dataset=list(dataset())\n", " \n", " return fn.external_source(\n", " source=dataset,\n", " num_outputs=len(extensions),\n", " batch=(cycle != \"quiet\"), # If `cycle` is \"quiet\" then batching is handled by the external source\n", " cycle=cycle,\n", " dtype=types.UINT8\n", " )" ] }, { "cell_type": "markdown", "id": "6c08a04b", "metadata": {}, "source": [ "We also define a sample data augmentation function which decodes an image, applies a jitter to it and resizes it to 244x244." ] }, { "cell_type": "code", "execution_count": 7, "id": "a018552f", "metadata": {}, "outputs": [], "source": [ "def decode_augment(img, seed=0):\n", " img = fn.decoders.image(img)\n", " img = fn.jitter(img.gpu(), seed=seed)\n", " img = fn.resize(img, size=(224, 224))\n", " return img" ] }, { "cell_type": "markdown", "id": "15b9bd5c", "metadata": {}, "source": [ "## Usage presentation\n", "Below we define the sample webdataset pipeline with our `external_source`-based loader, that just chains the previously defined reader and augmentation function together." ] }, { "cell_type": "code", "execution_count": 8, "id": "b97ffab9", "metadata": {}, "outputs": [], "source": [ "@dali.pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)\n", "def webdataset_pipeline(\n", " paths,\n", " random_shuffle=False, \n", " initial_fill=256,\n", " seed=0,\n", " pad_last_batch=False,\n", " read_ahead=False,\n", " cycle=\"quiet\"\n", "):\n", " img, label = read_webdataset(paths=paths, \n", " extensions=(\"jpg\", \"cls\"),\n", " random_shuffle=random_shuffle,\n", " initial_fill=initial_fill,\n", " seed=seed,\n", " pad_last_batch=pad_last_batch,\n", " read_ahead=read_ahead,\n", " cycle=cycle)\n", " return decode_augment(img, seed=seed), label" ] }, { "cell_type": "markdown", "id": "db4c395c", "metadata": {}, "source": [ "The pipeline can then be build with the desired arguments passed through to the data loader" ] }, { "cell_type": "code", "execution_count": 9, "id": "7b128aae", "metadata": {}, "outputs": [], "source": [ "pipeline = webdataset_pipeline(\n", " tar_dataset_paths, # Paths for the sharded dataset\n", " random_shuffle=True, # Random buffered shuffling on\n", " pad_last_batch=False, # Last batch is filled to the full size\n", " read_ahead=False,\n", " cycle=\"raise\") # All the data is preloaded into the memory\n", "pipeline.build()" ] }, { "cell_type": "markdown", "id": "1c1fb518", "metadata": {}, "source": [ "And executed, printing the example image using matplotlib" ] }, { "cell_type": "code", "execution_count": 10, "id": "fc3a842d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img, c = pipeline.run() # If StopIteration is raised, use pipeline.reset() to start a new epoch\n", "img = img.as_cpu()\n", "print(int(bytes(c.as_array()[0]))) # Conversion from an array of bytes back to bytes and then to int\n", "plt.imshow(img.as_array()[0])\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "ee203442", "metadata": {}, "source": [ "## Checking consistency\n", "Here we will check if the custom pipeline for the webdataset matches an equivalent pipeline reading the files from an untarred directory, with `fn.readers.file` reader.\n", "\n", "First let's define the pipeline to compare against. This is the same pipeline as the one for the webdataset, but instead uses the `fn.readers.file` reader." ] }, { "cell_type": "code", "execution_count": 11, "id": "abd839be", "metadata": {}, "outputs": [], "source": [ "@dali.pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)\n", "def file_pipeline(files):\n", " img, _ = fn.readers.file(files=files)\n", " return decode_augment(img)" ] }, { "cell_type": "markdown", "id": "c8018b06", "metadata": {}, "source": [ "Then let's instantiate and build both pipelines" ] }, { "cell_type": "code", "execution_count": 12, "id": "804bce07", "metadata": {}, "outputs": [], "source": [ "webdataset_pipeline_instance = webdataset_pipeline(tar_dataset_paths)\n", "webdataset_pipeline_instance.build()\n", "file_pipeline_instance = file_pipeline(folder_dataset_files)\n", "file_pipeline_instance.build()" ] }, { "cell_type": "markdown", "id": "2b370da0", "metadata": {}, "source": [ "\n", "And run the comparison loop." ] }, { "cell_type": "code", "execution_count": 13, "id": "57e1a773", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No difference found!\n" ] } ], "source": [ "# The number of batches to sample between the two pipelines\n", "num_batches = 10\n", "\n", "for _ in range(num_batches):\n", " webdataset_pipeline_threw_exception = False\n", " file_pipeline_threw_exception = False\n", " \n", " # Try running the webdataset pipeline and check if it has run out of the samples\n", " try:\n", " web_img, _ = webdataset_pipeline_instance.run()\n", " except StopIteration:\n", " webdataset_pipeline_threw_exception = True\n", " \n", " # Try running the file pipeline and check if it has run out of the samples\n", " try:\n", " (file_img,) = file_pipeline_instance.run()\n", " except StopIteration:\n", " file_pipeline_threw_exception = True\n", " \n", " # In case of different number of batches\n", " assert(webdataset_pipeline_threw_exception==file_pipeline_threw_exception)\n", "\n", " web_img = web_img.as_cpu().as_array()\n", " file_img = file_img.as_cpu().as_array()\n", "\n", " # In case the pipelines give different outputs\n", " np.testing.assert_equal(web_img, file_img)\n", "else:\n", " print(\"No difference found!\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 5 }