MoE (Mixture of Experts)

Do MoE (Mixture of Experts) computation based on selected experts and their scores and generate the weighted-sum output of the selected experts.

Attributes

activationType The activation type for the MoE layer. Supported values are MoEActType::kNONE and MoEActType::kSILU.

quantizationToType The type to quantize to when quantizing the mul output within the MoE layer. Currently only DataType::kFP8 is supported.

quantizationBlockShape The quantization block shape when quantizing the mul output within the MoE layer. Only used when setQuantizationDynamicDblQ is called to configure dynamic quantization. (Not supported yet.)

dynQOutputScaleType The generated scale type when dynamically quantizing the mul output within the MoE layer. Only used when setQuantizationDynamicDblQ is called to configure dynamic quantization. (Not supported yet.)

swigluParamLimit The limit for the SWIGLU parameter. (Not supported yet.)

swigluParamAlpha The alpha for the SWIGLU parameter. (Not supported yet.)

swigluParamBeta The beta for the SWIGLU parameter. (Not supported yet.)

Inputs

hiddenStates: tensor of type T, the hidden states of the layer.

selectedExpertsForTokens: tensor of type M, the top K experts selected for each token.

scoresForSelectedExperts: tensor of type T, the scales for the selected experts per token.

fcGateWeights: tensor of type T, the weights for the fcGate. Call setGatedWeights to provide this input.

fcUpWeights: tensor of type T, the weights for the fcUp. Call setGatedWeights to provide this input.

fcDownWeights: tensor of type T, the weights for the fcDown. Call setGatedWeights to provide this input.

fcGateBias optional: tensor of type T, the bias for the fcGate. Call setGatedBiases to provide this input.

fcUpBias optional: tensor of type T, the bias for the fcUp. Call setGatedBiases to provide this input.

fcDownBias optional: tensor of type T, the bias for the fcDown. Call setGatedBiases to provide this input.

fcDownActivationScale optional: tensor of type T, the scale for the fcDown activation (mul output). Call setQuantizationStatic to provide this input.

fcDownActivationDblQScale optional: tensor of type T, the double-quantization scale for the fcDown activation (mul output). Call setQuantizationDynamicDblQ to provide this input. (Not supported yet.)

Outputs

output: tensor of type T

Data Types

T: float32

M: int32

Shape Information

hiddenStates: [batchSize, seqLen, hiddenSize]

selectedExpertsForTokens: [batchSize, seqLen, topK]

scoresForSelectedExperts: [batchSize, seqLen, topK]

fcGateWeights: [numExperts, hiddenSize, moeInterSize]

fcUpWeights: [numExperts, hiddenSize, moeInterSize]

fcDownWeights: [numExperts, moeInterSize, hiddenSize]

fcGateBias: [numExperts, moeInterSize]

fcUpBias: [numExperts, moeInterSize]

fcDownBias: [numExperts, hiddenSize]

fcDownActivationScale: [] or [numExperts]

fcDownActivationDblQScale: [] or [numExperts]

output: [batchSize, seqLen, hiddenSize]

DLA Support

Not supported.

Notes

Currently, only Thor platform is supported. Also, the performance is limited when seqLen > 16.

Currently, only the following configuration is supported:

  1. Weights must be NVFP4 double quantized: Weights must pass through a DQ (with its scale coming out of another DQ) and then optionally a Transpose before passing to the MoE layer. The weights must be DataType::kFP4 before passing through the weight DQ. The scales must be DataType::kFP8 before passing through the scale DQ.

  2. hiddenStates input must be FP8 quantized: hiddenStates input must pass through a DQ before passing to the MoE layer. hiddenStates must be DataType::kFP8 before passing through the DQ.

  3. Internal mul output activation must be FP8 quantized: setQuantizationStatic must be used and quantizationToType must be DataType::kFP8.

Examples

MoE
get_runner.network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
network = get_runner.network

batch_size = 1
seq_len = -1 # dynamic
num_experts = 8
top_k = 4
hidden_size = 32
moe_inter_size = 32
block_size = 16

seq_len_min, seq_len_opt, seq_len_max = 1, 64, 128
opt_profile = get_runner.builder.create_optimization_profile()
opt_profile.set_shape("hidden_states", (batch_size, seq_len_min, hidden_size), (batch_size, seq_len_opt, hidden_size), (batch_size, seq_len_max, hidden_size))
opt_profile.set_shape("selected_experts_for_tokens", (batch_size, seq_len_min, top_k), (batch_size, seq_len_opt, top_k), (batch_size, seq_len_max, top_k))
opt_profile.set_shape("scores_for_selected_experts", (batch_size, seq_len_min, top_k), (batch_size, seq_len_opt, top_k), (batch_size, seq_len_max, top_k))
get_runner.config.add_optimization_profile(opt_profile)

hidden_states = network.add_input("hidden_states", dtype=trt.float32, shape=(batch_size, seq_len, hidden_size))
selected_experts_for_tokens = network.add_input("selected_experts_for_tokens", dtype=trt.int32, shape=(batch_size, seq_len, top_k))
scores_for_selected_experts = network.add_input("scores_for_selected_experts", dtype=trt.float32, shape=(batch_size, seq_len, top_k))

fc_gate_weights_data = np.ones(shape=(num_experts, moe_inter_size, hidden_size), dtype=np.float32)
fc_up_weights_data = np.ones(shape=(num_experts, moe_inter_size, hidden_size), dtype=np.float32)
fc_down_weights_data = np.ones(shape=(num_experts, hidden_size, moe_inter_size), dtype=np.float32)
fc_gate_weights = network.add_constant(shape=(num_experts, moe_inter_size, hidden_size), weights=fc_gate_weights_data).get_output(0)
fc_up_weights = network.add_constant(shape=(num_experts, moe_inter_size, hidden_size), weights=fc_up_weights_data).get_output(0)
fc_down_weights = network.add_constant(shape=(num_experts, hidden_size, moe_inter_size), weights=fc_down_weights_data).get_output(0)

fc_gate_weight_scale_data = np.ones(shape=(num_experts, moe_inter_size, hidden_size // block_size), dtype=np.float32)
fc_up_weight_scale_data = np.ones(shape=(num_experts, moe_inter_size, hidden_size // block_size), dtype=np.float32)
fc_down_weight_scale_data = np.ones(shape=(num_experts, hidden_size, moe_inter_size // block_size), dtype=np.float32)
fc_gate_weight_scale = network.add_constant(shape=(num_experts, moe_inter_size, hidden_size // block_size), weights=fc_gate_weight_scale_data).get_output(0)
fc_up_weight_scale = network.add_constant(shape=(num_experts, moe_inter_size, hidden_size // block_size), weights=fc_up_weight_scale_data).get_output(0)
fc_down_weight_scale = network.add_constant(shape=(num_experts, hidden_size, moe_inter_size // block_size), weights=fc_down_weight_scale_data).get_output(0)

fc_gate_weight_dbl_q_scale_data = np.ones(shape=(num_experts,), dtype=np.float32)
fc_up_weight_dbl_q_scale_data = np.ones(shape=(num_experts,), dtype=np.float32)
fc_down_weight_dbl_q_scale_data = np.ones(shape=(num_experts,), dtype=np.float32)
fc_gate_weight_dbl_q_scale = network.add_constant(shape=(num_experts,), weights=fc_gate_weight_dbl_q_scale_data).get_output(0)
fc_up_weight_dbl_q_scale = network.add_constant(shape=(num_experts,), weights=fc_up_weight_dbl_q_scale_data).get_output(0)
fc_down_weight_dbl_q_scale = network.add_constant(shape=(num_experts,), weights=fc_down_weight_dbl_q_scale_data).get_output(0)

fc_gate_weight_scale_quantize = network.add_quantize(fc_gate_weight_scale, fc_gate_weight_dbl_q_scale, trt.fp8)
fc_gate_weight_scale_quantize.axis = 0
fc_gate_weight_scale_quantized = fc_gate_weight_scale_quantize.get_output(0)
fc_up_weight_scale_quantize = network.add_quantize(fc_up_weight_scale, fc_up_weight_dbl_q_scale, trt.fp8)
fc_up_weight_scale_quantize.axis = 0
fc_up_weight_scale_quantized = fc_up_weight_scale_quantize.get_output(0)
fc_down_weight_scale_quantize = network.add_quantize(fc_down_weight_scale, fc_down_weight_dbl_q_scale, trt.fp8)
fc_down_weight_scale_quantize.axis = 0
fc_down_weight_scale_quantized = fc_down_weight_scale_quantize.get_output(0)

fc_gate_weight_scale_dequantize = network.add_dequantize(fc_gate_weight_scale_quantized, fc_gate_weight_dbl_q_scale, trt.float32)
fc_gate_weight_scale_dequantize.axis = 0
fc_gate_weight_scale_dequantized = fc_gate_weight_scale_dequantize.get_output(0)
fc_up_weight_scale_dequantize = network.add_dequantize(fc_up_weight_scale_quantized, fc_up_weight_dbl_q_scale, trt.float32)
fc_up_weight_scale_dequantize.axis = 0
fc_up_weight_scale_dequantized = fc_up_weight_scale_dequantize.get_output(0)
fc_down_weight_scale_dequantize = network.add_dequantize(fc_down_weight_scale_quantized, fc_down_weight_dbl_q_scale, trt.float32)
fc_down_weight_scale_dequantize.axis = 0
fc_down_weight_scale_dequantized = fc_down_weight_scale_dequantize.get_output(0)

fc_gate_weights_quantize = network.add_quantize(fc_gate_weights, fc_gate_weight_scale, trt.fp4)
fc_gate_weights_quantize.axis = 2
fc_gate_weights_quantized = fc_gate_weights_quantize.get_output(0)
fc_up_weights_quantize = network.add_quantize(fc_up_weights, fc_up_weight_scale, trt.fp4)
fc_up_weights_quantize.axis = 2
fc_up_weights_quantized = fc_up_weights_quantize.get_output(0)
fc_down_weights_quantize = network.add_quantize(fc_down_weights, fc_down_weight_scale, trt.fp4)
fc_down_weights_quantize.axis = 2
fc_down_weights_quantized = fc_down_weights_quantize.get_output(0)

fc_gate_weights_dequantize = network.add_dequantize(fc_gate_weights_quantized, fc_gate_weight_scale_dequantized, trt.float32)
fc_gate_weights_dequantize.axis = 2
fc_gate_weights_dequantized = fc_gate_weights_dequantize.get_output(0)
fc_up_weights_dequantize = network.add_dequantize(fc_up_weights_quantized, fc_up_weight_scale_dequantized, trt.float32)
fc_up_weights_dequantize.axis = 2
fc_up_weights_dequantized = fc_up_weights_dequantize.get_output(0)
fc_down_weights_dequantize = network.add_dequantize(fc_down_weights_quantized, fc_down_weight_scale_dequantized, trt.float32)
fc_down_weights_dequantize.axis = 2
fc_down_weights_dequantized = fc_down_weights_dequantize.get_output(0)

fc_gate_weights_shuffle = network.add_shuffle(fc_gate_weights_dequantized)
fc_gate_weights_shuffle.first_transpose = trt.Permutation([0, 2, 1])
fc_gate_weights_transposed = fc_gate_weights_shuffle.get_output(0)
fc_up_weights_shuffle = network.add_shuffle(fc_up_weights_dequantized)
fc_up_weights_shuffle.first_transpose = trt.Permutation([0, 2, 1])
fc_up_weights_transposed = fc_up_weights_shuffle.get_output(0)
fc_down_weights_shuffle = network.add_shuffle(fc_down_weights_dequantized)
fc_down_weights_shuffle.first_transpose = trt.Permutation([0, 2, 1])
fc_down_weights_transposed = fc_down_weights_shuffle.get_output(0)

input_activation_scale_data = np.ones(shape=(), dtype=np.float32)
input_activation_scale = network.add_constant(shape=(), weights=input_activation_scale_data).get_output(0)

hidden_states_quantized = network.add_quantize(hidden_states, input_activation_scale, trt.fp8).get_output(0)
hidden_states_dequantized = network.add_dequantize(hidden_states_quantized, input_activation_scale, trt.float32).get_output(0)

fc_down_activation_scale_data = np.ones(shape=(num_experts,), dtype=np.float32)
fc_down_activation_scale = network.add_constant(shape=(num_experts,), weights=fc_down_activation_scale_data).get_output(0)

layer = network.add_moe(hidden_states_dequantized, selected_experts_for_tokens, scores_for_selected_experts)
layer.set_gated_weights(fc_gate_weights_transposed, fc_up_weights_transposed, fc_down_weights_transposed, trt.MoEActType.SILU)
layer.set_quantization_static(fc_down_activation_scale, trt.fp8)
network.mark_output(layer.get_output(0))

seq_len = 2
inputs[hidden_states.name] = np.ones(shape=(batch_size, seq_len, hidden_size), dtype=np.float32)
inputs[selected_experts_for_tokens.name] = np.random.randint(0, num_experts, (batch_size, seq_len, top_k), dtype=np.int32)
inputs[scores_for_selected_experts.name] = np.ones(shape=(batch_size, seq_len, top_k), dtype=np.float32)

outputs[layer.get_output(0).name] = layer.get_output(0).shape

expected[layer.get_output(0).name] = np.ones(shape=(batch_size, seq_len, hidden_size), dtype=np.float32) * 57344.0

C++ API

For more information about the C++ IMoELayer operator, refer to the C++ IMoELayer documentation.

Python API

For more information about the Python IMoELayer operator, refer to the Python IMoELayer documentation.