{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Data Loading: TensorFlow TFRecord\n", "\n", "## Overview\n", "\n", "This example shows you how to use the data that is stored in the TensorFlow TFRecord format with DALI.\n", "\n", "## Creating index\n", "\n", "To use data that is stored in the TFRecord format, we need to use the `readers.TFRecord` operator. In addition to the arguments that are common to all readers, such as `random_shuffle`, this operator takes `path`, `index_path` and `features` arguments.\n", "\n", "* `path` is a list of paths to the TFRecord files\n", "* `index_path` is a list that contains the paths to index files, which are used by DALI mainly to properly shard the dataset between multiple workers. The index for a TFRecord file can be obtained from that file by using the `tfrecord2idx` utility that is included with DALI. You need to create the index file only once per TFRecord file.\n", "* `features` is a dictionary of pairs (name, feature), where feature (of type `dali.tfrecord.Feature`) describes the contents of the TFRecord. DALI features closely follow the TensorFlow types `tf.FixedLenFeature` and `tf.VarLenFeature`.\n", "\n", "The `DALI_EXTRA_PATH` environment variable should point to the location where data from [DALI extra repository](https://github.com/NVIDIA/DALI_extra) is downloaded.\n", "\n", "**Important**: Ensure that you check out the correct release tag that corresponds to the installed version of DALI." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from subprocess import call\n", "import os.path\n", "\n", "test_data_root = os.environ['DALI_EXTRA_PATH']\n", "tfrecord = os.path.join(test_data_root, 'db', 'tfrecord', 'train')\n", "batch_size = 16\n", "tfrecord_idx = \"idx_files/train.idx\"\n", "tfrecord2idx_script = \"tfrecord2idx\"\n", "\n", "if not os.path.exists(\"idx_files\"):\n", " os.mkdir(\"idx_files\")\n", "\n", "if not os.path.isfile(tfrecord_idx):\n", " call([tfrecord2idx_script, tfrecord, tfrecord_idx])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining and Running the Pipeline\n", "\n", "1. Define a simple pipeline that takes the images stored in TFRecord format, decodes them, and prepares them for ingestion in DL framework.\n", "\n", " Processing images involves cropping, normalizing, and `HWC` -> `CHW` conversion process.\n", "\n", "The TFRecord file that we used in this example does not have images upscaled to a common size. This results in an error during cropping, when the image is smaller than the crop window. To overcome this issue, use the `Resize` operation before you crop. This step ensures that the shorter side of images being cropped is 256 pixels." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from nvidia.dali.pipeline import Pipeline\n", "import nvidia.dali.fn as fn\n", "import nvidia.dali.types as types\n", "import nvidia.dali.tfrecord as tfrec\n", "import numpy as np\n", "\n", "pipe = Pipeline(batch_size=batch_size, num_threads=4, device_id=0)\n", "with pipe:\n", " inputs = fn.readers.tfrecord(\n", " path=tfrecord, \n", " index_path=tfrecord_idx,\n", " features={\n", " \"image/encoded\" : tfrec.FixedLenFeature((), tfrec.string, \"\"),\n", " \"image/class/label\": tfrec.FixedLenFeature([1], tfrec.int64, -1),\n", " \"image/class/text\": tfrec.FixedLenFeature([ ], tfrec.string, \"\"),\n", " \"image/object/bbox/xmin\": tfrec.VarLenFeature(tfrec.float32, 0.0),\n", " \"image/object/bbox/ymin\": tfrec.VarLenFeature(tfrec.float32, 0.0),\n", " \"image/object/bbox/xmax\": tfrec.VarLenFeature(tfrec.float32, 0.0),\n", " \"image/object/bbox/ymax\": tfrec.VarLenFeature(tfrec.float32, 0.0)})\n", " jpegs = inputs[\"image/encoded\"]\n", " images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.RGB)\n", " resized = fn.resize(images, device=\"gpu\", resize_shorter=256.)\n", " output = fn.crop_mirror_normalize(\n", " resized,\n", " dtype=types.FLOAT,\n", " crop=(224, 224),\n", " mean=[0., 0., 0.],\n", " std=[1., 1., 1.])\n", " pipe.set_outputs(output, inputs[\"image/class/text\"])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "2. Build and run our the pipeline:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "pipe.build()\n", "pipe_out = pipe.run()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "3. To visualize the results, use the `matplotlib` library, which expects images in `HWC` format, but the output of the pipeline is in `CHW`.\n", "\n", " **Note**: `CHW` is the preferred format for most Deep Learning frameworks.\n", " \n", "4. For the visualization purposes, transpose the images back to the `HWC` layout." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import matplotlib.gridspec as gridspec\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "def show_images(image_batch, labels):\n", " columns = 4\n", " rows = (batch_size + 1) // (columns)\n", " fig = plt.figure(figsize = (32,(32 // columns) * rows))\n", " gs = gridspec.GridSpec(rows, columns)\n", " for j in range(rows*columns):\n", " plt.subplot(gs[j])\n", " plt.axis(\"off\")\n", " ascii = labels.at(j)\n", " plt.title(\"\".join([chr(item) for item in ascii]))\n", " img_chw = image_batch.at(j)\n", " img_hwc = np.transpose(img_chw, (1,2,0))/255.0\n", " plt.imshow(img_hwc)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": false }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\n\n\n\n \n \n \n \n 2020-11-20T21:29:08.168470\n image/svg+xml\n \n \n Matplotlib v3.3.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n