Tensor joining

This notebook demonstrates two methods of joining tensors: stacking and concatenation.

Both of these operations take multiple inputs and produce the output by joining the input tensors. The difference between these methods is that concatenation joins the tensors along an existing axis, whereas stacking inserts a new axis. Stacking can be used, for example, to combine separate coordinates into vectors, or to combine color planes into color images. Concatenation can be used, among other applications, for joining tiles into a larger image or appending lists.

Concatenation

In this section we’ll demonstrate concatenation along different axes. Since we’ll be concatenating the same tensors along different axes, it is required that the tensors have identical shapes.

[1]:
import nvidia.dali as dali
import nvidia.dali.fn as fn
import numpy as np

np.random.seed(1234)

arr = np.array([
    [[1,2,3,4],
     [5,6,7,8],
     [9,10,11,12]],
    [[13,14,15,16],
     [17,18,19,20],
     [21,22,23,24]]
])

src1 = dali.types.Constant(arr)
src2 = dali.types.Constant(arr + 100)
src3 = dali.types.Constant(arr + 200)

pipe_cat = dali.pipeline.Pipeline(batch_size = 1, num_threads = 3, device_id = 0)
with pipe_cat:
    cat_outer = fn.cat(src1, src2, src3, axis = 0)
    cat_middle = fn.cat(src1, src2, src3, axis = 1)
    cat_inner = fn.cat(src1, src2, src3, axis = 2)
    pipe_cat.set_outputs(cat_outer, cat_middle, cat_inner)

pipe_cat.build()
o = pipe_cat.run()
[2]:
print("Concatenation along outer axis:")
print(o[0].at(0))
print("Shape: ", o[0].at(0).shape)
Concatenation along outer axis:
[[[  1   2   3   4]
  [  5   6   7   8]
  [  9  10  11  12]]

 [[ 13  14  15  16]
  [ 17  18  19  20]
  [ 21  22  23  24]]

 [[101 102 103 104]
  [105 106 107 108]
  [109 110 111 112]]

 [[113 114 115 116]
  [117 118 119 120]
  [121 122 123 124]]

 [[201 202 203 204]
  [205 206 207 208]
  [209 210 211 212]]

 [[213 214 215 216]
  [217 218 219 220]
  [221 222 223 224]]]
Shape:  (6, 3, 4)
[3]:
print("Concatenation along middle axis:")
print(o[1].at(0))
print("Shape: ", o[1].at(0).shape)
Concatenation along middle axis:
[[[  1   2   3   4]
  [  5   6   7   8]
  [  9  10  11  12]
  [101 102 103 104]
  [105 106 107 108]
  [109 110 111 112]
  [201 202 203 204]
  [205 206 207 208]
  [209 210 211 212]]

 [[ 13  14  15  16]
  [ 17  18  19  20]
  [ 21  22  23  24]
  [113 114 115 116]
  [117 118 119 120]
  [121 122 123 124]
  [213 214 215 216]
  [217 218 219 220]
  [221 222 223 224]]]
Shape:  (2, 9, 4)
[4]:
print("Concatenation along inner axis:")
print(o[2].at(0))
print("Shape: ", o[2].at(0).shape)
Concatenation along inner axis:
[[[  1   2   3   4 101 102 103 104 201 202 203 204]
  [  5   6   7   8 105 106 107 108 205 206 207 208]
  [  9  10  11  12 109 110 111 112 209 210 211 212]]

 [[ 13  14  15  16 113 114 115 116 213 214 215 216]
  [ 17  18  19  20 117 118 119 120 217 218 219 220]
  [ 21  22  23  24 121 122 123 124 221 222 223 224]]]
Shape:  (2, 3, 12)

Stacking

When stacking, a new axis is inserted. It can be inserted after the innermost axis, in which case the values from the input tensors are interleaved.

Let’s apply stacking to the same inputs which were used for concatenation.

[5]:
pipe_stack = dali.pipeline.Pipeline(batch_size = 1, num_threads = 3, device_id = 0)
with pipe_stack:
    st_outermost = fn.stack(src1, src2, src3, axis = 0)
    st_1         = fn.stack(src1, src2, src3, axis = 1)
    st_2         = fn.stack(src1, src2, src3, axis = 2)
    st_new_inner = fn.stack(src1, src2, src3, axis = 3)
    pipe_stack.set_outputs(st_outermost, st_1, st_2, st_new_inner)

pipe_stack.build()
o = pipe_stack.run()
[6]:
print("Stacking - insert outermost axis:")
print(o[0].at(0))
print("Shape: ", o[0].at(0).shape)
Stacking - insert outermost axis:
[[[[  1   2   3   4]
   [  5   6   7   8]
   [  9  10  11  12]]

  [[ 13  14  15  16]
   [ 17  18  19  20]
   [ 21  22  23  24]]]


 [[[101 102 103 104]
   [105 106 107 108]
   [109 110 111 112]]

  [[113 114 115 116]
   [117 118 119 120]
   [121 122 123 124]]]


 [[[201 202 203 204]
   [205 206 207 208]
   [209 210 211 212]]

  [[213 214 215 216]
   [217 218 219 220]
   [221 222 223 224]]]]
Shape:  (3, 2, 3, 4)
[7]:
print("Stacking - new axis before 1:")
print(o[1].at(0))
print("Shape: ", o[1].at(0).shape)
Stacking - new axis before 1:
[[[[  1   2   3   4]
   [  5   6   7   8]
   [  9  10  11  12]]

  [[101 102 103 104]
   [105 106 107 108]
   [109 110 111 112]]

  [[201 202 203 204]
   [205 206 207 208]
   [209 210 211 212]]]


 [[[ 13  14  15  16]
   [ 17  18  19  20]
   [ 21  22  23  24]]

  [[113 114 115 116]
   [117 118 119 120]
   [121 122 123 124]]

  [[213 214 215 216]
   [217 218 219 220]
   [221 222 223 224]]]]
Shape:  (2, 3, 3, 4)
[8]:
print("Stacking - new axis before 2:")
print(o[2].at(0))
print("Shape: ", o[2].at(0).shape)
Stacking - new axis before 2:
[[[[  1   2   3   4]
   [101 102 103 104]
   [201 202 203 204]]

  [[  5   6   7   8]
   [105 106 107 108]
   [205 206 207 208]]

  [[  9  10  11  12]
   [109 110 111 112]
   [209 210 211 212]]]


 [[[ 13  14  15  16]
   [113 114 115 116]
   [213 214 215 216]]

  [[ 17  18  19  20]
   [117 118 119 120]
   [217 218 219 220]]

  [[ 21  22  23  24]
   [121 122 123 124]
   [221 222 223 224]]]]
Shape:  (2, 3, 3, 4)
[9]:
print("Stacking - new innermost axis:")
print(o[3].at(0))
print("Shape: ", o[3].at(0).shape)
Stacking - new innermost axis:
[[[[  1 101 201]
   [  2 102 202]
   [  3 103 203]
   [  4 104 204]]

  [[  5 105 205]
   [  6 106 206]
   [  7 107 207]
   [  8 108 208]]

  [[  9 109 209]
   [ 10 110 210]
   [ 11 111 211]
   [ 12 112 212]]]


 [[[ 13 113 213]
   [ 14 114 214]
   [ 15 115 215]
   [ 16 116 216]]

  [[ 17 117 217]
   [ 18 118 218]
   [ 19 119 219]
   [ 20 120 220]]

  [[ 21 121 221]
   [ 22 122 222]
   [ 23 123 223]
   [ 24 124 224]]]]
Shape:  (2, 3, 4, 3)