Using Tensorflow DALI plugin: DALI tf.data.Dataset with multiple GPUs¶
Overview¶
This notebook is a comprehensive example on how to use DALI tf.data.Dataset with multiple GPUs. It is recommended to look into single GPU example first to get up to speed with DALI dataset and how it can be used to train a neural network with custom model and training loop. This example is an extension of the single GPU version.
Initially we define some parameters of the training and to create a DALI pipeline to read MNIST converted to LMDB format. You can find it in DALI_extra dataset. This pipeline is able to partition the dataset into multiple shards.
DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.
[1]:
import nvidia.dali as dali
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import os
import nvidia.dali.plugin.tf as dali_tf
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)
tf.disable_eager_execution()
tf.reset_default_graph()
[2]:
# Path to MNIST dataset
data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
BATCH_SIZE = 64
DROPOUT = 0.2
IMAGE_SIZE = 28
NUM_CLASSES = 10
HIDDEN_SIZE = 128
EPOCHS = 5
ITERATIONS = 100
NUM_DEVICES = 2
[3]:
class MnistPipeline(Pipeline):
    def __init__(
        self, batch_size, device_id=0, shard_id=0, num_shards=1, num_threads=4, seed=0):
        super(MnistPipeline, self).__init__(
            batch_size, num_threads, device_id, seed)
        self.reader = ops.Caffe2Reader(
            path=data_path, random_shuffle=True, shard_id=shard_id, num_shards=num_shards)
        self.decode = ops.ImageDecoder(
            device='mixed',
            output_type=types.GRAY)
        self.cmn = ops.CropMirrorNormalize(
            device='gpu',
            dtype=types.FLOAT,
            std=[255.],
            output_layout="CHW")
    def define_graph(self):
        inputs, labels = self.reader(name="Reader")
        images = self.decode(inputs)
        labels = labels.gpu()
        images = self.cmn(images)
        return (images, labels)
Next we create some parameters needed for the DALI dataset. For more details on what they are you can look into single GPU example.
[4]:
options = tf.data.Options()
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.autotune = False
shapes = (
    (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE),
    (BATCH_SIZE))
dtypes = (
    tf.float32,
    tf.int32)
As we utilize more than one GPU for this training, we use the function below to average gradient between the devices.
[5]:
# This function is copied form: https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py#L102
def average_gradients(tower_grads):
    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        # Note that each grad_and_vars looks like the following:
        #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
        grads = []
        for g, _ in grad_and_vars:
            # Add 0 dimension to the gradients to represent the tower.
            expanded_g = tf.expand_dims(g, 0)
            # Append on a 'tower' dimension which we will average over below.
            grads.append(expanded_g)
        # Average over the 'tower' dimension.
        grad = tf.concat(grads, 0)
        grad = tf.reduce_mean(grad, 0)
        # Keep in mind that the Variables are redundant because they are shared
        # across towers. So .. we will just return the first tower's pointer to
        # the Variable.
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads
Now we are ready to define the model. Note that one instance of the DALI dataset is created per GPU. Each instance reads only the part of the MNIST dataset assinged with shard_id parameter of the wrapped pipeline.
[6]:
iterator_initializers = []
with tf.device('/cpu:0'):
    tower_grads = []
    for i in range(NUM_DEVICES):
        with tf.device('/gpu:{}'.format(i)):
            daliset = dali_tf.DALIDataset(
                pipeline=MnistPipeline(
                    BATCH_SIZE, device_id=i, shard_id=i, num_shards=NUM_DEVICES),
                batch_size=BATCH_SIZE,
                output_shapes=shapes,
                output_dtypes=dtypes,
                device_id=i).with_options(options)
            iterator = tf.data.make_initializable_iterator(daliset)
            iterator_initializers.append(iterator.initializer)
            images, labels = iterator.get_next()
            labels = tf.reshape(
                tf.one_hot(labels, NUM_CLASSES),
                [BATCH_SIZE, NUM_CLASSES])
            with tf.variable_scope('mnist_net', reuse=(i != 0)):
                images = tf.layers.flatten(images)
                images = tf.layers.dense(images, HIDDEN_SIZE, activation=tf.nn.relu)
                images = tf.layers.dropout(images, rate=DROPOUT, training=True)
                images = tf.layers.dense(images, NUM_CLASSES, activation=tf.nn.softmax)
            logits_train = images
            loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                logits=logits_train, labels=labels))
            optimizer = tf.train.AdamOptimizer()
            grads = optimizer.compute_gradients(loss_op)
            if i == 0:
                correct_pred = tf.equal(
                    tf.argmax(logits_train, 1), tf.argmax(labels, 1))
                accuracy = tf.reduce_mean(
                    tf.cast(correct_pred, tf.float32))
            tower_grads.append(grads)
    tower_grads = average_gradients(tower_grads)
    train_step = optimizer.apply_gradients(tower_grads)
Everything is now ready to run the training.
[7]:
with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(iterator_initializers)
        for i in range(EPOCHS * ITERATIONS):
            sess.run(train_step)
            if i % ITERATIONS == 0:
                train_accuracy = sess.run(accuracy)
                print("Step %d, accuracy: %g" % (i, train_accuracy))
        final_accuracy = 0
        for _ in range(ITERATIONS):
            final_accuracy = final_accuracy + \
                sess.run(accuracy)
        final_accuracy = final_accuracy / ITERATIONS
        print('Final accuracy: ', final_accuracy)
Step 0, accuracy: 0.171875
Step 100, accuracy: 0.9375
Step 200, accuracy: 0.859375
Step 300, accuracy: 0.921875
Step 400, accuracy: 0.90625
Final accuracy:  0.92015625