Data Loading: Webdataset#
Overview#
This example shows you how to use the data stored in Webdataset format with DALI.
Using readers.webdataset operator#
Data stored in WebDataset format can be read with readers.webdataset operator. The operator takes the following arguments:
- pathsthe path (or list of paths) to the tar archives containing webdataset
- index_pathsthe path (or list of paths) to the respective index files containing data about the specifics of the tar files, created using- wds2idx- a utility included with DALI. For the usage details please refer to- wds2idx -h. If not provided, it will be automatically inferred from the tar file, although it will take considerable time for big datasets.
- extthe extension set (or list of those) of extensions separated by a ‘;’ that specify the outputs of the operator and which sample components will be fed into the specific outputs of the operator
- missing_component_behaviorthe behavior of the reader in case it encounters a sample that can’t return provide any component to a certain output. There are 3 options:- empty(default) returns an empty tensor for that output
- skipthe samples with missing components are skipped
- erroran error is raised
 
- dtypesthe data types of the outputs of the operator. If the size of the output component is not divisible by the size of the type, an error is raised. If not provided, data is returned as UINT8. In addition to these arguments, the operator accepts arguments common to all readers, which configure the seed of random number generator, shuffling, sharding and handling of incomplete batches at the end of epoch.
Creating an index#
The index file(s), paths to which are passed in the argument index_paths , can be generated with a tool wds2idx, bundled with DALI.
Note: The DALI_EXTRA_PATH environment variable should point to the location where data from DALI extra repository is downloaded.
Important: Ensure that you check out the correct release tag that corresponds to the installed version of DALI.
[1]:
from subprocess import call
import os.path
test_data_root = os.environ["DALI_EXTRA_PATH"]
wds_data = os.path.join(test_data_root, "db", "webdataset", "train.tar")
batch_size = 16
Defining and running the pipeline#
- Define a simple pipeline that takes the images stored in Webdataset format and decodes them. - In this example we process images by cropping, normalizing, and - HWC->- CHWconversion process.
[2]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import numpy as np
@pipeline_def(batch_size=batch_size, num_threads=4, device_id=0)
def wds_pipeline(wds_data=wds_data):
    img_raw, cls = fn.readers.webdataset(
        paths=wds_data, ext=["jpg", "cls"], missing_component_behavior="error"
    )
    img = fn.decoders.image(img_raw, device="mixed", output_type=types.RGB)
    resized = fn.resize(img, device="gpu", resize_shorter=256.0)
    output = fn.crop_mirror_normalize(
        resized,
        dtype=types.FLOAT,
        crop=(224, 224),
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
    )
    return output, cls
- Build and run the pipeline: 
[3]:
pipe = wds_pipeline()
pipe.build()
pipe_out = pipe.run()
- To visualize the results, use the - matplotliblibrary, which expects images in- HWCformat, but the output of the pipeline is in- CHW.
- For the visualization purposes, transpose the images back to the - HWClayout.
[4]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline
def show_images(image_batch, labels):
    columns = 4
    rows = (batch_size + 1) // (columns)
    fig = plt.figure(figsize=(32, (32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        ascii = labels.at(j)
        plt.title(
            "".join([chr(item) for item in ascii]), fontdict={"fontsize": 25}
        )
        img_chw = image_batch.at(j)
        img_hwc = np.transpose(img_chw, (1, 2, 0)) / 255.0
        plt.imshow(img_hwc)
[5]:
images, labels = pipe_out
show_images(images.as_cpu(), labels)
