Using Tensorflow DALI plugin with sparse tensors#
Overview#
Using our DALI data loading and augmentation pipeline with Tensorflow is pretty simple.
However, sometimes a batch of data that uses wants to extract from the pipeline cannot be represented as a dense tensor. In such case, DALI op utilizes TensorFlow SparseTensor. Please keep in mind that SparseTensors are supported only for the CPU based piepline.
Defining the Data Loading Pipeline#
First, we start by defining some simple pipeline that will return data as a sparse tensor. To ochieve this, we will use well known COCO data set. Each image may have 0 or more bounding boxes with labels describing objects present in it.Wa want to return images in a normalized way, while labels and bounding boxes will be represented as sparse tensors. At the beginning let us define some global parameters
DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.
[1]:
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import os.path
test_data_root = os.environ["DALI_EXTRA_PATH"]
BATCH_SIZE = 32
test_data_root = os.environ["DALI_EXTRA_PATH"]
file_root = os.path.join(test_data_root, "db", "coco", "images")
annotations_file = os.path.join(test_data_root, "db", "coco", "instances.json")
Pipeline with the COCO reader is created. Please notice that while images are processed, other data from COCO ara passes through.
[2]:
@pipeline_def
def coco_pipeline():
    jpegs, bboxes, labels, im_ids = fn.readers.coco(
        file_root=file_root,
        annotations_file=annotations_file,
        ratio=False,
        image_ids=True,
    )
    images = fn.decoders.image(jpegs, device="cpu")
    images = fn.resize(
        images,
        resize_shorter=fn.random.uniform(range=(256.0, 480.0)),
        interp_type=types.INTERP_LINEAR,
    )
    images = fn.crop_mirror_normalize(
        images,
        crop_pos_x=fn.random.uniform(range=(0.0, 1.0)),
        crop_pos_y=fn.random.uniform(range=(0.0, 1.0)),
        dtype=types.FLOAT,
        crop=(224, 224),
        mean=[128.0, 128.0, 128.0],
        std=[1.0, 1.0, 1.0],
    )
    images = fn.cast(images, dtype=types.INT32)
    return images, bboxes, labels, im_ids
Next, we instatiate the pipelines with the right parameters. We will create one pipeline per GPU, by specifying the right device_id for each pipeline.
The difference is that instead of calling pipeline.build and using it, we will pass the pipeline object to the TensorFlow operator.
[3]:
pipe = coco_pipeline(batch_size=BATCH_SIZE, num_threads=2, device_id=0)
Using DALI TensorFlow Plugin#
Let’s start by importing Tensorflow and the DALI Tensorflow plugin as dali_tf.
[4]:
import tensorflow as tf
import nvidia.dali.plugin.tf as dali_tf
import time
from tensorflow.compat.v1 import GPUOptions
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import Session
from tensorflow.compat.v1 import placeholder
tf.compat.v1.disable_eager_execution()
We can now use nvidia.dali.plugin.tf.DALIIterator() method to get the Tensorflow Op that will produce the tensors we will use in the Tensorflow graph.
For each DALI pipeline, we use daliop that returns a Tensorflow tensor tuple that we will store in image, bouding boxes, labels and image ids.To enable sparse tensor generation sparse argument need to be filled with True values for the output elements that are going to be represented as a sparse tensors.
[5]:
daliop = dali_tf.DALIIterator()
images = []
bboxes = []
labels = []
image_ids = []
with tf.device("/cpu"):
    image, bbox, label, id = daliop(
        pipeline=pipe,
        shapes=[(BATCH_SIZE, 3, 224, 224), (), (), ()],
        dtypes=[tf.int32, tf.float32, tf.int32, tf.int32],
        sparse=[False, True, True],
    )
    images.append(image)
    bboxes.append(bbox)
    labels.append(label)
    image_ids.append(id)
Using the Tensors in a Simple Tensorflow Graph#
We will use images, bboxes, labels and image_ids tensors list in our Tensorflow graph definition. Then run a very simple one op graph session that will output the batch of data. Then we will print bounding boxes, labels and image_ids.
[6]:
with Session() as sess:
    all_img_per_sec = []
    total_batch_size = BATCH_SIZE
    start_time = time.time()
    # The actual run with our dali_tf tensors
    res_cpu = sess.run([images, bboxes, labels, image_ids])
print(res_cpu[1])
print(res_cpu[2])
print(res_cpu[3])
[SparseTensorValue(indices=array([[ 0,  0,  0],
       [ 0,  0,  1],
       [ 0,  0,  2],
       [ 0,  0,  3],
       [ 1,  0,  0],
       [ 1,  0,  1],
       [ 1,  0,  2],
       [ 1,  0,  3],
       [ 2,  0,  0],
       [ 2,  0,  1],
       [ 2,  0,  2],
       [ 2,  0,  3],
       [ 3,  0,  0],
       [ 3,  0,  1],
       [ 3,  0,  2],
       [ 3,  0,  3],
       [ 3,  1,  0],
       [ 3,  1,  1],
       [ 3,  1,  2],
       [ 3,  1,  3],
       [ 4,  0,  0],
       [ 4,  0,  1],
       [ 4,  0,  2],
       [ 4,  0,  3],
       [ 5,  0,  0],
       [ 5,  0,  1],
       [ 5,  0,  2],
       [ 5,  0,  3],
       [ 6,  0,  0],
       [ 6,  0,  1],
       [ 6,  0,  2],
       [ 6,  0,  3],
       [ 7,  0,  0],
       [ 7,  0,  1],
       [ 7,  0,  2],
       [ 7,  0,  3],
       [ 8,  0,  0],
       [ 8,  0,  1],
       [ 8,  0,  2],
       [ 8,  0,  3],
       [ 9,  0,  0],
       [ 9,  0,  1],
       [ 9,  0,  2],
       [ 9,  0,  3],
       [ 9,  1,  0],
       [ 9,  1,  1],
       [ 9,  1,  2],
       [ 9,  1,  3],
       [10,  0,  0],
       [10,  0,  1],
       [10,  0,  2],
       [10,  0,  3],
       [10,  1,  0],
       [10,  1,  1],
       [10,  1,  2],
       [10,  1,  3],
       [10,  2,  0],
       [10,  2,  1],
       [10,  2,  2],
       [10,  2,  3],
       [10,  3,  0],
       [10,  3,  1],
       [10,  3,  2],
       [10,  3,  3],
       [10,  4,  0],
       [10,  4,  1],
       [10,  4,  2],
       [10,  4,  3],
       [10,  5,  0],
       [10,  5,  1],
       [10,  5,  2],
       [10,  5,  3],
       [11,  0,  0],
       [11,  0,  1],
       [11,  0,  2],
       [11,  0,  3],
       [12,  0,  0],
       [12,  0,  1],
       [12,  0,  2],
       [12,  0,  3],
       [13,  0,  0],
       [13,  0,  1],
       [13,  0,  2],
       [13,  0,  3],
       [13,  1,  0],
       [13,  1,  1],
       [13,  1,  2],
       [13,  1,  3],
       [14,  0,  0],
       [14,  0,  1],
       [14,  0,  2],
       [14,  0,  3],
       [15,  0,  0],
       [15,  0,  1],
       [15,  0,  2],
       [15,  0,  3],
       [16,  0,  0],
       [16,  0,  1],
       [16,  0,  2],
       [16,  0,  3],
       [16,  1,  0],
       [16,  1,  1],
       [16,  1,  2],
       [16,  1,  3],
       [16,  2,  0],
       [16,  2,  1],
       [16,  2,  2],
       [16,  2,  3],
       [17,  0,  0],
       [17,  0,  1],
       [17,  0,  2],
       [17,  0,  3],
       [18,  0,  0],
       [18,  0,  1],
       [18,  0,  2],
       [18,  0,  3],
       [18,  1,  0],
       [18,  1,  1],
       [18,  1,  2],
       [18,  1,  3],
       [19,  0,  0],
       [19,  0,  1],
       [19,  0,  2],
       [19,  0,  3],
       [20,  0,  0],
       [20,  0,  1],
       [20,  0,  2],
       [20,  0,  3],
       [21,  0,  0],
       [21,  0,  1],
       [21,  0,  2],
       [21,  0,  3],
       [22,  0,  0],
       [22,  0,  1],
       [22,  0,  2],
       [22,  0,  3],
       [23,  0,  0],
       [23,  0,  1],
       [23,  0,  2],
       [23,  0,  3],
       [23,  1,  0],
       [23,  1,  1],
       [23,  1,  2],
       [23,  1,  3],
       [23,  2,  0],
       [23,  2,  1],
       [23,  2,  2],
       [23,  2,  3],
       [24,  0,  0],
       [24,  0,  1],
       [24,  0,  2],
       [24,  0,  3],
       [25,  0,  0],
       [25,  0,  1],
       [25,  0,  2],
       [25,  0,  3],
       [26,  0,  0],
       [26,  0,  1],
       [26,  0,  2],
       [26,  0,  3],
       [27,  0,  0],
       [27,  0,  1],
       [27,  0,  2],
       [27,  0,  3],
       [27,  1,  0],
       [27,  1,  1],
       [27,  1,  2],
       [27,  1,  3],
       [27,  2,  0],
       [27,  2,  1],
       [27,  2,  2],
       [27,  2,  3],
       [28,  0,  0],
       [28,  0,  1],
       [28,  0,  2],
       [28,  0,  3],
       [29,  0,  0],
       [29,  0,  1],
       [29,  0,  2],
       [29,  0,  3],
       [30,  0,  0],
       [30,  0,  1],
       [30,  0,  2],
       [30,  0,  3],
       [31,  0,  0],
       [31,  0,  1],
       [31,  0,  2],
       [31,  0,  3]]), values=array([ 604.,  120.,   78.,  563.,  294.,  411.,  669.,  345.,  206.,
         19.,  887.,  664.,   70.,  239.,  580.,  655.,  604.,  192.,
        624.,  726.,  160.,  152.,  413.,  397.,  521.,   36.,  136.,
        443.,  732.,  390.,  181.,   48.,   69.,  216., 1129.,  437.,
        377.,   24.,  512.,  652.,  316.,   52.,  476.,  428.,  572.,
        442.,   98.,  403.,  172.,  181.,  932.,  466.,  446.,  191.,
        728.,  608.,  347.,  645.,  187.,   83.,  143.,  569.,  204.,
         88.,  110.,  145.,  894.,  363.,  528.,  120.,  448.,  273.,
        253.,  283.,  816.,  518.,   85.,  518.,  639.,  389.,  221.,
        188.,  495.,  220.,  297.,  486.,  413.,  211.,  175.,   44.,
       1103.,  916.,  624.,  241.,  526.,  474.,  219.,  222.,  453.,
        237.,  553.,  157.,  366.,  305.,  727.,  208.,  465.,  255.,
        290.,  269.,  967.,  467.,  614.,   30.,  529.,  787.,  613.,
         23.,  527.,  793.,  331.,  160.,  600.,  539.,   55.,  148.,
        989.,  512.,  405.,   74.,  753.,  496.,   60.,  497.,  905.,
        246.,  432.,  110.,  252.,  540.,  528.,  105.,  643.,  491.,
        566.,   79.,  667.,  439.,  185.,   28.,  903.,  785.,  195.,
        337.,  820.,  459.,   10.,   65.,  978., 1214.,  999.,  312.,
        138.,  171.,  853.,  259.,  167.,  234.,  897.,  285.,  182.,
        299.,  173.,   55.,  767., 1079.,  539.,  448.,  556.,  323.,
          0.,   77., 1036.,  775.,   72.,   54., 1207.,  797.],
      dtype=float32), dense_shape=array([32,  6,  4]))]
[SparseTensorValue(indices=array([[ 0,  0],
       [ 1,  0],
       [ 2,  0],
       [ 3,  0],
       [ 3,  1],
       [ 4,  0],
       [ 5,  0],
       [ 6,  0],
       [ 7,  0],
       [ 8,  0],
       [ 9,  0],
       [ 9,  1],
       [10,  0],
       [10,  1],
       [10,  2],
       [10,  3],
       [10,  4],
       [10,  5],
       [11,  0],
       [12,  0],
       [13,  0],
       [13,  1],
       [14,  0],
       [15,  0],
       [16,  0],
       [16,  1],
       [16,  2],
       [17,  0],
       [18,  0],
       [18,  1],
       [19,  0],
       [20,  0],
       [21,  0],
       [22,  0],
       [23,  0],
       [23,  1],
       [23,  2],
       [24,  0],
       [25,  0],
       [26,  0],
       [27,  0],
       [27,  1],
       [27,  2],
       [28,  0],
       [29,  0],
       [30,  0],
       [31,  0]]), values=array([17,  2, 14, 12, 12,  1, 17,  8,  6,  8, 10, 17,  3,  3,  3,  3,  3,
        3,  2,  4, 13, 14,  9,  1, 12, 12, 12,  6,  8, 10,  8, 14, 13, 16,
        3,  3,  3, 15, 15,  9, 13, 13, 13,  7,  4, 12,  7], dtype=int32), dense_shape=array([32,  6]))]
[array([[ 0],
       [ 1],
       [ 2],
       [ 3],
       [ 4],
       [ 5],
       [ 6],
       [ 7],
       [ 8],
       [ 9],
       [10],
       [11],
       [12],
       [13],
       [14],
       [15],
       [16],
       [17],
       [18],
       [19],
       [20],
       [21],
       [22],
       [23],
       [24],
       [25],
       [26],
       [27],
       [28],
       [29],
       [30],
       [31]], dtype=int32)]
Let us check the output images with their augmentations! Tensorflow outputs numpy arrays, so we can visualize them easily with matplotlib.
We define a show_images helper function that will display a sample of our batch.
The batch layout is NCHW so we use transpose to get HWC images, that matplotlib can show.
[7]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline
def show_images(image_batch, nb_images):
    columns = 4
    rows = (nb_images + 1) // (columns)
    fig = plt.figure(figsize=(32, (32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(nb_images):
        plt.subplot(gs[j])
        plt.axis("off")
        img = image_batch[0][j].transpose((1, 2, 0)) + 128
        plt.imshow(img.astype("uint8"))
show_images(res_cpu[0], 8)
 
[ ]: