MoE (Mixture of Experts)¶
Do MoE (Mixture of Experts) computation based on selected experts and their scores and generate the weighted-sum output of the selected experts.
Attributes¶
activationType The activation type for the MoE layer. Supported values are MoEActType::kNONE and MoEActType::kSILU.
quantizationToType The type to quantize to when quantizing the mul output within the MoE layer. Currently only DataType::kFP8 is supported.
quantizationBlockShape The quantization block shape when quantizing the mul output within the MoE layer. Only used when setQuantizationDynamicDblQ is called to configure dynamic quantization. (Not supported yet.)
dynQOutputScaleType The generated scale type when dynamically quantizing the mul output within the MoE layer. Only used when setQuantizationDynamicDblQ is called to configure dynamic quantization. (Not supported yet.)
swigluParamLimit The limit for the SWIGLU parameter. (Not supported yet.)
swigluParamAlpha The alpha for the SWIGLU parameter. (Not supported yet.)
swigluParamBeta The beta for the SWIGLU parameter. (Not supported yet.)
Inputs¶
hiddenStates: tensor of type T, the hidden states of the layer.
selectedExpertsForTokens: tensor of type M, the top K experts selected for each token.
scoresForSelectedExperts: tensor of type T, the scales for the selected experts per token.
fcGateWeights: tensor of type T, the weights for the fcGate. Call setGatedWeights to provide this input.
fcUpWeights: tensor of type T, the weights for the fcUp. Call setGatedWeights to provide this input.
fcDownWeights: tensor of type T, the weights for the fcDown. Call setGatedWeights to provide this input.
fcGateBias optional: tensor of type T, the bias for the fcGate. Call setGatedBiases to provide this input.
fcUpBias optional: tensor of type T, the bias for the fcUp. Call setGatedBiases to provide this input.
fcDownBias optional: tensor of type T, the bias for the fcDown. Call setGatedBiases to provide this input.
fcDownActivationScale optional: tensor of type T, the scale for the fcDown activation (mul output). Call setQuantizationStatic to provide this input.
fcDownActivationDblQScale optional: tensor of type T, the double-quantization scale for the fcDown activation (mul output). Call setQuantizationDynamicDblQ to provide this input. (Not supported yet.)
Outputs¶
output: tensor of type T
Data Types¶
T: float32
M: int32
Shape Information¶
hiddenStates: [batchSize, seqLen, hiddenSize]
selectedExpertsForTokens: [batchSize, seqLen, topK]
scoresForSelectedExperts: [batchSize, seqLen, topK]
fcGateWeights: [numExperts, hiddenSize, moeInterSize]
fcUpWeights: [numExperts, hiddenSize, moeInterSize]
fcDownWeights: [numExperts, moeInterSize, hiddenSize]
fcGateBias: [numExperts, moeInterSize]
fcUpBias: [numExperts, moeInterSize]
fcDownBias: [numExperts, hiddenSize]
fcDownActivationScale: [] or [numExperts]
fcDownActivationDblQScale: [] or [numExperts]
output: [batchSize, seqLen, hiddenSize]
DLA Support¶
Not supported.
Notes¶
Currently, only Thor platform is supported. Also, the performance is limited when seqLen > 16.
Currently, only the following configuration is supported:
Weights must be NVFP4 double quantized: Weights must pass through a DQ (with its scale coming out of another DQ) and then optionally a Transpose before passing to the MoE layer. The weights must be
DataType::kFP4before passing through the weight DQ. The scales must beDataType::kFP8before passing through the scale DQ.hiddenStatesinput must be FP8 quantized:hiddenStatesinput must pass through a DQ before passing to the MoE layer.hiddenStatesmust beDataType::kFP8before passing through the DQ.Internal mul output activation must be FP8 quantized:
setQuantizationStaticmust be used andquantizationToTypemust beDataType::kFP8.
Examples¶
MoE
get_runner.network = get_runner.builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
network = get_runner.network
batch_size = 1
seq_len = -1 # dynamic
num_experts = 8
top_k = 4
hidden_size = 32
moe_inter_size = 32
block_size = 16
seq_len_min, seq_len_opt, seq_len_max = 1, 64, 128
opt_profile = get_runner.builder.create_optimization_profile()
opt_profile.set_shape("hidden_states", (batch_size, seq_len_min, hidden_size), (batch_size, seq_len_opt, hidden_size), (batch_size, seq_len_max, hidden_size))
opt_profile.set_shape("selected_experts_for_tokens", (batch_size, seq_len_min, top_k), (batch_size, seq_len_opt, top_k), (batch_size, seq_len_max, top_k))
opt_profile.set_shape("scores_for_selected_experts", (batch_size, seq_len_min, top_k), (batch_size, seq_len_opt, top_k), (batch_size, seq_len_max, top_k))
get_runner.config.add_optimization_profile(opt_profile)
hidden_states = network.add_input("hidden_states", dtype=trt.float32, shape=(batch_size, seq_len, hidden_size))
selected_experts_for_tokens = network.add_input("selected_experts_for_tokens", dtype=trt.int32, shape=(batch_size, seq_len, top_k))
scores_for_selected_experts = network.add_input("scores_for_selected_experts", dtype=trt.float32, shape=(batch_size, seq_len, top_k))
fc_gate_weights_data = np.ones(shape=(num_experts, moe_inter_size, hidden_size), dtype=np.float32)
fc_up_weights_data = np.ones(shape=(num_experts, moe_inter_size, hidden_size), dtype=np.float32)
fc_down_weights_data = np.ones(shape=(num_experts, hidden_size, moe_inter_size), dtype=np.float32)
fc_gate_weights = network.add_constant(shape=(num_experts, moe_inter_size, hidden_size), weights=fc_gate_weights_data).get_output(0)
fc_up_weights = network.add_constant(shape=(num_experts, moe_inter_size, hidden_size), weights=fc_up_weights_data).get_output(0)
fc_down_weights = network.add_constant(shape=(num_experts, hidden_size, moe_inter_size), weights=fc_down_weights_data).get_output(0)
fc_gate_weight_scale_data = np.ones(shape=(num_experts, moe_inter_size, hidden_size // block_size), dtype=np.float32)
fc_up_weight_scale_data = np.ones(shape=(num_experts, moe_inter_size, hidden_size // block_size), dtype=np.float32)
fc_down_weight_scale_data = np.ones(shape=(num_experts, hidden_size, moe_inter_size // block_size), dtype=np.float32)
fc_gate_weight_scale = network.add_constant(shape=(num_experts, moe_inter_size, hidden_size // block_size), weights=fc_gate_weight_scale_data).get_output(0)
fc_up_weight_scale = network.add_constant(shape=(num_experts, moe_inter_size, hidden_size // block_size), weights=fc_up_weight_scale_data).get_output(0)
fc_down_weight_scale = network.add_constant(shape=(num_experts, hidden_size, moe_inter_size // block_size), weights=fc_down_weight_scale_data).get_output(0)
fc_gate_weight_dbl_q_scale_data = np.ones(shape=(num_experts,), dtype=np.float32)
fc_up_weight_dbl_q_scale_data = np.ones(shape=(num_experts,), dtype=np.float32)
fc_down_weight_dbl_q_scale_data = np.ones(shape=(num_experts,), dtype=np.float32)
fc_gate_weight_dbl_q_scale = network.add_constant(shape=(num_experts,), weights=fc_gate_weight_dbl_q_scale_data).get_output(0)
fc_up_weight_dbl_q_scale = network.add_constant(shape=(num_experts,), weights=fc_up_weight_dbl_q_scale_data).get_output(0)
fc_down_weight_dbl_q_scale = network.add_constant(shape=(num_experts,), weights=fc_down_weight_dbl_q_scale_data).get_output(0)
fc_gate_weight_scale_quantize = network.add_quantize(fc_gate_weight_scale, fc_gate_weight_dbl_q_scale, trt.fp8)
fc_gate_weight_scale_quantize.axis = 0
fc_gate_weight_scale_quantized = fc_gate_weight_scale_quantize.get_output(0)
fc_up_weight_scale_quantize = network.add_quantize(fc_up_weight_scale, fc_up_weight_dbl_q_scale, trt.fp8)
fc_up_weight_scale_quantize.axis = 0
fc_up_weight_scale_quantized = fc_up_weight_scale_quantize.get_output(0)
fc_down_weight_scale_quantize = network.add_quantize(fc_down_weight_scale, fc_down_weight_dbl_q_scale, trt.fp8)
fc_down_weight_scale_quantize.axis = 0
fc_down_weight_scale_quantized = fc_down_weight_scale_quantize.get_output(0)
fc_gate_weight_scale_dequantize = network.add_dequantize(fc_gate_weight_scale_quantized, fc_gate_weight_dbl_q_scale, trt.float32)
fc_gate_weight_scale_dequantize.axis = 0
fc_gate_weight_scale_dequantized = fc_gate_weight_scale_dequantize.get_output(0)
fc_up_weight_scale_dequantize = network.add_dequantize(fc_up_weight_scale_quantized, fc_up_weight_dbl_q_scale, trt.float32)
fc_up_weight_scale_dequantize.axis = 0
fc_up_weight_scale_dequantized = fc_up_weight_scale_dequantize.get_output(0)
fc_down_weight_scale_dequantize = network.add_dequantize(fc_down_weight_scale_quantized, fc_down_weight_dbl_q_scale, trt.float32)
fc_down_weight_scale_dequantize.axis = 0
fc_down_weight_scale_dequantized = fc_down_weight_scale_dequantize.get_output(0)
fc_gate_weights_quantize = network.add_quantize(fc_gate_weights, fc_gate_weight_scale, trt.fp4)
fc_gate_weights_quantize.axis = 2
fc_gate_weights_quantized = fc_gate_weights_quantize.get_output(0)
fc_up_weights_quantize = network.add_quantize(fc_up_weights, fc_up_weight_scale, trt.fp4)
fc_up_weights_quantize.axis = 2
fc_up_weights_quantized = fc_up_weights_quantize.get_output(0)
fc_down_weights_quantize = network.add_quantize(fc_down_weights, fc_down_weight_scale, trt.fp4)
fc_down_weights_quantize.axis = 2
fc_down_weights_quantized = fc_down_weights_quantize.get_output(0)
fc_gate_weights_dequantize = network.add_dequantize(fc_gate_weights_quantized, fc_gate_weight_scale_dequantized, trt.float32)
fc_gate_weights_dequantize.axis = 2
fc_gate_weights_dequantized = fc_gate_weights_dequantize.get_output(0)
fc_up_weights_dequantize = network.add_dequantize(fc_up_weights_quantized, fc_up_weight_scale_dequantized, trt.float32)
fc_up_weights_dequantize.axis = 2
fc_up_weights_dequantized = fc_up_weights_dequantize.get_output(0)
fc_down_weights_dequantize = network.add_dequantize(fc_down_weights_quantized, fc_down_weight_scale_dequantized, trt.float32)
fc_down_weights_dequantize.axis = 2
fc_down_weights_dequantized = fc_down_weights_dequantize.get_output(0)
fc_gate_weights_shuffle = network.add_shuffle(fc_gate_weights_dequantized)
fc_gate_weights_shuffle.first_transpose = trt.Permutation([0, 2, 1])
fc_gate_weights_transposed = fc_gate_weights_shuffle.get_output(0)
fc_up_weights_shuffle = network.add_shuffle(fc_up_weights_dequantized)
fc_up_weights_shuffle.first_transpose = trt.Permutation([0, 2, 1])
fc_up_weights_transposed = fc_up_weights_shuffle.get_output(0)
fc_down_weights_shuffle = network.add_shuffle(fc_down_weights_dequantized)
fc_down_weights_shuffle.first_transpose = trt.Permutation([0, 2, 1])
fc_down_weights_transposed = fc_down_weights_shuffle.get_output(0)
input_activation_scale_data = np.ones(shape=(), dtype=np.float32)
input_activation_scale = network.add_constant(shape=(), weights=input_activation_scale_data).get_output(0)
hidden_states_quantized = network.add_quantize(hidden_states, input_activation_scale, trt.fp8).get_output(0)
hidden_states_dequantized = network.add_dequantize(hidden_states_quantized, input_activation_scale, trt.float32).get_output(0)
fc_down_activation_scale_data = np.ones(shape=(num_experts,), dtype=np.float32)
fc_down_activation_scale = network.add_constant(shape=(num_experts,), weights=fc_down_activation_scale_data).get_output(0)
layer = network.add_moe(hidden_states_dequantized, selected_experts_for_tokens, scores_for_selected_experts)
layer.set_gated_weights(fc_gate_weights_transposed, fc_up_weights_transposed, fc_down_weights_transposed, trt.MoEActType.SILU)
layer.set_quantization_static(fc_down_activation_scale, trt.fp8)
network.mark_output(layer.get_output(0))
seq_len = 2
inputs[hidden_states.name] = np.ones(shape=(batch_size, seq_len, hidden_size), dtype=np.float32)
inputs[selected_experts_for_tokens.name] = np.random.randint(0, num_experts, (batch_size, seq_len, top_k), dtype=np.int32)
inputs[scores_for_selected_experts.name] = np.ones(shape=(batch_size, seq_len, top_k), dtype=np.float32)
outputs[layer.get_output(0).name] = layer.get_output(0).shape
expected[layer.get_output(0).name] = np.ones(shape=(batch_size, seq_len, hidden_size), dtype=np.float32) * 57344.0
C++ API¶
For more information about the C++ IMoELayer operator, refer to the C++ IMoELayer documentation.
Python API¶
For more information about the Python IMoELayer operator, refer to the Python IMoELayer documentation.