Using Tensorflow DALI plugin: DALI and tf.data

Overview

DALI offers integration with tf.data API. Using this approach you can easily connect DALI pipeline with various TensorFlow APIs and use it as a data source for your model. This tutorial shows how to do it using well known MNIST converted to LMDB format. You can find it in DALI_extra - DALI test data repository.

We start with creating a DALI pipeline to read, decode and normalize MNIST images and read corresponding labels.

DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.

[1]:
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import os

BATCH_SIZE = 64
DROPOUT = 0.2
IMAGE_SIZE = 28
NUM_CLASSES = 10
HIDDEN_SIZE = 128
EPOCHS = 5
ITERATIONS_PER_EPOCH = 100


# Path to MNIST dataset
data_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/")


@pipeline_def(device_id=0, batch_size=BATCH_SIZE)
def mnist_pipeline(device):
    jpegs, labels = fn.readers.caffe2(path=data_path, random_shuffle=True)
    images = fn.decoders.image(
        jpegs, device="mixed" if device == "gpu" else "cpu", output_type=types.GRAY
    )
    images = fn.crop_mirror_normalize(
        images, device=device, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
    )

    if device == "gpu":
        labels = labels.gpu()

    return images, labels

Next step is to wrap an instance of MnistPipeline with a DALIDataset object from DALI TensorFlow plugin. This class is compatible with tf.data.Dataset. Other parameters are shapes and types of the outputs of the pipeline. Here we return images and labels. It means we have two outputs one of type tf.float32 for images and on of type tf.int32 for labels.

[2]:
import nvidia.dali.plugin.tf as dali_tf
import tensorflow as tf
import tensorflow.compat.v1 as tf_v1
import logging

tf.get_logger().setLevel(logging.ERROR)

# Create pipeline
pipeline = mnist_pipeline(device="cpu")

# Define shapes and types of the outputs
shapes = ((BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE), (BATCH_SIZE))
dtypes = (tf.float32, tf.int32)

# Create dataset
with tf.device("/cpu:0"):
    mnist_set = dali_tf.DALIDataset(
        pipeline=pipeline,
        batch_size=BATCH_SIZE,
        output_shapes=shapes,
        output_dtypes=dtypes,
        device_id=0,
    )

We are ready to start the training. Following sections show how to do it with different APIs availible in TensorFlow.

Keras

First, we pass mnist_set to model created with tf.keras and use model.fit method to train it.

[3]:
# Create the model
model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name="images"),
        tf.keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE)),
        tf.keras.layers.Dense(HIDDEN_SIZE, activation="relu"),
        tf.keras.layers.Dropout(DROPOUT),
        tf.keras.layers.Dense(NUM_CLASSES, activation="softmax"),
    ]
)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# Train using DALI dataset
model.fit(mnist_set, epochs=EPOCHS, steps_per_epoch=ITERATIONS_PER_EPOCH)
Epoch 1/5
100/100 [==============================] - 1s 3ms/step - loss: 1.3511 - accuracy: 0.5834
Epoch 2/5
100/100 [==============================] - 0s 3ms/step - loss: 0.4534 - accuracy: 0.8690
Epoch 3/5
100/100 [==============================] - 0s 4ms/step - loss: 0.3380 - accuracy: 0.9003
Epoch 4/5
100/100 [==============================] - 0s 3ms/step - loss: 0.2927 - accuracy: 0.9218
Epoch 5/5
100/100 [==============================] - 0s 4ms/step - loss: 0.2736 - accuracy: 0.9217
[3]:
<tensorflow.python.keras.callbacks.History at 0x7f678122cbe0>

As you can see, it was very easy to integrate DALI pipeline with tf.keras API.

The code above performed the training using the CPU. Both the DALI pipeline and the model were using the CPU.

We can easily move the whole processing to the GPU. First, we create a pipeline that uses the GPU with ID = 0. Next we place both the DALI dataset and the model on the same GPU.

[4]:
# Define the model and place it on the GPU
with tf.device("/gpu:0"):
    mnist_set = dali_tf.DALIDataset(
        pipeline=mnist_pipeline(device="gpu"),
        batch_size=BATCH_SIZE,
        output_shapes=shapes,
        output_dtypes=dtypes,
        device_id=0,
    )
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name="images"),
            tf.keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE)),
            tf.keras.layers.Dense(HIDDEN_SIZE, activation="relu"),
            tf.keras.layers.Dropout(DROPOUT),
            tf.keras.layers.Dense(NUM_CLASSES, activation="softmax"),
        ]
    )
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

We move the training to the GPU as well. This allows TensorFlow to pick up GPU instance of DALI dataset.

[5]:
# Train on the GPU
with tf.device("/gpu:0"):
    model.fit(mnist_set, epochs=EPOCHS, steps_per_epoch=ITERATIONS_PER_EPOCH)
Epoch 1/5
100/100 [==============================] - 1s 4ms/step - loss: 1.3734 - accuracy: 0.5844
Epoch 2/5
100/100 [==============================] - 0s 4ms/step - loss: 0.4566 - accuracy: 0.8690
Epoch 3/5
100/100 [==============================] - 0s 4ms/step - loss: 0.3375 - accuracy: 0.8991
Epoch 4/5
100/100 [==============================] - 0s 4ms/step - loss: 0.3017 - accuracy: 0.9156
Epoch 5/5
100/100 [==============================] - 0s 4ms/step - loss: 0.2925 - accuracy: 0.9167

It is important to note here, that there is no intermediate CPU buffer between DALI and TensorFlow in the execution above. DALI GPU outputs are copied straight to TF GPU Tensors used by the model.

In this particular toy example performance of the GPU variant is lower than the CPU one. The MNIST images are small and nvJPEG decoder used in the GPU DALI pipeline to decode them is not well suited for such circumstance. We use it here to show how to integrate it properly in the real life case.

Estimators

Another popular TensorFlow API is tf.estimator API. This section shows how to use DALI dataset as a data source for model based on this API.

First we create the model.

[6]:
# Define the feature columns for Estimator
feature_columns = [tf.feature_column.numeric_column("images", shape=[IMAGE_SIZE, IMAGE_SIZE])]

# And the run config
run_config = tf.estimator.RunConfig(
    model_dir="/tmp/tensorflow-checkpoints", device_fn=lambda op: "/gpu:0"
)

# Finally create the model based on `DNNClassifier`
model = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[HIDDEN_SIZE],
    n_classes=NUM_CLASSES,
    dropout=DROPOUT,
    config=run_config,
    optimizer="Adam",
)

In tf.estimator API data is passed to the model with the function returning the dataset. We define this function to return DALI dataset placed on the GPU.

[7]:
def train_data_fn():
    with tf.device("/gpu:0"):
        mnist_set = dali_tf.DALIDataset(
            fail_on_device_mismatch=False,
            pipeline=mnist_pipeline(device="gpu"),
            batch_size=BATCH_SIZE,
            output_shapes=shapes,
            output_dtypes=dtypes,
            device_id=0,
        )
        mnist_set = mnist_set.map(lambda features, labels: ({"images": features}, labels))

    return mnist_set

With everything set up we are ready to run the training.

[8]:
# Running the training on the GPU
model.train(input_fn=train_data_fn, steps=EPOCHS * ITERATIONS_PER_EPOCH)
[8]:
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7f677012beb0>
[9]:
def test_data_fn():
    with tf.device("/cpu:0"):
        mnist_set = dali_tf.DALIDataset(
            fail_on_device_mismatch=False,
            pipeline=mnist_pipeline(device="cpu"),
            batch_size=BATCH_SIZE,
            output_shapes=shapes,
            output_dtypes=dtypes,
            device_id=0,
        )
        mnist_set = mnist_set.map(lambda features, labels: ({"images": features}, labels))

    return mnist_set


model.evaluate(input_fn=test_data_fn, steps=ITERATIONS_PER_EPOCH)
[9]:
{'accuracy': 0.9915625,
 'average_loss': 0.030411616,
 'loss': 0.030411616,
 'global_step': 5500}

Custom Models and Training Loops

Finally, the last part of this tutorial focuses on integrating DALI dataset with custom models and training loops. A complete example below shows from start to finish how to use DALI dataset with native TensorFlow model and run training using tf.Session.

First step is to define the model and the dataset and place both on the GPU.

[10]:
tf.compat.v1.disable_eager_execution()
tf_v1.reset_default_graph()

with tf.device("/gpu:0"):
    mnist_set = dali_tf.DALIDataset(
        pipeline=mnist_pipeline(device="gpu"),
        batch_size=BATCH_SIZE,
        output_shapes=shapes,
        output_dtypes=dtypes,
        device_id=0,
    )

    iterator = tf_v1.data.make_initializable_iterator(mnist_set)
    images, labels = iterator.get_next()

    labels = tf_v1.reshape(tf_v1.one_hot(labels, NUM_CLASSES), [BATCH_SIZE, NUM_CLASSES])

    with tf_v1.variable_scope("mnist_net", reuse=False):
        images = tf_v1.layers.flatten(images)
        images = tf_v1.layers.dense(images, HIDDEN_SIZE, activation=tf_v1.nn.relu)
        images = tf_v1.layers.dropout(images, rate=DROPOUT, training=True)
        images = tf_v1.layers.dense(images, NUM_CLASSES, activation=tf_v1.nn.softmax)

    logits_train = images
    loss_op = tf_v1.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=logits_train, labels=labels)
    )
    train_step = tf_v1.train.AdamOptimizer().minimize(loss_op)

    correct_pred = tf_v1.equal(tf_v1.argmax(logits_train, 1), tf_v1.argmax(labels, 1))
    accuracy = tf_v1.reduce_mean(tf_v1.cast(correct_pred, tf_v1.float32))
/home/awolant/.local/lib/python3.8/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:329: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  warnings.warn('`tf.layers.flatten` is deprecated and '
/home/awolant/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1693: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
  warnings.warn('`layer.apply` is deprecated and '
/home/awolant/.local/lib/python3.8/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:171: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  warnings.warn('`tf.layers.dense` is deprecated and '
/home/awolant/.local/lib/python3.8/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:268: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  warnings.warn('`tf.layers.dropout` is deprecated and '

With tf.Session we can run this model and train it on the GPU.

[11]:
with tf_v1.Session() as sess:
    sess.run(tf_v1.global_variables_initializer())
    sess.run(iterator.initializer)

    for i in range(EPOCHS * ITERATIONS_PER_EPOCH):
        sess.run(train_step)
        if i % ITERATIONS_PER_EPOCH == 0:
            train_accuracy = sess.run(accuracy)
            print("Step %d, accuracy: %g" % (i, train_accuracy))

    final_accuracy = 0
    for _ in range(ITERATIONS_PER_EPOCH):
        final_accuracy = final_accuracy + sess.run(accuracy)
    final_accuracy = final_accuracy / ITERATIONS_PER_EPOCH

    print("Final accuracy: ", final_accuracy)
Step 0, accuracy: 0.140625
Step 100, accuracy: 0.84375
Step 200, accuracy: 0.9375
Step 300, accuracy: 0.875
Step 400, accuracy: 0.90625
Final accuracy:  0.90640625
[ ]: