{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Manually Constructing a TensorRT Engine\n", "\n", "The Python API provides a path for Python-based frameworks, which might be unsupported by the UFF converter, if they use NumPy compatible layer weights. \n", "\n", "For this example, we will use PyTorch. \n", "\n", "First, we import TensorRT." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorrt as trt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use PyCUDA to transfer data to/from the GPU and NumPy to store data." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pycuda.driver as cuda\n", "import pycuda.autoinit\n", "import numpy as np\n", "from matplotlib.pyplot import imshow # to show test case" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we import PyTorch and its various packages." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", "from torch.autograd import Variable" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training a Model in PyTorch\n", "\n", "For more detailed information about training models in PyTorch, see http://pytorch.org/tutorials/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we define hyper-parameters, then create a dataloader, define our network, set our optimizer and define our training and testing steps." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 64\n", "TEST_BATCH_SIZE = 1000\n", "EPOCHS = 3\n", "LEARNING_RATE = 0.001\n", "SGD_MOMENTUM = 0.5 \n", "SEED = 1\n", "LOG_INTERVAL = 10" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Enable Cuda\n", "torch.cuda.manual_seed(SEED)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Dataloader\n", "kwargs = {'num_workers': 1, 'pin_memory': True}\n", "train_loader = torch.utils.data.DataLoader(\n", " datasets.MNIST('/tmp/mnist/data', train=True, download=True, \n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])),\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " **kwargs)\n", "\n", "test_loader = torch.utils.data.DataLoader(\n", " datasets.MNIST('/tmp/mnist/data', train=False, \n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", " ])),\n", " batch_size=TEST_BATCH_SIZE,\n", " shuffle=True,\n", " **kwargs)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Network\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 20, kernel_size=5)\n", " self.conv2 = nn.Conv2d(20, 50, kernel_size=5)\n", " self.conv2_drop = nn.Dropout2d()\n", " self.fc1 = nn.Linear(800, 500)\n", " self.fc2 = nn.Linear(500, 10)\n", "\n", " def forward(self, x):\n", " x = F.max_pool2d(self.conv1(x), kernel_size=2, stride=2)\n", " x = F.max_pool2d(self.conv2(x), kernel_size=2, stride=2)\n", " x = x.view(-1, 800)\n", " x = F.relu(self.fc1(x))\n", " x = self.fc2(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", "model = Net()\n", "model.cuda()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=SGD_MOMENTUM)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def train(epoch):\n", " model.train()\n", " for batch, (data, target) in enumerate(train_loader):\n", " data, target = data.cuda(), target.cuda()\n", " data, target = Variable(data), Variable(target)\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = F.nll_loss(output, target)\n", " loss.backward()\n", " optimizer.step()\n", " if batch % LOG_INTERVAL == 0:\n", " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'\n", " .format(epoch, \n", " batch * len(data), \n", " len(train_loader.dataset), \n", " 100. * batch / len(train_loader), \n", " loss.data.item()))\n", "\n", "def test(epoch):\n", " model.eval()\n", " test_loss = 0\n", " correct = 0\n", " for data, target in test_loader:\n", " data, target = data.cuda(), target.cuda()\n", " with torch.no_grad():\n", " data, target = Variable(data), Variable(target)\n", " output = model(data)\n", " test_loss += F.nll_loss(output, target).data.item()\n", " pred = output.data.max(1)[1]\n", " correct += pred.eq(target.data).cpu().sum()\n", " test_loss /= len(test_loader)\n", " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'\n", " .format(test_loss, \n", " correct, \n", " len(test_loader.dataset), \n", " 100. * correct / len(test_loader.dataset)))\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we train the model." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "for e in range(EPOCHS):\n", " train(e + 1)\n", " test(e + 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Converting the Model into a TensorRT Engine\n", "Now that we have a \"trained\" model, we extract the layer wieghts by getting the `state_dict`" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "weights = model.state_dict()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we create a builder and a logger for the build process." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.ERROR)\n", "builder = trt.infer.create_infer_builder(G_LOGGER)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we replicate the network structure above in TensorRT and extract the weights from PyTorch in the form of numpy arrays. The numpy arrays from PyTorch reflect the dimensionality of the layers, so we flatten the arrays" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "network = builder.create_network()\n", "\n", "# Name for the input layer, data type, tuple for dimension \n", "data = network.add_input(\"data\", trt.infer.DataType.FLOAT, (1, 28, 28))\n", "assert(data)\n", "\n", "#-------------\n", "conv1_w = weights['conv1.weight'].cpu().numpy().reshape(-1)\n", "conv1_b = weights['conv1.bias'].cpu().numpy().reshape(-1)\n", "conv1 = network.add_convolution(data, 20, (5,5), conv1_w, conv1_b)\n", "assert(conv1)\n", "conv1.set_stride((1,1))\n", "\n", "#-------------\n", "pool1 = network.add_pooling(conv1.get_output(0), trt.infer.PoolingType.MAX, (2,2))\n", "assert(pool1)\n", "pool1.set_stride((2,2))\n", "\n", "#-------------\n", "conv2_w = weights['conv2.weight'].cpu().numpy().reshape(-1)\n", "conv2_b = weights['conv2.bias'].cpu().numpy().reshape(-1)\n", "conv2 = network.add_convolution(pool1.get_output(0), 50, (5,5), conv2_w, conv2_b)\n", "assert(conv2)\n", "conv2.set_stride((1,1))\n", "\n", "#-------------\n", "pool2 = network.add_pooling(conv2.get_output(0), trt.infer.PoolingType.MAX, (2,2))\n", "assert(pool2)\n", "pool2.set_stride((2,2))\n", "\n", "#-------------\n", "fc1_w = weights['fc1.weight'].cpu().numpy().reshape(-1)\n", "fc1_b = weights['fc1.bias'].cpu().numpy().reshape(-1)\n", "fc1 = network.add_fully_connected(pool2.get_output(0), 500, fc1_w, fc1_b)\n", "assert(fc1)\n", "\n", "#-------------\n", "relu1 = network.add_activation(fc1.get_output(0), trt.infer.ActivationType.RELU)\n", "assert(relu1)\n", "\n", "#-------------\n", "fc2_w = weights['fc2.weight'].cpu().numpy().reshape(-1)\n", "fc2_b = weights['fc2.bias'].cpu().numpy().reshape(-1)\n", "fc2 = network.add_fully_connected(relu1.get_output(0), 10, fc2_w, fc2_b)\n", "assert(fc2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we need to mark our output layer." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "fc2.get_output(0).set_name(\"prob\")\n", "network.mark_output(fc2.get_output(0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We set the rest of the parameters for the network (max batch size and max workspace) and build the engine." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "builder.set_max_batch_size(1)\n", "builder.set_max_workspace_size(1 << 20)\n", "\n", "engine = builder.build_cuda_engine(network)\n", "network.destroy()\n", "builder.destroy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we create the engine runtime and generate a test case from the torch dataloader." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "runtime = trt.infer.create_infer_runtime(G_LOGGER)\n", "img, target = next(iter(test_loader))\n", "img = img.numpy()[0]\n", "target = target.numpy()[0]\n", "%matplotlib inline\n", "img.shape\n", "imshow(img[0])\n", "print(\"Test Case: \" + str(target))\n", "img = img.ravel()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we create an execution context for the engine." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "context = engine.create_execution_context()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we allocate memory on the GPU, as well as on the host to hold results after inference. The size of these allocations is the size of the input/expected output * the batch size." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "output = np.empty(10, dtype = np.float32)\n", "\n", "# Allocate device memory\n", "d_input = cuda.mem_alloc(1 * img.nbytes)\n", "d_output = cuda.mem_alloc(1 * output.nbytes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The engine requires bindings (pointers to GPU memory). PyCUDA lets us do this by casting the results of memory allocations to ints." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "bindings = [int(d_input), int(d_output)] " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create a cuda stream to run inference." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "stream = cuda.Stream()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we transfer the data to the GPU, run inference, then transfer the results to the host." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Transfer input data to device\n", "cuda.memcpy_htod_async(d_input, img, stream)\n", "#execute model \n", "context.enqueue(1, bindings, stream.handle, None)\n", "# Transfer predictions back\n", "cuda.memcpy_dtoh_async(output, d_output, stream)\n", "# Synchronize threads\n", "stream.synchronize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use `np.argmax` to get a prediction." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "print(\"Test Case: \" + str(target))\n", "print (\"Prediction: \" + str(np.argmax(output)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also save our engine to a file to use later." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "trt.utils.write_engine_to_file(\"./pyt_mnist.engine\", engine.serialize()) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can load this engine later by using `tensorrt.utils.load_engine`." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "new_engine = trt.utils.load_engine(G_LOGGER, \"./pyt_mnist.engine\") " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we clean up our context, engine and runtime." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "context.destroy()\n", "engine.destroy()\n", "new_engine.destroy()\n", "runtime.destroy()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 2 }