MXNet with DALI - ResNet 50 example

Overview

This example shows, how to use DALI pipelines with Apache MXNet.

ResNet 50 Pipeline

Let us first define a few global constants.

[1]:
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import nvidia.dali.fn as fn

N = 8  # number of GPUs
batch_size = 128  # batch size per GPU

db_folder = "/data/imagenet/train-480-val-256-recordio/"

The Training Pipeline

The training pipeline consists of the following steps: * Data is first read from MXNet’s recordIO file (the reader op is given a name Reader for later use) * Then, images are decoded using nvJPEG * RGB images are then randomly cropped and resized to the final size of (224, 224) pixels * Finally, the batch is transposed from NHWC layout to NCHW layout, normalized and randomly mirrored.

DALIClassificationIterator, which we will use for interfacing with MXNet in this example, requires outputs of the pipeline to follow (image, label) structure.

The validation pipeline is similar to the training pipeline, but omits the random resized crop and random mirroring steps, as well as shuffling the data coming from the reader.

[2]:
def create_dali_pipeline(batch_size, num_threads, device_id, db_folder, crop, size,
                         shard_id, num_shards, dali_cpu=False, is_training=True):
    pipeline = Pipeline(batch_size, num_threads, device_id, seed=12 + device_id)
    with pipeline:
        images, labels = fn.readers.mxnet(path=[db_folder+"train.rec"], index_path=[db_folder+"train.idx"],
                                          random_shuffle=False, shard_id=device_id, num_shards=num_shards,
                                          pad_last_batch=is_training, name="Reader")
        dali_device = 'cpu' if dali_cpu else 'gpu'
        decoder_device = 'cpu' if dali_cpu else 'mixed'
        # ask nvJPEG to preallocate memory for the biggest sample in ImageNet for CPU and GPU to avoid reallocations in runtime
        device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
        host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
        # ask HW NVJPEG to allocate memory ahead for the biggest image in the data set to avoid reallocations in runtime
        preallocate_width_hint = 5980 if decoder_device == 'mixed' else 0
        preallocate_height_hint = 6430 if decoder_device == 'mixed' else 0
        if is_training:
            images = fn.decoders.image_random_crop(images,
                                                  device=decoder_device, output_type=types.RGB,
                                                  device_memory_padding=device_memory_padding,
                                                  host_memory_padding=host_memory_padding,
                                                  preallocate_width_hint=preallocate_width_hint,
                                                  preallocate_height_hint=preallocate_height_hint,
                                                  random_aspect_ratio=[0.8, 1.25],
                                                  random_area=[0.1, 1.0],
                                                  num_attempts=100)
            images = fn.resize(images,
                               device=dali_device,
                               resize_x=crop,
                               resize_y=crop,
                               interp_type=types.INTERP_TRIANGULAR)
            mirror = fn.random.coin_flip(probability=0.5)
        else:
            images = fn.decoders.image(images,
                                       device=decoder_device,
                                       output_type=types.RGB)
            images = fn.resize(images,
                               device=dali_device,
                               size=size,
                               mode="not_smaller",
                               interp_type=types.INTERP_TRIANGULAR)
            mirror = False

        images = fn.crop_mirror_normalize(images.gpu(),
                                          dtype=types.FLOAT,
                                          output_layout="CHW",
                                          crop=(crop, crop),
                                          mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                          std=[0.229 * 255,0.224 * 255,0.225 * 255],
                                          mirror=mirror)
        labels = labels.gpu()
        pipeline.set_outputs(images, labels)
    return pipeline

[3]:
trainpipes = [create_dali_pipeline(db_folder=db_folder, batch_size=batch_size,
                                   num_threads=2, device_id=i, shard_id=i, num_shards=N, is_training=True,
                                   crop=224, size=256) for i in range(N)]
valpipes = [create_dali_pipeline(db_folder=db_folder, batch_size=batch_size,
                                 num_threads=2, device_id=i, shard_id=i, num_shards=N, is_training=False,
                                 crop=224, size=256) for i in range(N)]

Using the MXNet Plugin

MXNet data iterators need to know what is the size of the dataset. Since DALI pipelines may consist of multiple readers, potentially with differently sized datasets, we need to specify the reader which we ask for the epoch size. That is why we gave a name to readers in both training and validation pipelines.

In order to get the epoch size out of the reader, we need to build one of the training and one of the validation pipelines.

[4]:
trainpipes[0].build()
valpipes[0].build()
[5]:
print("Training pipeline epoch size: {}".format(trainpipes[0].epoch_size("Reader")))
print("Validation pipeline epoch size: {}".format(valpipes[0].epoch_size("Reader")))
Training pipeline epoch size: 1281167
Validation pipeline epoch size: 50000

Now we can make MXNet iterators out of our pipelines, using DALIClassificationIterator class.

[6]:
from nvidia.dali.plugin.mxnet import DALIClassificationIterator, LastBatchPolicy
dali_train_iter = DALIClassificationIterator(trainpipes, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL)
dali_val_iter = DALIClassificationIterator(valpipes, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL)

Training with MXNet

Once we have MXNet data iterators from DALIClassificationIterator, we can use them instead of MXNet’smx.io.ImageRecordIter. Here we show modified train_imagenet.py example that uses our DALI pipelines.

[7]:
import os.path
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from resnetn.common import find_mxnet, data, fit
import mxnet as mx

gpus_string = "".join(str(list(range(N)))).replace('[','').replace(']','')

s = ['--gpu', gpus_string,
     '--batch-size', str(batch_size * N),
     '--num-epochs', '1',
     '--data-train', '/data/imagenet/train-480-val-256-recordio/train.rec',
     '--data-val', '/data/imagenet/train-480-val-256-recordio/val.rec',
     '--disp-batches', '100',
     '--network', 'resnet-v1',
     '--num-layers', '50',
     '--data-nthreads', '40',
     '--min-random-scale', '0.533',
     '--max-random-shear-ratio', '0',
     '--max-random-rotate-angle', '0',
     '--max-random-h', '0',
     '--max-random-l', '0',
     '--max-random-s', '0',
     '--dtype', 'float16']

# parse args
parser = argparse.ArgumentParser(description="train imagenet-1k",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
fit.add_fit_args(parser)
data.add_data_args(parser)
data.add_data_aug_args(parser)
# use a large aug level
data.set_data_aug_level(parser, 3)
parser.set_defaults(
        # network
        network          = 'resnet',
        num_layers       = 50,
        # data
        num_classes      = 1000,
        num_examples     = 1281167,
        image_shape      = '3,224,224',
        min_random_scale = 1, # if input image has min size k, suggest to use
                              # 256.0/x, e.g. 0.533 for 480
        # train
        num_epochs       = 80,
        lr_step_epochs   = '30,60',
        dtype            = 'float32'
    )
args = parser.parse_args(s)


# load network
from importlib import import_module
net = import_module('resnetn.symbols.'+args.network)
sym = net.get_symbol(1000, 50, "3,224,224", dtype='float16')

def get_dali_iter(args, kv=None):
    return (dali_train_iter, dali_val_iter)

# train
#fit.fit(args, sym, data.get_rec_iter)
fit.fit(args, sym, get_dali_iter)
INFO:root:start with arguments Namespace(batch_size=1024, benchmark=0, data_nthreads=40, data_train='/data/imagenet/train-480-val-256-recordio/train.rec', data_train_idx='', data_val='/data/imagenet/train-480-val-256-recordio/val.rec', data_val_idx='', disp_batches=100, dtype='float16', gc_threshold=0.5, gc_type='none', gpus='0, 1, 2, 3, 4, 5, 6, 7', image_shape='3,224,224', initializer='default', kv_store='device', load_epoch=None, loss='', lr=0.1, lr_factor=0.1, lr_step_epochs='30,60', macrobatch_size=0, max_random_aspect_ratio=0.25, max_random_h=0, max_random_l=0, max_random_rotate_angle=0, max_random_s=0, max_random_scale=1, max_random_shear_ratio=0.0, min_random_scale=0.533, model_prefix=None, mom=0.9, monitor=0, network='resnet-v1', num_classes=1000, num_epochs=1, num_examples=1281167, num_layers=50, optimizer='sgd', pad_size=0, random_crop=1, random_mirror=1, rgb_mean='123.68,116.779,103.939', test_io=0, top_k=0, warmup_epochs=5, warmup_strategy='linear', wd=0.0001)
INFO:root:Epoch[0] Batch [100]  Speed: 4407.30 samples/sec      accuracy=0.001141
INFO:root:Epoch[0] Batch [200]  Speed: 4444.77 samples/sec      accuracy=0.003184
INFO:root:Epoch[0] Batch [300]  Speed: 4395.88 samples/sec      accuracy=0.006074
INFO:root:Epoch[0] Batch [400]  Speed: 4384.70 samples/sec      accuracy=0.011182
INFO:root:Epoch[0] Batch [500]  Speed: 4389.42 samples/sec      accuracy=0.017441
INFO:root:Epoch[0] Batch [600]  Speed: 4382.10 samples/sec      accuracy=0.026377
INFO:root:Epoch[0] Batch [700]  Speed: 4388.26 samples/sec      accuracy=0.036611
INFO:root:Epoch[0] Batch [800]  Speed: 4383.51 samples/sec      accuracy=0.047139
INFO:root:Epoch[0] Batch [900]  Speed: 4402.73 samples/sec      accuracy=0.057686
INFO:root:Epoch[0] Batch [1000] Speed: 4392.32 samples/sec      accuracy=0.067861
INFO:root:Epoch[0] Batch [1100] Speed: 4384.42 samples/sec      accuracy=0.079248
INFO:root:Epoch[0] Batch [1200] Speed: 4385.37 samples/sec      accuracy=0.090088
INFO:root:Epoch[0] Train-accuracy=0.098537
INFO:root:Epoch[0] Time cost=295.153
WARNING:root:DALI iterator does not support resetting while epoch is not finished. Ignoring...
INFO:root:Epoch[0] Validation-accuracy=0.104393
[ ]: