Using Tensorflow DALI plugin: using various readers

Overview

This example shows how different readers could be used to interact with Tensorflow. It shows how flexible DALI is.

The following readers are used in this example:

  • readers.mxnet

  • readers.caffe

  • readers.file

  • readers.tfrecord

For details on how to use them please see other examples.

Let us start with defining some global constants

DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.

[1]:
import os.path

test_data_root = os.environ["DALI_EXTRA_PATH"]

# MXNet RecordIO
db_folder = os.path.join(test_data_root, "db", "recordio/")

# Caffe LMDB
lmdb_folder = os.path.join(test_data_root, "db", "lmdb")

# image dir with plain jpeg files
image_dir = "../../data/images"

# TFRecord
tfrecord = os.path.join(test_data_root, "db", "tfrecord", "train")
tfrecord_idx = "idx_files/train.idx"
tfrecord2idx_script = "tfrecord2idx"

N = 8  # number of GPUs
BATCH_SIZE = 128  # batch size per GPU
ITERATIONS = 32
IMAGE_SIZE = 3

Create idx file by calling tfrecord2idx script

[2]:
from subprocess import call
import os.path

if not os.path.exists("idx_files"):
    os.mkdir("idx_files")

if not os.path.isfile(tfrecord_idx):
    call([tfrecord2idx_script, tfrecord, tfrecord_idx])

Let us define: - common part of the processing graph, used by all pipelines

[3]:
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types


def common_pipeline(jpegs, labels):
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(
        images, resize_shorter=fn.random.uniform(range=(256, 480)), interp_type=types.INTERP_LINEAR
    )
    images = fn.crop_mirror_normalize(
        images,
        crop_pos_x=fn.random.uniform(range=(0.0, 1.0)),
        crop_pos_y=fn.random.uniform(range=(0.0, 1.0)),
        dtype=types.FLOAT,
        crop=(227, 227),
        mean=[128.0, 128.0, 128.0],
        std=[1.0, 1.0, 1.0],
    )
    return images, labels
  • MXNet reader pipeline

[4]:
@pipeline_def
def mxnet_reader_pipeline(num_gpus):
    jpegs, labels = fn.readers.mxnet(
        path=[db_folder + "train.rec"],
        index_path=[db_folder + "train.idx"],
        random_shuffle=True,
        shard_id=Pipeline.current().device_id,
        num_shards=num_gpus,
        name="Reader",
    )

    return common_pipeline(jpegs, labels)
  • Caffe reader pipeline

[5]:
@pipeline_def
def caffe_reader_pipeline(num_gpus):
    jpegs, labels = fn.readers.caffe(
        path=lmdb_folder,
        random_shuffle=True,
        shard_id=Pipeline.current().device_id,
        num_shards=num_gpus,
        name="Reader",
    )

    return common_pipeline(jpegs, labels)
  • File reader pipeline

[6]:
@pipeline_def
def file_reader_pipeline(num_gpus):
    jpegs, labels = fn.readers.file(
        file_root=image_dir,
        random_shuffle=True,
        shard_id=Pipeline.current().device_id,
        num_shards=num_gpus,
        name="Reader",
    )

    return common_pipeline(jpegs, labels)
  • TFRecord reader pipeline

[7]:
import nvidia.dali.tfrecord as tfrec


@pipeline_def
def tfrecord_reader_pipeline(num_gpus):
    inputs = fn.readers.tfrecord(
        path=tfrecord,
        index_path=tfrecord_idx,
        features={
            "image/encoded": tfrec.FixedLenFeature((), tfrec.string, ""),
            "image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1),
        },
        random_shuffle=True,
        shard_id=Pipeline.current().device_id,
        num_shards=num_gpus,
        name="Reader",
    )

    return common_pipeline(inputs["image/encoded"], inputs["image/class/label"])

Now let us create function which builds pipeline on demand:

[8]:
import tensorflow as tf
import nvidia.dali.plugin.tf as dali_tf

from tensorflow.compat.v1 import GPUOptions
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import Session
from tensorflow.compat.v1 import placeholder

tf.compat.v1.disable_eager_execution()


def get_batch_test_dali(batch_size, pipe_type):
    pipe_name, label_type, _ = pipe_type
    pipes = [
        pipe_name(batch_size=BATCH_SIZE, num_threads=2, device_id=device_id, num_gpus=N)
        for device_id in range(N)
    ]

    daliop = dali_tf.DALIIterator()
    images = []
    labels = []
    for d in range(N):
        with tf.device("/gpu:%i" % d):
            image, label = daliop(
                pipeline=pipes[d],
                shapes=[(BATCH_SIZE, 3, 227, 227), ()],
                dtypes=[tf.int32, label_type],
                device_id=d,
            )
            images.append(image)
            labels.append(label)

    return [images, labels]

At the end let us test if all pipelines have been correctly built and run with TF session

[9]:
import numpy as np

pipe_types = [
    [mxnet_reader_pipeline, tf.float32, (0, 999)],
    [caffe_reader_pipeline, tf.int32, (0, 999)],
    [file_reader_pipeline, tf.int32, (0, 1)],
    [tfrecord_reader_pipeline, tf.int64, (1, 1000)],
]

for pipe_name in pipe_types:
    print("RUN: " + pipe_name[0].__name__)
    test_batch = get_batch_test_dali(BATCH_SIZE, pipe_name)
    x = placeholder(tf.float32, shape=[BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3], name="x")
    gpu_options = GPUOptions(per_process_gpu_memory_fraction=0.8)
    config = ConfigProto(gpu_options=gpu_options)

    with Session(config=config) as sess:
        for i in range(ITERATIONS):
            imgs, labels = sess.run(test_batch)
            # Testing correctness of labels
            for label in labels:
                ## labels need to be integers
                assert np.equal(np.mod(label, 1), 0).all()
                ## labels need to be in range pipe_name[2]
                assert (label >= pipe_name[2][0]).all()
                assert (label <= pipe_name[2][1]).all()
    print("OK : " + pipe_name[0].__name__)
RUN: mxnet_reader_pipeline
OK : mxnet_reader_pipeline
RUN: caffe_reader_pipeline
OK : caffe_reader_pipeline
RUN: file_reader_pipeline
OK : file_reader_pipeline
RUN: tfrecord_reader_pipeline
OK : tfrecord_reader_pipeline
[ ]: