If¶
Generates a conditional execution of network subgraphs. The true
and false
subgraphs aren’t explicitly used to define the operator but instead represent sets of input and output tensors.
See also
Inputs¶
condition a tensor of type T1
inputs tensors of type T2
Outputs¶
outputs tensors of type T2
Data Types¶
T1: bool
T2: bool
, int32
, float16
, float32
, bfloat16
Shape Information¶
condition is a scalar (zero-dimensional tensor).
inputs the number of input tensors and their shapes can be different for each of the subgraphs.
outputs the number of output tensors must be the same. For each pair of corresponding outputs, their shapes must be equal unless the condition is a build-time constant.
Examples¶
If
condition = network.add_input(name="condition", shape=(), dtype=trt.bool)
true_inp = network.add_input(name="true_input", shape=(1, 1), dtype=trt.float32)
false_inp = network.add_input(name="false_input", shape=(1, 1), dtype=trt.float32)
conditional = network.add_if_conditional()
conditional.set_condition(condition)
true_cond_inp = conditional.add_input(true_inp)
false_cond_inp = conditional.add_input(false_inp)
output = conditional.add_output(true_cond_inp.get_output(0), false_cond_inp.get_output(0))
network.mark_output(output.get_output(0))
inputs[condition.name] = np.array(True)
inputs[true_inp.name] = np.array([5.0])
inputs[false_inp.name] = np.array([0.0])
outputs[output.get_output(0).name] = output.get_output(0).shape
expected[output.get_output(0).name] = np.array([5.0])
If with ElementWise Subgraphs
condition = network.add_input("condition", dtype=trt.bool, shape=())
in1 = network.add_input(name="input1", shape=(2, 2), dtype=trt.float32)
in2 = network.add_input(name="input2", shape=(1, 2), dtype=trt.float32)
conditional = network.add_if_conditional()
conditional.set_condition(condition)
cond_inp1 = conditional.add_input(in1)
cond_inp2 = conditional.add_input(in2)
true_elemwise = network.add_elementwise(cond_inp1.get_output(0), cond_inp2.get_output(0), op=trt.ElementWiseOperation.PROD)
false_elemwise = network.add_elementwise(cond_inp1.get_output(0), cond_inp2.get_output(0), op=trt.ElementWiseOperation.SUM)
output = conditional.add_output(true_elemwise.get_output(0), false_elemwise.get_output(0))
network.mark_output(output.get_output(0))
inputs[condition.name] = np.array(False)
inputs[in1.name] = np.array(
[
[5.0, 7.8],
[-3.2, 4.6],
]
)
inputs[in2.name] = np.array(
[
[1.0, -1.0],
]
)
outputs[output.get_output(0).name] = output.get_output(0).shape
expected[output.get_output(0).name] = np.array([[6.0, 6.8], [-2.2, 3.6]])
C++ API¶
For more information about the C++ IConditionalLayer operator, refer to the C++ IConditionalLayer documentation.
Python API¶
For more information about the Python IConditionalLayer operator, refer to the Python IConditionalLayer documentation.