Data Loading: Webdataset

Overview

This example shows you how to use the data stored in Webdataset format with DALI.

Using readers.webdataset operator

Data stored in WebDataset format can be read with readers.webdataset operator. The operator takes the following arguments:

  • paths the path (or list of paths) to the tar archives containing webdataset

  • index_paths the path (or list of paths) to the respective index files containing data about the specifics of the tar files, created using wds2idx - a utility included with DALI. For the usage details please refer to wds2idx -h. If not provided, it will be automatically inferred from the tar file, although it will take considerable time for big datasets.

  • ext the extension set (or list of those) of extensions separated by a ‘;’ that specify the outputs of the operator and which sample components will be fed into the specific outputs of the operator

  • missing_component_behavior the behavior of the reader in case it encounters a sample that can’t return provide any component to a certain output. There are 3 options:

    • empty (default) returns an empty tensor for that output

    • skip the samples with missing components are skipped

    • error an error is raised

  • dtypes the data types of the outputs of the operator. If the size of the output component is not divisible by the size of the type, an error is raised. If not provided, data is returned as UINT8. In addition to these arguments, the operator accepts arguments common to all readers, which configure the seed of random number generator, shuffling, sharding and handling of incomplete batches at the end of epoch.

Creating an index

The index file(s), paths to which are passed in the argument index_paths , can be generated with a tool wds2idx, bundled with DALI.

Note: The DALI_EXTRA_PATH environment variable should point to the location where data from DALI extra repository is downloaded.

Important: Ensure that you check out the correct release tag that corresponds to the installed version of DALI.

[1]:
from subprocess import call
import os.path

test_data_root = os.environ['DALI_EXTRA_PATH']
wds_data = os.path.join(test_data_root, 'db', 'webdataset', 'train.tar')
batch_size = 16

Defining and running the pipeline

  1. Define a simple pipeline that takes the images stored in Webdataset format and decodes them.

    In this example we process images by cropping, normalizing, and HWC -> CHW conversion process.

[2]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import numpy as np

@pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)
def wds_pipeline(wds_data=wds_data):
    img_raw, cls = fn.readers.webdataset(
        paths=wds_data,
        ext=["jpg", "cls"],
        missing_component_behavior="error")
    img = fn.decoders.image(img_raw, device="mixed", output_type=types.RGB)
    resized = fn.resize(img, device="gpu", resize_shorter=256.)
    output = fn.crop_mirror_normalize(
        resized,
        dtype=types.FLOAT,
        crop=(224, 224),
        mean=[0., 0., 0.],
        std=[1., 1., 1.])
    return output, cls
  1. Build and run the pipeline:

[3]:
pipe = wds_pipeline()
pipe.build()
pipe_out = pipe.run()
  1. To visualize the results, use the matplotlib library, which expects images in HWC format, but the output of the pipeline is in CHW.

  2. For the visualization purposes, transpose the images back to the HWC layout.

[4]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline

def show_images(image_batch, labels):
    columns = 4
    rows = (batch_size + 1) // (columns)
    fig = plt.figure(figsize = (32,(32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows*columns):
        plt.subplot(gs[j])
        plt.axis("off")
        ascii = labels.at(j)
        plt.title("".join([chr(item) for item in ascii]), fontdict={"fontsize": 25})
        img_chw = image_batch.at(j)
        img_hwc = np.transpose(img_chw, (1,2,0))/255.0
        plt.imshow(img_hwc)
[5]:
images, labels = pipe_out
show_images(images.as_cpu(), labels)
../../../_images/examples_general_data_loading_dataloading_webdataset_8_0.png