Using PyTorch DALI plugin: using various readers#
Overview#
This example shows how different readers could be used to interact with PyTorch. It shows how flexible DALI is.
The following readers are used in this example:
readers.mxnet
readers.caffe
readers.file
readers.tfrecord
For details on how to use them please see other examples.
Let us start from defining some global constants
DALI_EXTRA_PATH
environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.
[1]:
import os.path
test_data_root = os.environ["DALI_EXTRA_PATH"]
# MXNet RecordIO
db_folder = os.path.join(test_data_root, "db", "recordio/")
# Caffe LMDB
lmdb_folder = os.path.join(test_data_root, "db", "lmdb")
# image dir with plain jpeg files
image_dir = "../../data/images"
# TFRecord
tfrecord = os.path.join(test_data_root, "db", "tfrecord", "train")
tfrecord_idx = "idx_files/train.idx"
tfrecord2idx_script = "tfrecord2idx"
N = 8 # number of GPUs
BATCH_SIZE = 128 # batch size per GPU
IMAGE_SIZE = 3
Create idx file by calling tfrecord2idx
script
[2]:
from subprocess import call
import os.path
if not os.path.exists("idx_files"):
os.mkdir("idx_files")
if not os.path.isfile(tfrecord_idx):
call([tfrecord2idx_script, tfrecord, tfrecord_idx])
Let us define: - common part of the processing graph, used by all pipelines
[3]:
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
def common_pipeline(jpegs, labels):
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(
images,
resize_shorter=fn.random.uniform(range=(256, 480)),
interp_type=types.INTERP_LINEAR,
)
images = fn.crop_mirror_normalize(
images,
crop_pos_x=fn.random.uniform(range=(0.0, 1.0)),
crop_pos_y=fn.random.uniform(range=(0.0, 1.0)),
dtype=types.FLOAT,
crop=(227, 227),
mean=[128.0, 128.0, 128.0],
std=[1.0, 1.0, 1.0],
)
return images, labels
MXNet reader pipeline
[4]:
@pipeline_def
def mxnet_reader_pipeline(num_gpus):
jpegs, labels = fn.readers.mxnet(
path=[db_folder + "train.rec"],
index_path=[db_folder + "train.idx"],
random_shuffle=True,
shard_id=Pipeline.current().device_id,
num_shards=num_gpus,
name="Reader",
)
return common_pipeline(jpegs, labels)
Caffe reader pipeline
[5]:
@pipeline_def
def caffe_reader_pipeline(num_gpus):
jpegs, labels = fn.readers.caffe(
path=lmdb_folder,
random_shuffle=True,
shard_id=Pipeline.current().device_id,
num_shards=num_gpus,
name="Reader",
)
return common_pipeline(jpegs, labels)
File reader pipeline
[6]:
@pipeline_def
def file_reader_pipeline(num_gpus):
jpegs, labels = fn.readers.file(
file_root=image_dir,
random_shuffle=True,
shard_id=Pipeline.current().device_id,
num_shards=num_gpus,
name="Reader",
)
return common_pipeline(jpegs, labels)
TFRecord reader pipeline
[7]:
import nvidia.dali.tfrecord as tfrec
@pipeline_def
def tfrecord_reader_pipeline(num_gpus):
inputs = fn.readers.tfrecord(
path=tfrecord,
index_path=tfrecord_idx,
features={
"image/encoded": tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1),
},
random_shuffle=True,
shard_id=Pipeline.current().device_id,
num_shards=num_gpus,
name="Reader",
)
return common_pipeline(inputs["image/encoded"], inputs["image/class/label"])
Let us create pipelines and pass them to PyTorch generic iterator
[8]:
import numpy as np
from nvidia.dali.plugin.pytorch import DALIGenericIterator
pipe_types = [
[mxnet_reader_pipeline, (0, 999)],
[caffe_reader_pipeline, (0, 999)],
[file_reader_pipeline, (0, 1)],
[tfrecord_reader_pipeline, (1, 1000)],
]
for pipe_t in pipe_types:
pipe_name, label_range = pipe_t
print("RUN: " + pipe_name.__name__)
pipes = [
pipe_name(
batch_size=BATCH_SIZE,
num_threads=2,
device_id=device_id,
num_gpus=N,
)
for device_id in range(N)
]
dali_iter = DALIGenericIterator(
pipes, ["data", "label"], reader_name="Reader"
)
for i, data in enumerate(dali_iter):
# Testing correctness of labels
for d in data:
label = d["label"]
image = d["data"]
## labels need to be integers
assert np.equal(np.mod(label, 1), 0).all()
## labels need to be in range pipe_name[2]
assert (label >= label_range[0]).all()
assert (label <= label_range[1]).all()
print("OK : " + pipe_name.__name__)
RUN: mxnet_reader_pipeline
OK : mxnet_reader_pipeline
RUN: caffe_reader_pipeline
OK : caffe_reader_pipeline
RUN: file_reader_pipeline
OK : file_reader_pipeline
RUN: tfrecord_reader_pipeline
OK : tfrecord_reader_pipeline