Using Tensorflow DALI plugin: DALI and tf.data

Overview

DALI offers integration with tf.data API. Using this approach you can easily connect DALI pipeline with various TensorFlow APIs and use it as a data source for your model. This tutorial shows how to do it using well known MNIST converted to LMDB format. You can find it in DALI_extra - DALI test data repository.

We start with creating a DALI pipeline to read, decode and normalize MNIST images and read corresponding labels.

DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.

[1]:
import nvidia.dali as dali
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types

import os

# Path to MNIST dataset
data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')


class MnistPipeline(Pipeline):
    def __init__(self, batch_size, device, device_id=0, num_threads=4, seed=0):
        super(MnistPipeline, self).__init__(
            batch_size, num_threads, device_id, seed)
        self.device = device
        self.reader = ops.Caffe2Reader(path=data_path, random_shuffle=True)
        self.decode = ops.ImageDecoder(
            device='mixed' if device is 'gpu' else 'cpu',
            output_type=types.GRAY)
        self.cmn = ops.CropMirrorNormalize(
            device=device,
            output_dtype=types.FLOAT,
            image_type=types.GRAY,
            std=[255.],
            output_layout="CHW")

    def define_graph(self):
        inputs, labels = self.reader(name="Reader")
        images = self.decode(inputs)
        if self.device is 'gpu':
            labels = labels.gpu()
        images = self.cmn(images)

        return (images, labels)

Now we define some parameters of the training:

[2]:
BATCH_SIZE = 64
DROPOUT = 0.2
IMAGE_SIZE = 28
NUM_CLASSES = 10
HIDDEN_SIZE = 128
EPOCHS = 5
ITERATIONS_PER_EPOCH = 100

Next step is to wrap an instance of MnistPipeline with a DALIDataset object from DALI TensorFlow plugin. This class is compatible with tf.data.Dataset. Other parameters are shapes and types of the outputs of the pipeline. Here we return images and labels. It means we have two outputs one of type tf.float32 for images and on of type tf.int32 for labels.

[3]:
import nvidia.dali.plugin.tf as dali_tf
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)
tf.disable_eager_execution()


# Create pipeline
mnist_pipeline = MnistPipeline(BATCH_SIZE, device='cpu', device_id=0)

# Define shapes and types of the outputs
shapes = [
    (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE),
    (BATCH_SIZE)]
dtypes = [
    tf.float32,
    tf.int32]

# Create dataset
mnist_set = dali_tf.DALIDataset(
    pipeline=mnist_pipeline,
    batch_size=BATCH_SIZE,
    shapes=shapes,
    dtypes=dtypes,
    device_id=0)

We are ready to start the training. Following sections show how to do it with different APIs availible in TensorFlow.

Keras

First, we pass mnist_set to model created with tf.keras and use model.fit method to train it.

[4]:
# Create the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='images'),
    tf.keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE)),
    tf.keras.layers.Dense(HIDDEN_SIZE, activation='relu'),
    tf.keras.layers.Dropout(DROPOUT),
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')])
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])

# Train using DALI dataset
model.fit(
    mnist_set,
    epochs=EPOCHS,
    steps_per_epoch=ITERATIONS_PER_EPOCH)
Train on 100 steps
Epoch 1/5
100/100 [==============================] - 1s 6ms/step - loss: 0.8921 - accuracy: 0.7462
Epoch 2/5
100/100 [==============================] - 0s 4ms/step - loss: 0.4115 - accuracy: 0.8847
Epoch 3/5
100/100 [==============================] - 0s 3ms/step - loss: 0.3235 - accuracy: 0.9062
Epoch 4/5
100/100 [==============================] - 0s 4ms/step - loss: 0.2926 - accuracy: 0.9202
Epoch 5/5
100/100 [==============================] - 0s 4ms/step - loss: 0.2617 - accuracy: 0.9245
[4]:
<tensorflow.python.keras.callbacks.History at 0x7fc3a4955518>

As you can see, it was very easy to integrate DALI pipeline with tf.keras API.

The code above performed the training using the CPU. Both the DALI pipeline and the model were using the CPU.

We can easily move the whole processing to the GPU. First, we create a pipeline that uses the GPU with ID = 0. Next we place both the DALI dataset and the model on the same GPU.

[5]:
# Create pipeline
mnist_pipeline = MnistPipeline(BATCH_SIZE, device='gpu', device_id=0)

# Define the model and place it on the GPU
with tf.device('/gpu:0'):
    mnist_set = dali_tf.DALIDataset(
        pipeline=mnist_pipeline,
        batch_size=BATCH_SIZE,
        shapes=shapes,
        dtypes=dtypes,
        device_id=0)
    model = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='images'),
        tf.keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE)),
        tf.keras.layers.Dense(HIDDEN_SIZE, activation='relu'),
        tf.keras.layers.Dropout(DROPOUT),
        tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')])
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])

We move the training to the GPU as well. This allows TensorFlow to pick up GPU instance of DALI dataset.

[6]:
# Train on the GPU
with tf.device('/gpu:0'):
    model.fit(
        mnist_set,
        epochs=EPOCHS,
        steps_per_epoch=ITERATIONS_PER_EPOCH)
Train on 100 steps
Epoch 1/5
100/100 [==============================] - 1s 7ms/step - loss: 0.9235 - accuracy: 0.7381
Epoch 2/5
100/100 [==============================] - 1s 6ms/step - loss: 0.4115 - accuracy: 0.8856
Epoch 3/5
100/100 [==============================] - 1s 6ms/step - loss: 0.3243 - accuracy: 0.9050
Epoch 4/5
100/100 [==============================] - 0s 5ms/step - loss: 0.2932 - accuracy: 0.9166
Epoch 5/5
100/100 [==============================] - 1s 7ms/step - loss: 0.2606 - accuracy: 0.9212

It is important to note here, that there is no intermediate CPU buffer between DALI and TensorFlow in the execution above. DALI GPU outputs are copied straight to TF GPU Tensors used by the model.

In this particular toy example performance of the GPU variant is lower than the CPU one. The MNIST images are small and nvJPEG decoder used in the GPU DALI pipeline to decode them is not well suited for such circumstance. We use it here to show how to integrate it properly in the real life case.

Estimators

Another popular TensorFlow API is tf.estimator API. This section shows how to use DALI dataset as a data source for model based on this API.

First we create the model.

[7]:
# Define the feature columns for Estimator
feature_columns = [tf.feature_column.numeric_column(
    "images", shape=[IMAGE_SIZE, IMAGE_SIZE])]

# And the run config
run_config = tf.estimator.RunConfig(
    model_dir='/tmp/tensorflow-checkpoints',
    device_fn=lambda op: '/gpu:0')

# Finally create the model based on `DNNClassifier`
model = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[HIDDEN_SIZE],
    n_classes=NUM_CLASSES,
    dropout=DROPOUT,
    config=run_config,
    optimizer='Adam')

In tf.estimator API data is passed to the model with the function returning the dataset. We define this function to return DALI dataset placed on the GPU.

[8]:
def train_data_fn():
    with tf.device('/gpu:0'):
        mnist_pipeline = MnistPipeline(BATCH_SIZE, device='gpu', device_id=0)
        mnist_set = dali_tf.DALIDataset(
            pipeline=mnist_pipeline,
            batch_size=BATCH_SIZE,
            shapes=shapes,
            dtypes=dtypes,
            device_id=0)
        mnist_set = mnist_set.map(
            lambda features, labels: ({'images': features}, labels))

    return mnist_set

With everything set up we are ready to run the training.

[9]:
# Running the training on the GPU
model.train(input_fn=train_data_fn, steps=EPOCHS * ITERATIONS_PER_EPOCH)
[9]:
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7fc3a6ad8b70>
[10]:
model.evaluate(input_fn=train_data_fn, steps=ITERATIONS_PER_EPOCH)
[10]:
{'accuracy': 0.87921876,
 'average_loss': 0.49361104,
 'loss': 31.591106,
 'global_step': 17500}

Custom models and training loops

Finally, the last part of this tutorial focuses on integrating DALI dataset with custom models and training loops. A complete example below shows from start to finish how to use DALI dataset with native TensorFlow model and run training using tf.Session.

First step is to define the model and the dataset and place both on the GPU.

[11]:
tf.reset_default_graph()

options = tf.data.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.autotune = False


with tf.device('/gpu:0'):
    mnist_set = dali_tf.DALIDataset(
        pipeline=MnistPipeline(BATCH_SIZE, device='gpu', device_id=0),
        batch_size=BATCH_SIZE,
        shapes=shapes,
        dtypes=dtypes,
        device_id=0).with_options(options)

    iterator = tf.data.make_initializable_iterator(mnist_set)
    images, labels = iterator.get_next()

    labels = tf.reshape(
        tf.one_hot(labels, NUM_CLASSES),
        [BATCH_SIZE, NUM_CLASSES])

    with tf.variable_scope('mnist_net', reuse=False):
        images = tf.layers.flatten(images)
        images = tf.layers.dense(images, HIDDEN_SIZE, activation=tf.nn.relu)
        images = tf.layers.dropout(images, rate=DROPOUT, training=True)
        images = tf.layers.dense(images, NUM_CLASSES, activation=tf.nn.softmax)

    logits_train = images
    loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
        logits=logits_train, labels=labels))
    train_step = tf.train.AdamOptimizer().minimize(loss_op)

    correct_pred = tf.equal(
            tf.argmax(logits_train, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

With tf.Session we can run this model and train it on the GPU.

[12]:
with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(iterator.initializer)

        for i in range(EPOCHS * ITERATIONS_PER_EPOCH):
            sess.run(train_step)
            if i % ITERATIONS_PER_EPOCH == 0:
                train_accuracy = sess.run(accuracy)
                print("Step %d, accuracy: %g" % (i, train_accuracy))

        final_accuracy = 0
        for _ in range(ITERATIONS_PER_EPOCH):
            final_accuracy = final_accuracy + sess.run(accuracy)
        final_accuracy = final_accuracy / ITERATIONS_PER_EPOCH

        print('Final accuracy: ', final_accuracy)
Step 0, accuracy: 0.109375
Step 100, accuracy: 0.828125
Step 200, accuracy: 0.828125
Step 300, accuracy: 0.96875
Step 400, accuracy: 0.890625
Final accuracy:  0.90734375