Using Tensorflow DALI plugin: DALI tf.data.Dataset with multiple GPUs¶
Overview¶
This notebook is a comprehensive example on how to use DALI tf.data.Dataset with multiple GPUs. It is recommended to look into single GPU example first to get up to speed with DALI dataset and how it can be used to train a neural network. This example is an extension of the single GPU version.
Initially we define some parameters of the training and to create a DALI pipeline to read MNIST converted to LMDB format. You can find it in DALI_extra dataset. This pipeline is able to partition the dataset into multiple shards.
DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.
[1]:
import nvidia.dali as dali
from nvidia.dali import pipeline_def,Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import os
import nvidia.dali.plugin.tf as dali_tf
import tensorflow as tf
import logging
tf.get_logger().setLevel(logging.ERROR)
[2]:
# Path to MNIST dataset
data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
BATCH_SIZE = 64
DROPOUT = 0.2
IMAGE_SIZE = 28
NUM_CLASSES = 10
HIDDEN_SIZE = 128
EPOCHS = 5
ITERATIONS = 100
NUM_DEVICES = 2
[3]:
@pipeline_def(batch_size=BATCH_SIZE)
def mnist_pipeline(shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path, random_shuffle=True, shard_id=shard_id, num_shards=NUM_DEVICES)
    images = fn.decoders.image(jpegs, device='mixed', output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.], output_layout="CHW")
    return images, labels.gpu()
Next we create some parameters needed for the DALI dataset. For more details on what they are you can look into single GPU example.
[4]:
shapes = (
    (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE),
    (BATCH_SIZE))
dtypes = (
    tf.float32,
    tf.int32)
Now we are ready to define the model. To make the training distributed to multiple GPUs, we use tf.distribute.MirroredStrategy.
[5]:
strategy = tf.distribute.MirroredStrategy(devices=['/gpu:0', '/gpu:1'])
with strategy.scope():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='images'),
        tf.keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE)),
        tf.keras.layers.Dense(HIDDEN_SIZE, activation='relu'),
        tf.keras.layers.Dropout(DROPOUT),
        tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')])
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
DALI dataset needs to be distributed as well. To do it, we use distribute_datasets_from_function. First we need to define a function that returns dataset bound to a device given by id. Also, some specific options are needed to make everything work.
[6]:
 def dataset_fn(input_context):
        with tf.device("/gpu:{}".format(input_context.input_pipeline_id)):
            device_id = input_context.input_pipeline_id
            return dali_tf.DALIDataset(
                pipeline=mnist_pipeline(
                    device_id=device_id, shard_id=device_id),
                batch_size=BATCH_SIZE,
                output_shapes=shapes,
                output_dtypes=dtypes,
                device_id=device_id)
input_options = tf.distribute.InputOptions(
    experimental_place_dataset_on_device = True,
    experimental_prefetch_to_device = False,
    experimental_replication_mode = tf.distribute.InputReplicationMode.PER_REPLICA)
train_dataset = strategy.distribute_datasets_from_function(dataset_fn, input_options)
With everything in place, we are ready to run the training and evaluate the model.
[7]:
model.fit(
    train_dataset,
    epochs=EPOCHS,
    steps_per_epoch=ITERATIONS)
Epoch 1/5
100/100 [==============================] - 4s 8ms/step - loss: 1.2438 - accuracy: 0.6290
Epoch 2/5
100/100 [==============================] - 1s 8ms/step - loss: 0.3991 - accuracy: 0.8876
Epoch 3/5
100/100 [==============================] - 1s 8ms/step - loss: 0.3202 - accuracy: 0.9045
Epoch 4/5
100/100 [==============================] - 1s 9ms/step - loss: 0.2837 - accuracy: 0.9183
Epoch 5/5
100/100 [==============================] - 1s 8ms/step - loss: 0.2441 - accuracy: 0.9303
[7]:
<tensorflow.python.keras.callbacks.History at 0x7f5d09685880>
[8]:
model.evaluate(
    train_dataset,
    steps=ITERATIONS)
100/100 [==============================] - 2s 5ms/step - loss: 0.1963 - accuracy: 0.9438
[8]:
[0.19630344212055206, 0.9437500238418579]