WarpAffine¶
In this notebook you’ll learn how to use WarpAffine operator.
Introduction¶
Warp operators¶
All warp operators work by caclulating the output pixels by sampling the source image at transformed coordinates:
This way each output pixel is calculated exactly once.
If the source coordinates do not point exactly to pixel centers, the values of neighboring pixels will be interpolated or the nearst pixel is taken, depending on the interpolation method specified in the interp_type argument.
Affine transform¶
The source sample coordinates \(x_{src}, y_{src}\) are calculated according to the formula:
Where \(x, y\) are coordinates of the destination pixel and the matrix represents the inverse (destination to source) affine transform. The \(\begin{vmatrix} m_{00} & m_{01} \\ m_{10} & m_{11} \end{vmatrix}\) block represents a combined rotate/scale/shear transform and \(t_x, t_y\) is a translation vector.
Usage example¶
First, let’s import the necessary modules and define the location of the dataset.
DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.
[1]:
from __future__ import division
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import numpy as np
import matplotlib.pyplot as plt
import math
import os.path
test_data_root = os.environ['DALI_EXTRA_PATH']
db_folder = os.path.join(test_data_root, 'db', 'lmdb')
The functions below define affine transofm matrices for a batch of images. Each image receives its own transform. The transform matrices are expected to be a tensor list of shape \(batch\_size \times 2 \times 3\)
[2]:
def random_transform(index):
    dst_cx, dst_cy = (200,200)
    src_cx, src_cy = (200,200)
    # This function uses homogeneous coordinates - hence, 3x3 matrix
    # translate output coordinates to center defined by (dst_cx, dst_cy)
    t1 = np.array([[1, 0, -dst_cx],
                   [0, 1, -dst_cy],
                   [0, 0, 1]])
    def u():
        return np.random.uniform(-0.5, 0.5)
    # apply a randomized affine transform - uniform scaling + some random distortion
    m = np.array([
        [1 + u(),     u(),  0],
        [    u(), 1 + u(),  0],
        [      0,       0,  1]])
    # translate input coordinates to center (src_cx, src_cy)
    t2 = np.array([[1, 0, src_cx],
                   [0, 1, src_cy],
                   [0, 0, 1]])
    # combine the transforms
    m = (np.matmul(t2, np.matmul(m, t1)))
    # remove the last row; it's not used by affine transform
    return m[0:2,0:3]
def gen_transforms(batch_size, single_transform_fn):
    out = np.zeros([batch_size, 2, 3])
    for i in range(batch_size):
        out[i,:,:] = single_transform_fn(i)
    return out.astype(np.float32)
np.random.seed(seed = 123)
Now, let’s define the pipeline. It will apply the same transform to an image, with slightly different options.
The first variant executes on GPU and uses fixed output size and linear interpolation. It does not specify any fill value, in which case out-of-bounds destination coordinates are clamped to valid range.
The second executes on CPU and uses a fill_value argument which replaces the out-of-bounds source pixels with that value.
The last one executes on GPU and does not specify a new size, which keeps original image size.
[3]:
class ExamplePipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, pipelined = True, exec_async = True):
        super(ExamplePipeline, self).__init__(
            batch_size, num_threads, device_id,
            seed = 12, exec_pipelined=pipelined, exec_async=exec_async)
        # The reader reads raw files from some storage - in this case, a Caffe LMDB container
        self.input = ops.CaffeReader(path = db_folder, random_shuffle = True)
        # The decoder takes tensors containing raw files and outputs images
        # as 3D tensors with HWC layout
        self.decode = ops.ImageDecoder(device = "cpu", output_type = types.RGB)
        # This example uses ExternalSource to provide warp matrices
        self.transform_source = ops.ExternalSource()
        self.iter = 0
        self.warp_gpu = ops.WarpAffine(
            device = "gpu",
            size = (400,400),                 # specify the output size
          # fill_value                        # not specifying `fill_value`
                                              #   results in source coordinate clamping
            interp_type = types.INTERP_LINEAR # use linear interpolation
        )
        self.warp_cpu = ops.WarpAffine(
            device = "cpu",
            fill_value = 200,
            size = (400,400),                 # specify the output size
            interp_type = types.INTERP_NN     # use nearest neighbor interpolation
        )
        self.warp_keep_size = ops.WarpAffine(
            device = "gpu",
          # size                              # keep original canvas size
            interp_type = types.INTERP_LINEAR # use linear interpolation
        )
    # Then, we can tie the operators together to form a graph
    def define_graph(self):
        self.transform = self.transform_source()
        self.jpegs, self.labels = self.input()
        images = self.decode(self.jpegs)
        outputs = [images.gpu()]
        # pass the transform parameters through GPU memory
        outputs += [self.warp_gpu(images.gpu(), self.transform.gpu())]
        # pass the transform through a named input
        outputs += [self.warp_cpu(images, matrix = self.transform).gpu()]
        outputs += [self.warp_keep_size(images.gpu(), self.transform.gpu())]
        return [self.labels, self.transform] + outputs
    # Since we're using ExternalSource, we need to feed the externally provided data to the pipeline
    def iter_setup(self):
        # Generate the transforms for the batch and feed them to the ExternalSource
        self.feed_input(self.transform, gen_transforms(self.batch_size, random_transform))
The pipeline class is now ready to use - we need to construct and build it before we run it.
[4]:
batch_size = 32
pipe = ExamplePipeline(batch_size=batch_size, num_threads=2, device_id = 0)
pipe.build()
Finally, we can call run on our pipeline to obtain the first batch of preprocessed images.
[5]:
pipe_out = pipe.run()
Example output¶
Now that we’ve processed the first batch of images, let’s see the results:
[6]:
n = 0  # change this value to see other images from the batch;
       # it must be in 0..batch_size-1 range
from synsets import imagenet_synsets
import matplotlib.gridspec as gridspec
len_outputs = len(pipe_out) - 2
captions = ["original",
            "warp GPU (linear, border clamp)",
            "warp CPU (nearest, fill)", "warp GPU (keep canvas size)"]
fig = plt.figure(figsize = (16,12))
plt.suptitle(imagenet_synsets[pipe_out[0].at(n)[0]], fontsize=16)
columns = 2
rows = int(math.ceil(len_outputs / columns))
gs = gridspec.GridSpec(rows, columns)
print("Affine transform matrix:")
print(pipe_out[1].at(n))
for i in range(len_outputs):
    plt.subplot(gs[i])
    plt.axis("off")
    plt.title(captions[i])
    pipe_out_cpu = pipe_out[2 + i].as_cpu()
    img_chw = pipe_out_cpu.at(n)
    plt.imshow((img_chw)/255.0)
Affine transform matrix:
[[ 1.1964692  -0.21386066  3.4782958 ]
 [-0.27314854  1.0513147  44.366756  ]]
