Gluon example with DALI

Overview

This is a modified DCGAN example, which uses DALI for reading and augmenting images.

Sample

[1]:
import os.path
import matplotlib as mpl
import tarfile
import matplotlib.image as mpimg
from matplotlib import pyplot as plt

import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd
from mxnet.gluon import nn, utils
from mxnet import autograd
import numpy as np
[2]:
epochs = 10 # Set low by default for tests, set higher when you actually run this code.
batch_size = 64
latent_z_size = 100

use_gpu = True
ctx = mx.gpu() if use_gpu else mx.cpu()

lr = 0.0002
beta1 = 0.5
[3]:
lfw_url = 'http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz'
data_path = 'lfw_dataset'
if not os.path.exists(data_path):
    os.makedirs(data_path)
    data_file = utils.download(lfw_url)
    with tarfile.open(data_file) as tar:
        tar.extractall(path=data_path)
[4]:
target_wd = 64
target_ht = 64
[5]:
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import numpy as np


@pipeline_def
def gluon_pipe():
    jpegs, labels = fn.readers.file(
        name='Reader', file_root=data_path + '/lfw-deepfunneled/', random_shuffle = True, pad_last_batch=True)
    images = fn.decoders.image(jpegs, device='mixed')
    images = fn.resize(
        images, resize_x=target_wd, resize_y=target_ht, interp_type=types.INTERP_LINEAR)
    images = fn.rotate(images, angle=fn.random.uniform(range=(-10., 10)))
    images = fn.crop_mirror_normalize(
        images,
        dtype=types.FLOAT,
        crop=(target_wd, target_ht),
        mean=[127.5, 127.5, 127.5],
        std=[127.5, 127.5, 127.5])
    return images

[6]:
pipe = gluon_pipe(batch_size=batch_size, num_threads=4, device_id=0)
pipe.build()
[7]:
pipe_out = pipe.run()
[8]:
pipe_out_cpu = pipe_out[0].as_cpu()
img_chw = pipe_out_cpu.at(20)
%matplotlib inline
plt.imshow((np.transpose(img_chw, (1,2,0))+1.0)/2.0)
[8]:
<matplotlib.image.AxesImage at 0x7f07857e4640>
../../../_images/examples_frameworks_mxnet_gluon_8_1.png
[9]:
from nvidia.dali.plugin.mxnet import DALIGenericIterator, LastBatchPolicy
# recreate pipeline to avoid mixing simple with iterator API
pipe = gluon_pipe(batch_size=batch_size, num_threads=4, device_id=0)
pipe.build()
dali_iter = DALIGenericIterator(
    pipe,
    [("data", DALIGenericIterator.DATA_TAG)],
    reader_name="Reader",
    last_batch_policy=LastBatchPolicy.PARTIAL)
[10]:
# build the generator
nc = 3
ngf = 64
netG = nn.Sequential()
with netG.name_scope():
    # input is Z, going into a convolution
    netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))
    netG.add(nn.BatchNorm())
    netG.add(nn.Activation('relu'))
    # state size. (ngf*8) x 4 x 4
    netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))
    netG.add(nn.BatchNorm())
    netG.add(nn.Activation('relu'))
    # state size. (ngf*8) x 8 x 8
    netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))
    netG.add(nn.BatchNorm())
    netG.add(nn.Activation('relu'))
    # state size. (ngf*8) x 16 x 16
    netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))
    netG.add(nn.BatchNorm())
    netG.add(nn.Activation('relu'))
    # state size. (ngf*8) x 32 x 32
    netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))
    netG.add(nn.Activation('tanh'))
    # state size. (nc) x 64 x 64

# build the discriminator
ndf = 64
netD = nn.Sequential()
with netD.name_scope():
    # input is (nc) x 64 x 64
    netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))
    netD.add(nn.LeakyReLU(0.2))
    # state size. (ndf) x 32 x 32
    netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))
    netD.add(nn.BatchNorm())
    netD.add(nn.LeakyReLU(0.2))
    # state size. (ndf) x 16 x 16
    netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))
    netD.add(nn.BatchNorm())
    netD.add(nn.LeakyReLU(0.2))
    # state size. (ndf) x 8 x 8
    netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))
    netD.add(nn.BatchNorm())
    netD.add(nn.LeakyReLU(0.2))
    # state size. (ndf) x 4 x 4
    netD.add(nn.Conv2D(1, 4, 1, 0, use_bias=False))
[11]:
# loss
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()

# initialize the generator and the discriminator
netG.initialize(mx.init.Normal(0.02), ctx=ctx)
netD.initialize(mx.init.Normal(0.02), ctx=ctx)

# trainer for the generator and the discriminator
trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
[12]:
from datetime import datetime
import time
import logging

real_label = nd.ones((batch_size,), ctx=ctx)
fake_label = nd.zeros((batch_size,),ctx=ctx)

def facc(label, pred):
    pred = pred.ravel()
    label = label.ravel()
    return ((pred > 0.5) == label).mean()
metric = mx.metric.CustomMetric(facc)

stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')
logging.basicConfig(level=logging.DEBUG)

for epoch in range(epochs):
    tic = time.time()
    btic = time.time()
    iter = 0
    for batches in dali_iter:  # Using DALI iterator
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        data = batches[0].data[0]  # extracting the batch for device 0
        latent_z = mx.nd.random_normal(0, 1, shape=(batch_size, latent_z_size, 1, 1), ctx=ctx)

        with autograd.record():
            # train with real image
            output = netD(data).reshape((-1, 1))
            errD_real = loss(output, real_label)
            metric.update([real_label,], [output,])

            # train with fake image
            fake = netG(latent_z)
            output = netD(fake.detach()).reshape((-1, 1))
            errD_fake = loss(output, fake_label)
            errD = errD_real + errD_fake
            errD.backward()
            metric.update([fake_label,], [output,])

        trainerD.step(data.shape[0])

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        with autograd.record():
            fake = netG(latent_z)
            output = netD(fake).reshape((-1, 1))
            errG = loss(output, real_label)
            errG.backward()

        trainerG.step(data.shape[0])

        # Print log infomation every ten batches
        if iter % 100 == 0:
            name, acc = metric.get()
            logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
            logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
                     %(nd.mean(errD).asscalar(),
                       nd.mean(errG).asscalar(), acc, iter, epoch))
        iter = iter + 1
        btic = time.time()
    dali_iter.reset()

    name, acc = metric.get()
    metric.reset()
INFO:root:speed: 594.133498594542 samples/s
INFO:root:discriminator loss = 1.548318, generator loss = 6.110404, binary training acc = 0.554688 at iter 0 epoch 0
INFO:root:speed: 1864.6789758123898 samples/s
INFO:root:discriminator loss = 1.875944, generator loss = 10.167295, binary training acc = 0.884360 at iter 100 epoch 0
INFO:root:speed: 1925.027473197318 samples/s
INFO:root:discriminator loss = 1.383538, generator loss = 11.330649, binary training acc = 0.861046 at iter 200 epoch 0
INFO:root:speed: 1916.2053295452113 samples/s
INFO:root:discriminator loss = 0.729032, generator loss = 4.568172, binary training acc = 0.812500 at iter 0 epoch 1
INFO:root:speed: 1879.8790985615642 samples/s
INFO:root:discriminator loss = 0.659974, generator loss = 9.797787, binary training acc = 0.876702 at iter 100 epoch 1
INFO:root:speed: 1831.536308618137 samples/s
INFO:root:discriminator loss = 0.543448, generator loss = 8.295659, binary training acc = 0.876049 at iter 200 epoch 1
INFO:root:speed: 1899.7151935910774 samples/s
INFO:root:discriminator loss = 0.256711, generator loss = 5.445935, binary training acc = 0.976562 at iter 0 epoch 2
INFO:root:speed: 1952.044911464204 samples/s
INFO:root:discriminator loss = 0.231668, generator loss = 5.553924, binary training acc = 0.882890 at iter 100 epoch 2
INFO:root:speed: 1897.901949971012 samples/s
INFO:root:discriminator loss = 0.370034, generator loss = 6.385639, binary training acc = 0.894123 at iter 200 epoch 2
INFO:root:speed: 1894.9009332072114 samples/s
INFO:root:discriminator loss = 0.619546, generator loss = 7.575753, binary training acc = 0.898438 at iter 0 epoch 3
INFO:root:speed: 1860.8140748802484 samples/s
INFO:root:discriminator loss = 0.346177, generator loss = 7.045583, binary training acc = 0.916615 at iter 100 epoch 3
INFO:root:speed: 1883.8369054135612 samples/s
INFO:root:discriminator loss = 0.223433, generator loss = 5.369461, binary training acc = 0.916123 at iter 200 epoch 3
INFO:root:speed: 1907.6398988032633 samples/s
INFO:root:discriminator loss = 1.107040, generator loss = 12.463382, binary training acc = 0.734375 at iter 0 epoch 4
INFO:root:speed: 1817.1910100189548 samples/s
INFO:root:discriminator loss = 0.434032, generator loss = 9.117072, binary training acc = 0.880647 at iter 100 epoch 4
INFO:root:speed: 1825.7934486886495 samples/s
INFO:root:discriminator loss = 0.400421, generator loss = 4.347582, binary training acc = 0.883590 at iter 200 epoch 4
INFO:root:speed: 1824.0937204830084 samples/s
INFO:root:discriminator loss = 0.238368, generator loss = 4.456389, binary training acc = 0.984375 at iter 0 epoch 5
INFO:root:speed: 1924.4478410174424 samples/s
INFO:root:discriminator loss = 0.269836, generator loss = 4.265233, binary training acc = 0.878481 at iter 100 epoch 5
INFO:root:speed: 1889.7384423684784 samples/s
INFO:root:discriminator loss = 0.418730, generator loss = 4.249841, binary training acc = 0.890081 at iter 200 epoch 5
INFO:root:speed: 1938.7360590499716 samples/s
INFO:root:discriminator loss = 0.420917, generator loss = 6.589997, binary training acc = 0.945312 at iter 0 epoch 6
INFO:root:speed: 1878.9843065335779 samples/s
INFO:root:discriminator loss = 0.200992, generator loss = 4.738103, binary training acc = 0.882039 at iter 100 epoch 6
INFO:root:speed: 1911.0078879175328 samples/s
INFO:root:discriminator loss = 0.275492, generator loss = 4.333203, binary training acc = 0.893151 at iter 200 epoch 6
INFO:root:speed: 1927.3771746544605 samples/s
INFO:root:discriminator loss = 1.035967, generator loss = 3.340651, binary training acc = 0.664062 at iter 0 epoch 7
INFO:root:speed: 1953.3233108968527 samples/s
INFO:root:discriminator loss = 0.311858, generator loss = 3.603451, binary training acc = 0.879486 at iter 100 epoch 7
INFO:root:speed: 1895.7573977033574 samples/s
INFO:root:discriminator loss = 0.267880, generator loss = 3.750516, binary training acc = 0.878304 at iter 200 epoch 7
INFO:root:speed: 1954.0342565969063 samples/s
INFO:root:discriminator loss = 0.234087, generator loss = 3.313942, binary training acc = 0.968750 at iter 0 epoch 8
INFO:root:speed: 1883.2950222752306 samples/s
INFO:root:discriminator loss = 0.334708, generator loss = 3.301822, binary training acc = 0.868967 at iter 100 epoch 8
INFO:root:speed: 1883.3743027734706 samples/s
INFO:root:discriminator loss = 0.510033, generator loss = 2.656876, binary training acc = 0.875894 at iter 200 epoch 8
INFO:root:speed: 2009.232385985135 samples/s
INFO:root:discriminator loss = 0.412673, generator loss = 3.922147, binary training acc = 0.929688 at iter 0 epoch 9
INFO:root:speed: 1885.2260778571379 samples/s
INFO:root:discriminator loss = 0.402779, generator loss = 3.599171, binary training acc = 0.877398 at iter 100 epoch 9
INFO:root:speed: 1917.683766850742 samples/s
INFO:root:discriminator loss = 0.349911, generator loss = 4.183607, binary training acc = 0.882618 at iter 200 epoch 9
[13]:
def visualize(img_arr):
    plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
    plt.axis('off')
[14]:
num_image = 8
fig = plt.figure(figsize = (16,8))
for i in range(num_image):
    latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)
    img = netG(latent_z)
    plt.subplot(2,4,i+1)
    visualize(img[0])
plt.show()
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
DEBUG:matplotlib.axes._base:update_title_pos
../../../_images/examples_frameworks_mxnet_gluon_14_1.png
[ ]: