RotaryEmbedding

Computes the Rotary Position Embedding (RoPE) of the input tensor.

Attributes

interleaved The boolean that specifies whether the input tensor is in interleaved format, i.e., whether the 2-d vectors rotated are taken from adjacent 2 elements in the hidden dimension.

rotaryEmbeddingDim The hidden dimension that participates in RoPE. A special value of 0 means the full hidden dimension participates in RoPE.

Inputs

input: tensor of type T

cosCache: tensor of type T, cosine values for use in computing the rotary embedding.

sinCache: tensor of type T, sine values for use in computing the rotary embedding.

positionIds optional: tensor of type M, position IDs for indexing into cosCache and sinCache.

Outputs

output: tensor of type T

Data Types

T: float32, float16, bfloat16

M: int64

Shape Information

input and output are tensors with the same shape of \([b, d, s, h]\)

cosCache and sinCache are tensors with the same shape of \([b, s, h/2]\).

If positionIds is provided, the shape of cosCache and sinCache is \([maxPositionId+1, h/2]\).

If rotaryEmbeddingDim is not 0, the last dimension of cosCache and sinCache should be \(rotaryEmbeddingDim/2\) instead of \(h/2\).

positionIds, if provided, is a tensor with the shape of \([b, s]\).

DLA Support

Not supported.

Examples

RotaryEmbedding
input = network.add_input("input", dtype=trt.float32, shape=(2, 8, 4, 512))
cos_cache = network.add_input("cos_cache", dtype=trt.float32, shape=(100, 256))
sin_cache = network.add_input("sin_cache", dtype=trt.float32, shape=(100, 256))
position_ids = network.add_input("position_ids", dtype=trt.int64, shape=(2, 4))
layer = network.add_rotary_embedding(input=input, cos_cache=cos_cache, sin_cache=sin_cache, interleaved=False, rotary_embedding_dim=0)
layer.set_input(3, position_ids)
network.mark_output(layer.get_output(0))

inputs[input.name] = np.random.rand(2, 8, 4, 512).astype("f")
inputs[cos_cache.name] = np.random.rand(100, 256).astype("f")
inputs[sin_cache.name] = np.random.rand(100, 256).astype("f")
inputs[position_ids.name] = np.array([[6, 2, 1, 7], [2, 8, 3, 6]])

outputs[layer.get_output(0).name] = layer.get_output(0).shape

# This is a reference implementation of the rotary embedding operator.
def compute_rotary_embedding(
    input,
    cos_cache,
    sin_cache,
    position_ids=None,
    interleaved=False,
    rotary_embedding_dim=0,
):
    # Shape of input: (batch_size, num_heads, seq_len, head_size)
    head_size = input.shape[3]

    # Process partial RoPE
    rotary_embedding_dim = head_size if rotary_embedding_dim == 0 else rotary_embedding_dim
    x_rotate, x_not_rotate = np.split(input, [rotary_embedding_dim], axis=-1)

    # Get cached cosine and sine values
    cache = cos_cache + 1j * sin_cache
    if position_ids is not None:
        cache = cache[position_ids] # Shape: (batch_size, seq_len, rotary_embedding_dim/2)
    cache = cache[:, np.newaxis, :, :] # Shape: (batch_size, 1, seq_len, rotary_embedding_dim/2)

    # Get the 2-d vectors to rotate
    if interleaved:
        x1, x2 = x_rotate[..., 0::2], x_rotate[..., 1::2]
    else:
        x1, x2 = np.split(x_rotate, 2, axis=-1)
    x = x1 + 1j * x2

    # Rotate the vectors
    x = x * cache

    # Put the rotated vectors back
    if interleaved:
        x = np.expand_dims(x, axis=-1)
        x = np.concatenate((np.real(x), np.imag(x)), axis=-1)
        x = np.reshape(x, x_rotate.shape)
    else:
        x = np.concatenate((np.real(x), np.imag(x)), axis=-1)

    # Process partial RoPE
    output = np.concatenate((x, x_not_rotate), axis=-1)
    return output

expected[layer.get_output(0).name] = compute_rotary_embedding(inputs[input.name], inputs[cos_cache.name], inputs[sin_cache.name], inputs[position_ids.name], interleaved=False, rotary_embedding_dim=0)

C++ API

For more information about the C++ IRotaryEmbeddingLayer operator, refer to the C++ IRotaryEmbeddingLayer documentation.

Python API

For more information about the Python IRotaryEmbeddingLayer operator, refer to the Python IRotaryEmbeddingLayer documentation.