Working with Conditionals#

NVIDIA TensorRT supports conditional if-then-else flow control. TensorRT conditionals are used to implement conditional execution of network subgraphs.

Defining A Conditional#

Conditional boundary layers define an if-conditional:

  • IConditionLayer represents the predicate and specifies whether the conditional should execute the true-branch (then-branch) or the false-branch (else-branch).

  • IIfConditionalInputLayer specifies an input to one of the two conditional branches.

  • IIfConditionalOutputLayer specifies an output from a conditional.

Each boundary layer inherits from class IIfConditionalBoundaryLayer, which has a method getConditional() for getting its associated IIfConditional. The IIfConditional instance identifies the conditional. All conditional boundary layers with the same IIfConditional belong to that conditional.

A conditional must have exactly one instance of IConditionLayer, zero or more instances of IIfConditionalInputLayer, and at least one instance of IIfConditionalOutputLayer.

IIfConditional implements an if-then-else flow-control construct that provides conditional execution of a network subgraph based on a dynamic boolean input. It is defined by a boolean scalar predicate condition and two branch subgraphs: a trueSubgraph, which is executed when the condition evaluates to true, and a falseSubgraph, which is executed when the condition evaluates to false:

If condition is true then:
    output = trueSubgraph(trueInputs);
Else
    output = falseSubgraph(falseInputs);
Emit output

Both the true branch and the false branch must be defined in a way similar to the ternary operator in many programming languages.

To define an if-conditional, create an IIfConditional instance with INetworkDefinition::addIfConditional, then add the boundary and branch layers.

IIfConditional* simpleIf = network->addIfConditional();

The IIfConditional::setCondition method takes a single argument: the condition tensor. This 0D boolean tensor (scalar) can be computed dynamically by earlier layers in the network. It is used to decide which of the branches to execute. An IConditionLayer has a single input (the condition) and no outputs since it is used internally by the conditional implementation.

// Create a condition predicate that is also a network input.
auto cond = network->addInput("cond", DataType::kBOOL, Dims{0});
IConditionLayer* condition = simpleIf->setCondition(*cond);

TensorRT does not support a subgraph abstraction for implementing conditional branches and instead uses IIfConditionalInputLayer and IIfConditionalOutputLayer to define the boundaries of conditionals.

  • An IIfConditionalInputLayer abstracts a single input to one or both of the branch subgraphs of an IIfConditional. The output of a specific IIfConditionalInputLayer can feed both branches.

    // Create an if-conditional input.
    // x is some arbitrary Network tensor.
    IIfConditionalInputLayer* inputX = simpleIf->addInput(*x);
    

    Inputs to the then-branch and the else-branch do not have to be the same type and shape. Each branch can independently include zero or more inputs.

    IIfConditionalInputLayer is optional and is used to control which layers will be part of the branches (refer to Conditional Execution). If all of a branch’s outputs do not depend on an IIfConditionalInputLayer instance, that branch is empty. An empty else-branch can be useful when there are no layers to evaluate when the condition is false, and the network evaluation should proceed following the conditional (refer to Conditional Examples).

  • An IIfConditionalOutputLayer abstracts a single output of the if-conditional. It has two inputs: an output from the trueSubgraph (input index 0) and an output from the falseSubgraph (input index 1). The output of an IIfConditionalOutputLayer can be considered a placeholder for the final output that will be determined during runtime.

    IIfConditionalOutputLayer serves a role similar to that of a Φ (Phi) function node in traditional SSA control-flow graphs. Its semantics are: choose either the output of the trueSubgraph or falseSubgraph.

    // trueSubgraph and falseSubgraph represent network subgraphs
    IIfConditionalOutputLayer* outputLayer = simpleIf->addOutput(
        *trueSubgraph->getOutput(0),
        *falseSubgraph->getOutput(0));
    

    All outputs of an IIfConditional must be sourced at an IIfConditionalOutputLayer instance.

    An if-conditional without outputs does not affect the rest of the network. Therefore, it is considered ill-formed. Each branch (subgraphs) must also have at least one output. The output of an if-conditional can be marked as the output of the network unless that if-conditional is nested inside another if-conditional or loop.

An if-conditional construct abstract model

Conditional Execution#

Conditional execution of network layers is a network evaluation strategy in which branch layers (the layers belonging to a conditional subgraph) are executed only if the values of the branch outputs are needed. In conditional execution, either the true or false branches are executed and allowed to change the network state.

In contrast, in predicated execution, both the true branch and the false branch are executed, and only one of these is allowed to change the network evaluation state, depending on the value of the condition predicate (that is, only the outputs of one of the subgraphs is fed into the following layers).

Conditional execution is sometimes called lazy evaluation, and predicated execution is sometimes called eager evaluation.

Instances of IIfConditionalInputLayer can be used to specify which layers are invoked eagerly and which are invoked lazily. This is done by tracing the network layers backward, starting with each conditional output. Layers that are data-dependent on the output of at least one IIfConditionalInputLayer are considered internal to the conditional and are therefore evaluated lazily. In the extreme case that no instances of IIfConditionalInputLayer are added to the conditional, all layers are executed eagerly, similarly to ISelectLayer.

The three diagrams below depict how the choice of IIfConditionalInputLayer placement controls execution scheduling.

Controlling conditional-execution using IIfConditionalInputLayer placement

In diagram A, the true branch comprises three layers (T1, T2, T3). These layers execute lazily when the condition evaluates to true.

In diagram B, input-layer I1 is placed after layer T1, which moves T1 out of the true branch. Layer T1 executes eagerly before evaluating the if-construct.

In diagram C, input-layer I1 is removed, which moves T3 outside the conditional. T2’s input is reconfigured to create a legal network, and T2 also moves out of the true branch. When the condition evaluates to true, the conditional does not compute anything since the outputs have already been eagerly computed (but it does copy the conditional relevant inputs to its outputs).

Nesting and Loops#

Conditional branches may nest other conditionals and may also nest loops. Loops may nest conditionals. As in loop nesting, TensorRT infers the nesting of the conditionals and loops from the data flow. For example, if conditional B uses a value defined inside loop A, then B is considered to be nested inside of A.

There can be no cross-edges connecting the true branch to the false branch layers, and vice versa. In other words, the outputs of one branch cannot depend on layers in the other branch.

Refer to the Conditional Examples section for an example of how nesting can be specified.

Limitations#

The number of output tensors in both true/false subgraph branches must be the same. The type and shape of each branch output tensor must be the same.

Note that this is more constrained than the ONNX specification, which requires that the true/false subgraphs have the same number of outputs and use the same output types but allow for different output shapes.

Conditional Examples#

Simple If-Conditional#

The following example shows how to implement a simple conditional that conditionally performs an arithmetic operation on two tensors.

1condition = true
2If condition is true:
3        output = x + y
4Else:
5        output = x - y
 1ITensor* addCondition(INetworkDefinition& n, bool predicate)
 2{
 3    // The condition value is a constant int32 input that is cast to boolean because TensorRT doesn't support boolean constant layers.
 4
 5    static const Dims scalarDims = Dims{0, {}};
 6    static float constexpr zero{0};
 7    static float constexpr one{1};
 8
 9    float* const val = predicate ? &one : &zero;
10
11    ITensor* cond =
12        n.addConstant(scalarDims, DataType::kINT32, val, 1})->getOutput(0);
13
14    auto* cast = n.addIdentity(cond);
15    cast->setOutputType(0, DataType::kBOOL);
16    cast->getOutput(0)->setType(DataType::kBOOL);
17
18    return cast->getOutput(0);
19}
20
21IBuilder* builder = createInferBuilder(gLogger);
22INetworkDefinition& n = *builder->createNetworkV2(0U);
23auto x = n.addInput("x", DataType::kFLOAT, Dims{1, {5}});
24auto y = n.addInput("y", DataType::kFLOAT, Dims{1, {5}});
25ITensor* cond = addCondition(n, true);
26
27auto* simpleIf = n.addIfConditional();
28simpleIf->setCondition(*cond);
29
30// Add input layers to demarcate entry into true/false branches.
31x = simpleIf->addInput(*x)->getOutput(0);
32y = simpleIf->addInput(*y)->getOutput(0);
33
34auto* trueSubgraph = n.addElementWise(*x, *y, ElementWiseOperation::kSUM)->getOutput(0);
35auto* falseSubgraph = n.addElementWise(*x, *y, ElementWiseOperation::kSUB)->getOutput(0);
36
37auto* output = simpleIf->addOutput(*trueSubgraph, *falseSubgraph)->getOutput(0);
38n.markOutput(*output);

Exporting from PyTorch#

The following example shows how to export scripted PyTorch code to ONNX. The code in function sum_even performs an if-conditional nested in a loop.

import torch.onnx
import torch
import tensorrt as trt
import numpy as np

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

@torch.jit.script
def sum_even(items):
    s = torch.zeros(1, dtype=torch.float)
    for c in items:
        if c % 2 == 0:
            s += c
    return s

class ExampleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, items):
        return sum_even(items)

def build_engine(model_file):
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network()
    config = builder.create_builder_config()
    parser = trt.OnnxParser(network, TRT_LOGGER)

    with open(model_file, 'rb') as model:
        assert parser.parse(model.read())
        return builder.build_engine(network, config)

def export_to_onnx():
    items = torch.zeros(4, dtype=torch.float)
    example = ExampleModel()
    torch.onnx.export(example, (items), "example.onnx", verbose=False, opset_version=13, enable_onnx_checker=False, do_constant_folding=True)

export_to_onnx()
build_engine("example.onnx")