WebDataset integration using External Source#

In this notebook is an example of how one may combine the webdataset with a DALI pipeline, using an external source operator

Introduction#

Data Representation#

Web Dataset is a dataset representation that heavily optimizes networked accessed storage performance. At its simplest, it stores the whole dataset in one tarball file, where each sample is represented by one or more entries with the same name but different extensions. This approach improves drive access caching in RAM, since the data is represented sequentially.

Sharding#

In order to improve distributed storage access and network data transfer, the webdataset employs a strategy called sharding. In this approach, the tarball holding the data is split into several smaller ones, called shards, which allows for fetching from several storage drives at once, and reduces the packet size that has to be transferred via the network.

Sample Implementation#

First, let’s import the necessary modules and define the locations of the datasets that will be needed later.

DALI_EXTRA_PATH environment variable should point to the place where the data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.

The tar_dataset_paths holds the paths to the shards that will be loaded while showing and testing the webdataset loader.

batch_size is the common batch size for both loaders

[1]:
import nvidia.dali.fn as fn
import nvidia.dali as dali
import nvidia.dali.types as types
import webdataset as wds
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import random
import tempfile
import tarfile

root_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db", "webdataset",
                         "MNIST")
tar_dataset_paths = [os.path.join(root_path, data_file)
                        for data_file in ["devel-0.tar", "devel-1.tar",
                                          "devel-2.tar"]]
batch_size = 16

Next, let’s extract the files that will later be used for comparing the file reader to our custom one.

The folder_dataset_files holds the paths to the files

[2]:
folder_dataset_root_dir = tempfile.TemporaryDirectory()
folder_dataset_dirs = [tempfile.TemporaryDirectory(dir=folder_dataset_root_dir.name)
                     for dataset in tar_dataset_paths]
folder_dataset_tars = [tarfile.open(dataset) for dataset in tar_dataset_paths]

for folder_dataset_tar, folder_dataset_subdir in zip(folder_dataset_tars,
                                                     folder_dataset_dirs):
    folder_dataset_tar.extractall(path=folder_dataset_subdir.name)

folder_dataset_files = [
    filepath
    for folder_dataset_subdir in folder_dataset_dirs
    for filepath in sorted(
        glob.glob(os.path.join(folder_dataset_subdir.name, "*.jpg")),
        key=lambda s: int(s[s.rfind('/') + 1:s.rfind(".jpg")])
    )
]

The function below is used to later randomize the output from the dataset. The samples are first stored in a prefetch buffer, and then they’re randomly yielded in a generator and replaced by a new sample.

[3]:
def buffered_shuffle(generator_factory, initial_fill, seed):
    def buffered_shuffle_generator():
        nonlocal generator_factory, initial_fill, seed
        generator = generator_factory()
        # The buffer size must be positive
        assert(initial_fill > 0)

        # The buffer that will hold the randomized samples
        buffer = []

        # The random context for preventing side effects
        random_context = random.Random(seed)

        try:
            while len(buffer) < initial_fill: # Fills in the random buffer
                buffer.append(next(generator))

            # Selects a random sample from the buffer and then fills it back
            # in with a new one
            while True:
                idx = random_context.randint(0, initial_fill-1)

                yield buffer[idx]
                buffer[idx] = None
                buffer[idx] = next(generator)

        # When the generator runs out of the samples flushes our the buffer
        except StopIteration:
            random_context.shuffle(buffer)

            while buffer:
                # Prevents the one sample that was not filled from being duplicated
                if buffer[-1] != None:
                    yield buffer[-1]
                buffer.pop()
    return buffered_shuffle_generator

The next function is used for padding the last batch with the last sample, in order to make it the same size as all the other ones.

[4]:
def last_batch_padding(generator_factory, batch_size):
    def last_batch_padding_generator():
        nonlocal generator_factory, batch_size
        generator = generator_factory()
        in_batch_idx = 0
        last_item = None
        try:
            # Keeps track of the last sample and the sample number mod batch_size
            while True:
                if in_batch_idx >= batch_size:
                    in_batch_idx -= batch_size
                last_item = next(generator)
                in_batch_idx += 1
                yield last_item
        # Repeats the last sample the necessary number of times
        except StopIteration:
            while in_batch_idx < batch_size:
                yield last_item
                in_batch_idx += 1
    return last_batch_padding_generator

The final function collects all the data into batches in order to be able to have a variable length batch for the last sample

[5]:
def collect_batches(generator_factory, batch_size):
    def collect_batches_generator():
        nonlocal generator_factory, batch_size
        generator = generator_factory()
        batch = []
        try:
            while True:
                batch.append(next(generator))
                if len(batch) == batch_size:
                    # Converts tuples of samples into tuples of batches of samples
                    yield tuple(map(list, zip(*batch)))
                    batch = []
        except StopIteration:
            if batch is not []:
                # Converts tuples of samples into tuples of batches of samples
                yield tuple(map(list, zip(*batch)))
    return collect_batches_generator

And finally the data loader, that configures and returns an ExternalSource node.

Keyword Arguments:#

paths: describes the paths to the file/files containing the webdataset, and can be formatted as any data accepted by the WebDataset

extensions: describes the extensions containing the data to be output through the dataset. By default, all image format extensions supported by WebDataset are used

random_shuffle: describes whether to shuffle the data read by the WebDataset

initial_fill: if random_shuffle is True describes the buffer size of the data shuffler. Set to 256 by default.

seed: describes the seed for shuffling the data. Useful for getting consistent results. Set to 0 by default

pad_last_batch: describes whether to pad the last batch with the final sample to match the regular batch size

read_ahead: describes whether to prefetch the data into the memory

cycle: can be either "raise", in which case the data loader will throw StopIteration once it reaches the end of the data, in which case the user has to invoke pipeline.reset() before the next epoch, or "quiet"(Default), in which case it will keep looping over the data over and over

[6]:
def read_webdataset(
    paths,
    extensions=None,
    random_shuffle=False,
    initial_fill=256,
    seed=0,
    pad_last_batch=False,
    read_ahead=False,
    cycle="quiet"
):
    # Parsing the input data
    assert(cycle in {"quiet", "raise", "no"})
    if extensions == None:
        # All supported image formats
        extensions = ';'.join(["jpg", "jpeg", "img", "image", "pbm", "pgm", "png"])
    if type(extensions) == str:
        extensions = (extensions,)

    # For later information for batch collection and padding
    max_batch_size = dali.pipeline.Pipeline.current().max_batch_size

    def webdataset_generator():
        bytes_np_mapper = (lambda data: np.frombuffer(data, dtype=np.uint8),
                           )*len(extensions)
        dataset_instance = (wds.WebDataset(paths)
                            .to_tuple(*extensions)
                            .map_tuple(*bytes_np_mapper))

        for sample in dataset_instance:
            yield sample

    dataset = webdataset_generator

    # Adding the buffered shuffling
    if random_shuffle:
        dataset = buffered_shuffle(dataset, initial_fill, seed)

    # Adding the batch padding
    if pad_last_batch:
        dataset = last_batch_padding(dataset, max_batch_size)

    # Collecting the data into batches (possibly undefull)
    # Handled by a custom function only when `silent_cycle` is False
    if cycle != "quiet":
        dataset = collect_batches(dataset, max_batch_size)

    # Prefetching the data
    if read_ahead:
        dataset=list(dataset())

    return fn.external_source(
        source=dataset,
        num_outputs=len(extensions),
        # If `cycle` is "quiet" then batching is handled by the external source
        batch=(cycle != "quiet"),
        cycle=cycle,
        dtype=types.UINT8
    )

We also define a sample data augmentation function which decodes an image, applies a jitter to it and resizes it to 244x244.

[7]:
def decode_augment(img, seed=0):
    img = fn.decoders.image(img)
    img = fn.jitter(img.gpu(), seed=seed)
    img = fn.resize(img, size=(224, 224))
    return img

Usage presentation#

Below we define the sample webdataset pipeline with our external_source-based loader, that just chains the previously defined reader and augmentation function together.

[8]:
@dali.pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)
def webdataset_pipeline(
    paths,
    random_shuffle=False,
    initial_fill=256,
    seed=0,
    pad_last_batch=False,
    read_ahead=False,
    cycle="quiet"
):
    img, label = read_webdataset(paths=paths,
                                 extensions=("jpg", "cls"),
                                 random_shuffle=random_shuffle,
                                 initial_fill=initial_fill,
                                 seed=seed,
                                 pad_last_batch=pad_last_batch,
                                 read_ahead=read_ahead,
                                 cycle=cycle)
    return decode_augment(img, seed=seed), label

The pipeline can then be build with the desired arguments passed through to the data loader

[9]:
pipeline = webdataset_pipeline(
    tar_dataset_paths,   # Paths for the sharded dataset
    random_shuffle=True, # Random buffered shuffling on
    pad_last_batch=False, # Last batch is filled to the full size
    read_ahead=False,
    cycle="raise")     # All the data is preloaded into the memory
pipeline.build()

And executed, printing the example image using matplotlib

[10]:
# If StopIteration is raised, use pipeline.reset() to start a new epoch
img, c = pipeline.run()
img = img.as_cpu()
# Conversion from an array of bytes back to bytes and then to int
print(int(bytes(c.as_array()[0])))
plt.imshow(img.as_array()[0])
plt.show()
1
../../_images/examples_use_cases_webdataset-externalsource_22_1.png

Checking consistency#

Here we will check if the custom pipeline for the webdataset matches an equivalent pipeline reading the files from an untarred directory, with fn.readers.file reader.

First let’s define the pipeline to compare against. This is the same pipeline as the one for the webdataset, but instead uses the fn.readers.file reader.

[11]:
@dali.pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)
def file_pipeline(files):
    img, _ = fn.readers.file(files=files)
    return decode_augment(img)

Then let’s instantiate and build both pipelines

[12]:
webdataset_pipeline_instance = webdataset_pipeline(tar_dataset_paths)
webdataset_pipeline_instance.build()
file_pipeline_instance = file_pipeline(folder_dataset_files)
file_pipeline_instance.build()

And run the comparison loop.

[13]:
# The number of batches to sample between the two pipelines
num_batches = 10

for _ in range(num_batches):
    webdataset_pipeline_threw_exception = False
    file_pipeline_threw_exception = False

    # Try running the webdataset pipeline and check if it has run out of
    # the samples
    try:
        web_img, _ = webdataset_pipeline_instance.run()
    except StopIteration:
        webdataset_pipeline_threw_exception = True

    # Try running the file pipeline and check if it has run out of the samples
    try:
        (file_img,) = file_pipeline_instance.run()
    except StopIteration:
        file_pipeline_threw_exception = True

    # In case of different number of batches
    assert(webdataset_pipeline_threw_exception==file_pipeline_threw_exception)

    web_img = web_img.as_cpu().as_array()
    file_img = file_img.as_cpu().as_array()

    # In case the pipelines give different outputs
    np.testing.assert_equal(web_img, file_img)
else:
    print("No difference found!")
No difference found!