OneHot

Computes one-hot encoding based on the given indices, depth, and axis. Sets onValue on every index defined by the indices input and offValue in all other locations.

Attributes

axis: a scalar specifying which dimension of the output one-hot encoding is added. Valid range [-rank(Indices)-1, rank(Indices)]

Inputs

indices: tensor of type int32 containing the hot indices

values: tensor of type T containing offValue and onValue

depth: tensor of type int32 the depth of the encoding

Outputs

output: tensor of type T

Data Types

T: int8, int32, int64, float16, float32, bfloat16

Shape Information

indices is a tensor of shape \([A_0,...,A_n]\)

values is a two-element (rank=1) tensor that consists of [offValue, onValue]

depth is a shape tensor of rank 0 \([d]\)

output is a tensor of shape \([A_0,...,A_{axis-1},d,A_{axis+1},...,A_n]\)

Volume Limits

indices and output can have up to \(2^{31}-1\) elements.

Examples

OneHot
inInd = network.add_input("indices", dtype=trt.int32, shape=(10,))
inVals = network.add_input("values", dtype=trt.float32, shape=(2,))
inDepth = network.add_input("depth", dtype=trt.int32, shape=())
axis = -1
depthVal = 4

opt_profile = get_runner.builder.create_optimization_profile()
opt_profile.set_shape_input("depth", [depthVal], [depthVal], [depthVal])
get_runner.config.add_optimization_profile(opt_profile)

layer = network.add_one_hot(inInd, inVals, inDepth, axis)
network.mark_output(layer.get_output(0))

inputs[inInd.name] = np.array([0, 1, 2, 3, 4, -1, -2, -3, -4, -5], dtype=np.int32)
inputs[inVals.name] = np.array([0.0, 3.14], dtype=np.float32)
inputs[inDepth.name] = np.array([depthVal], dtype=np.int32)

outputs[layer.get_output(0).name] = layer.get_output(0).shape
expected[layer.get_output(0).name] = np.array(
    [
        [3.14, 0.0, 0.0, 0.0],
        [0.0, 3.14, 0.0, 0.0],
        [0.0, 0.0, 3.14, 0.0],
        [0.0, 0.0, 0.0, 3.14],
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 3.14],
        [0.0, 0.0, 3.14, 0.0],
        [0.0, 3.14, 0.0, 0.0],
        [3.14, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0],
    ],
    dtype=np.float32,
)

C++ API

For more information about the C++ IOneHotLayer operator, refer to the C++ IOneHotLayer documentation.

Python API

For more information about the Python IOneHotLayer operator, refer to the Python IOneHotLayer documentation.