TensorRT 10.16.0
nvinfer1::IMoELayer Class Reference

A MoE layer in a network definition. Mixture of Experts (MoE) is a collection of experts with each expert specializing in processing different subsets of input data. The key innovation lies in using a Router that selectively activates only the specific experts needed for a given input, rather than engaging the entire neural network for every task. More...

#include <NvInfer.h>

Inheritance diagram for nvinfer1::IMoELayer:
nvinfer1::ILayer nvinfer1::INoCopy

Public Member Functions

void setGatedWeights (ITensor &fcGateWeights, ITensor &fcUpWeights, ITensor &fcDownWeights, MoEActType activationType) noexcept
 Set the weights of the experts when each expert is a GLU (gated linear unit). In each GLU, there are 3 linear layers and 1 activation function, so this function requires 3 weight tensors and 1 activation type. More...
 
void setGatedBiases (ITensor &fcGateBiases, ITensor &fcUpBiases, ITensor &fcDownBiases) noexcept
 Set the biases of the experts when each expert is a GLU (gated linear unit). In each GLU, there are 3 linear layers, so this function requires 3 bias tensors. More...
 
void setActivationType (MoEActType activationType) noexcept
 Set the activation type for the MoE layer. More...
 
MoEActType getActivationType () const noexcept
 Get the activation type for the MoE layer. More...
 
void setQuantizationStatic (ITensor &fcDownActivationScale, DataType dataType) noexcept
 Configure static quantization after the mul op. ┌── fcGate ── activation ───┐ │ │ hiddenStates ───┤ ├── mul ── {Q ── DQ} ── fcDown ── output │ │ └── fcUp ───────────────────┘ When using mul output static quantization, the user must provide: More...
 
void setQuantizationDynamicDblQ (ITensor &fcDownActivationDblQScale, DataType dataType, Dims const &blockShape, DataType dynQOutputScaleType) noexcept
 Configure dynamic quantization (with double quantization) after the mul op. ┌── fcGate ── activation ───┐ ┌──── DQ │ │ │ │ hiddenStates ───┤ ├── mul ── {DynQ ── DQ} ── fcDown ── output │ │ └── fcUp ───────────────────┘ When using mul output dynamic quantization (with double quantization), the user must provide: More...
 
void setQuantizationToType (DataType type) noexcept
 Set the data type the mul output is quantized to. More...
 
DataType getQuantizationToType () const noexcept
 Get the data type the mul in MoE layer is quantized to. More...
 
void setQuantizationBlockShape (Dims const &blockShape) noexcept
 Set the block shape for the quantization of the Mul output. More...
 
Dims getQuantizationBlockShape () const noexcept
 Get the block shape for the quantization of the Mul output. More...
 
void setDynQOutputScaleType (DataType type) noexcept
 Set the dynamic quantization output scale type. More...
 
DataType getDynQOutputScaleType () const noexcept
 Get the dynamic quantization output scale type. More...
 
void setSwigluParams (float limit, float alpha, float beta) noexcept
 Set the SwiGLU parameters. More...
 
void setSwigluParamLimit (float limit) noexcept
 Set the SwiGLU parameter limit. More...
 
float getSwigluParamLimit () const noexcept
 Get the SwiGLU parameter limit. More...
 
void setSwigluParamAlpha (float alpha) noexcept
 Set the SwiGLU parameter alpha. More...
 
float getSwigluParamAlpha () const noexcept
 Get the SwiGLU parameter alpha. More...
 
void setSwigluParamBeta (float beta) noexcept
 Set the SwiGLU parameter beta. More...
 
float getSwigluParamBeta () const noexcept
 Get the SwiGLU parameter beta. More...
 
void setInput (int32_t index, ITensor &tensor) noexcept
 Set the input of the MoE layer. More...
 
void setInput (int32_t index, ITensor &tensor) noexcept
 Replace an input of this layer with a specific tensor. More...
 
- Public Member Functions inherited from nvinfer1::ILayer
LayerType getType () const noexcept
 Return the type of a layer. More...
 
void setName (char const *name) noexcept
 Set the name of a layer. More...
 
char const * getName () const noexcept
 Return the name of a layer. More...
 
int32_t getNbInputs () const noexcept
 Get the number of inputs of a layer. More...
 
ITensorgetInput (int32_t index) const noexcept
 Get the layer input corresponding to the given index. More...
 
int32_t getNbOutputs () const noexcept
 Get the number of outputs of a layer. More...
 
ITensorgetOutput (int32_t index) const noexcept
 Get the layer output corresponding to the given index. More...
 
void setInput (int32_t index, ITensor &tensor) noexcept
 Replace an input of this layer with a specific tensor. More...
 
TRT_DEPRECATED void setPrecision (DataType dataType) noexcept
 Set the preferred or required computational precision of this layer in a weakly-typed network. More...
 
DataType getPrecision () const noexcept
 get the computational precision of this layer More...
 
TRT_DEPRECATED bool precisionIsSet () const noexcept
 whether the computational precision has been set for this layer More...
 
TRT_DEPRECATED void resetPrecision () noexcept
 reset the computational precision for this layer More...
 
TRT_DEPRECATED void setOutputType (int32_t index, DataType dataType) noexcept
 Set the output type of this layer in a weakly-typed network. More...
 
DataType getOutputType (int32_t index) const noexcept
 get the output type of this layer More...
 
TRT_DEPRECATED bool outputTypeIsSet (int32_t index) const noexcept
 whether the output type has been set for this layer More...
 
TRT_DEPRECATED void resetOutputType (int32_t index) noexcept
 reset the output type for this layer More...
 
void setMetadata (char const *metadata) noexcept
 Set the metadata for this layer. More...
 
char const * getMetadata () const noexcept
 Get the metadata of the layer. More...
 
bool setNbRanks (int32_t nbRanks) noexcept
 Set the number of ranks for multi-device execution. More...
 
int32_t getNbRanks () const noexcept
 Get the number of ranks for multi-device execution. More...
 

Protected Member Functions

virtual ~IMoELayer () noexcept=default
 
- Protected Member Functions inherited from nvinfer1::ILayer
virtual ~ILayer () noexcept=default
 
- Protected Member Functions inherited from nvinfer1::INoCopy
 INoCopy ()=default
 
virtual ~INoCopy ()=default
 
 INoCopy (INoCopy const &other)=delete
 
INoCopyoperator= (INoCopy const &other)=delete
 
 INoCopy (INoCopy &&other)=delete
 
INoCopyoperator= (INoCopy &&other)=delete
 

Protected Attributes

apiv::VMoELayer * mImpl
 
- Protected Attributes inherited from nvinfer1::ILayer
apiv::VLayer * mLayer
 

Detailed Description

A MoE layer in a network definition. Mixture of Experts (MoE) is a collection of experts with each expert specializing in processing different subsets of input data. The key innovation lies in using a Router that selectively activates only the specific experts needed for a given input, rather than engaging the entire neural network for every task.

┌──────────────┐┌────────────────────────┐┌────────────────────────┐ │ hiddenStates ││selectedExpertsForTokens││scoresForSelectedExperts│ └──────────────┘└────────────────────────┘└────────────────────────┘ │ │ │ │ │ │ ┌───────────────────────────────────────────────────────────────────────────────────┐ │ │ │ ┌──────────────────────────┐ ┌──────────────────────────┐ │ │ │ │ Expert 0 │ │ MOE │ │ Expert i │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ ┌────────┐ ┌────────┐│ │ ┌────────┐ ┌────────┐│ │ │ │ │ fcGate │ │ fcUp ││ │ │ fcGate │ │ fcUp ││ │ │ │ │ │ │ ││ │ │ │ │ ││ │ │ │ └───┬────┘ └────┬───┘│ │ └───┬────┘ └────┬───┘│ │ │ │ │ │ │ │ │ │ │ │ │ │ ┌──────────┐ │ │ │ ┌──────────┐ │ │ │ │ │ │activation│ │ │ │ │activation│ │ │ │ │ │ └────┬─────┘ │ │ │ └────┬─────┘ │ │ │ │ │ │ │ │ ....... │ │ │ │ │ │ │ └──────┬───────┘ │ │ └──────┬───────┘ │ │ │ │ │ │ │ │ │ │ │ │ ┌────────┐ │ │ ┌────────┐ │ │ │ │ │ mul │ │ │ │ mul │ │ │ │ │ └───┬────┘ │ │ └───┬────┘ │ │ │ │ │ │ │ │ │ │ │ │ ┌───▼────┐ │ │ ┌───▼────┐ │ │ │ │ │ fcDown │ │ │ │ fcDown │ │ │ │ │ └───┬────┘ │ │ └───┬────┘ │ │ │ │ │ │ │ │ │ │ │ │ ┌───▼────┐ │ │ ┌───▼────┐ │ │ │ │ │output 0│ │ │ │output i│ │ │ │ │ └───┬────┘ │ │ └───┬────┘ │ │ │ └─────────────┼────────────┘ └─────────────┼────────────┘ │ │ │ │ │ │ └───────────────────┬───────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌───────────────┐ │ │ │ weightedSum │ │ │ └───────┬───────┘ │ └────────────────────────────────────│──────────────────────────────────────────────┘ ▼ ┌───────────────┐ │ moeOutput │ └───────────────┘

Definition in the MoE layer: fcDown, fcGate, fcUp are three linear layers. fc(x) = x * w + b, where x is the input, w is the weight, b is the bias, * is the matrix multiplication. activation is the activation function. mul is the multiplication between the output of fc_up and the output of fc_gate. weightedSum is the weighted sum of the output of the experts. moeOutput is the output of the MoE layer.

MoE is a collection of experts. Each expert is a GLU (gated linear unit), which consists by fcGate, fcUp, fcDown, activation, mul.

Definitions and Abbreviations: batchSize: batch size seqLen: sequence length hiddenSize: the size of the hidden states numExperts: the number of experts in the MoE layer moeInterSize: the intermediate size of the MoE layer topK: the number of experts to select for each token

This layer takes several activation inputs:

  1. hiddenStates: the hidden states of the layer, with shape [batchSize, seqLen, hiddenSize]
  2. selectedExpertsForTokens: the top K experts selected for each token, with shape [batchSize, seqLen, topK]
  3. scoresForSelectedExperts: the scales for the selected experts per token, with shape [batchSize, seqLen, topK] The MoE will take the selected experts and the corresponding scales for the selected experts to compute the output.

The weights in the MoE layer:

  1. fcGateWeights with shape [numExperts, hiddenSize, moeInterSize]: the weight matrix for fcGate
  2. fcUpWeights with shape [numExperts, hiddenSize, moeInterSize]: the weight matrix for fcUp
  3. fcDownWeights with shape [numExperts, moeInterSize, hiddenSize]: the weight matrix for fcDown

Several optional inputs are supported:

  1. fcGateBias: the bias for the fcGate, with shape [numExperts, moeInterSize]
  2. fcUpBias: the bias for the fcUp, with shape [numExperts, moeInterSize]
  3. fcDownBias: the bias for the fcDown, with shape [numExperts, hiddenSize] All the bias are none by default. You must either set all the bias or none of them.
  4. activation: the activation type for the MoE layer, currently only support SILU.

MoE computation process description: For each token, the MoE layer computation process is as follows:

  1. Input processing:
    • Receive hiddenStates:
    • Receive selectedExpertsForTokens:
    • Receive scoresForSelectedExperts:
  2. Expert computation for each token:
    • output_i = fcDown(fcUp(hiddenStates) * activation(fcGate(hiddenStates)))
  3. Expert output aggregation: For each token, firstly select all the experts that need to be activated to do the computation.
    • calculate the selected expert's output according to expert id in selectedExpertsForTokens for each token
    • Weighted sum of each expert's output according to weights in scoresForSelectedExperts for each token
    • Final output for the token: moeOutput = Σ(score_i * output_i) The output of MoE has the same shape as the input hiddenStates.
Warning
MoE is only supported on Thor. And performance is limited when seqLen > 16.
Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.

Constructor & Destructor Documentation

◆ ~IMoELayer()

virtual nvinfer1::IMoELayer::~IMoELayer ( )
protectedvirtualdefaultnoexcept

Member Function Documentation

◆ getActivationType()

MoEActType nvinfer1::IMoELayer::getActivationType ( ) const
inlinenoexcept

Get the activation type for the MoE layer.

See also
setActivationType()
Returns
the activation type for the MoE layer.

◆ getDynQOutputScaleType()

DataType nvinfer1::IMoELayer::getDynQOutputScaleType ( ) const
inlinenoexcept

Get the dynamic quantization output scale type.

See also
setDynQOutputScaleType()
Returns
the dynamic quantization output scale type.

◆ getQuantizationBlockShape()

Dims nvinfer1::IMoELayer::getQuantizationBlockShape ( ) const
inlinenoexcept

Get the block shape for the quantization of the Mul output.

See also
setQuantizationBlockShape()
Returns
the block shape for the quantization of the Mul output.

◆ getQuantizationToType()

DataType nvinfer1::IMoELayer::getQuantizationToType ( ) const
inlinenoexcept

Get the data type the mul in MoE layer is quantized to.

See also
setQuantizationToType()
Returns
the data type the mul in MoE layer is quantized to.

◆ getSwigluParamAlpha()

float nvinfer1::IMoELayer::getSwigluParamAlpha ( ) const
inlinenoexcept

Get the SwiGLU parameter alpha.

See also
setSwigluParamAlpha()
Returns
the SwiGLU parameter alpha.

◆ getSwigluParamBeta()

float nvinfer1::IMoELayer::getSwigluParamBeta ( ) const
inlinenoexcept

Get the SwiGLU parameter beta.

See also
setSwigluParamBeta()
Returns
the SwiGLU parameter beta.

◆ getSwigluParamLimit()

float nvinfer1::IMoELayer::getSwigluParamLimit ( ) const
inlinenoexcept

Get the SwiGLU parameter limit.

See also
setSwigluParamLimit()
Returns
the SwiGLU parameter limit.

◆ setActivationType()

void nvinfer1::IMoELayer::setActivationType ( MoEActType  activationType)
inlinenoexcept

Set the activation type for the MoE layer.

Parameters
activationTypethe activation type for the MoE layer.
See also
getActivationType()

◆ setDynQOutputScaleType()

void nvinfer1::IMoELayer::setDynQOutputScaleType ( DataType  type)
inlinenoexcept

Set the dynamic quantization output scale type.

Parameters
typethe dynamic quantization output scale type.
See also
getDynQOutputScaleType()

◆ setGatedBiases()

void nvinfer1::IMoELayer::setGatedBiases ( ITensor fcGateBiases,
ITensor fcUpBiases,
ITensor fcDownBiases 
)
inlinenoexcept

Set the biases of the experts when each expert is a GLU (gated linear unit). In each GLU, there are 3 linear layers, so this function requires 3 bias tensors.

Parameters
fcGateBiasesThe biases for the gate-projection layer of all experts in MoE. Shape: [numExperts, moeInterSize].
fcUpBiasesThe biases for the up-projection layer of all experts in MoE. Shape: [numExperts, moeInterSize].
fcDownBiasesThe biases for the down-projection layer of all experts in MoE. Shape: [numExperts, hiddenSize].

◆ setGatedWeights()

void nvinfer1::IMoELayer::setGatedWeights ( ITensor fcGateWeights,
ITensor fcUpWeights,
ITensor fcDownWeights,
MoEActType  activationType 
)
inlinenoexcept

Set the weights of the experts when each expert is a GLU (gated linear unit). In each GLU, there are 3 linear layers and 1 activation function, so this function requires 3 weight tensors and 1 activation type.

Parameters
fcGateWeightsThe weights for the gate-projection layer of all experts in MoE. Shape: [numExperts, hiddenSize, moeInterSize].
fcUpWeightsThe weights for the up-projection layer of all experts in MoE. Shape: [numExperts, hiddenSize, moeInterSize].
fcDownWeightsThe weights for the down-projection layer of all experts in MoE. Shape: [numExperts, moeInterSize, hiddenSize].
activationTypeThe activation function to use for the MoE layer. Currently only kSILU is supported.
See also
setActivationType()
getActivationType()

◆ setInput() [1/2]

void nvinfer1::IMoELayer::setInput ( int32_t  index,
ITensor tensor 
)
inlinenoexcept

Set the input of the MoE layer.

Parameters
indexthe index of the input to modify.
tensorthe new input tensor

The indices are as follows:

Input 0: hiddenStates: the input activations, with shape [batchSize, seqLen, hiddenSize] Input 1: selectedExpertsForTokens: the selected experts for tokens, with shape [batchSize, seqLen, topK] Input 2: scoresForSelectedExperts: the scores for selected experts, with shape [batchSize, seqLen, topK]

◆ setInput() [2/2]

void nvinfer1::ILayer::setInput ( int32_t  index,
ITensor tensor 
)
inlinenoexcept

Replace an input of this layer with a specific tensor.

Parameters
indexthe index of the input to modify.
tensorthe new input tensor

Except for IFillLayer, ILoopOutputLayer, INMSLayer, IResizeLayer, IShuffleLayer, and ISliceLayer, this method cannot change the number of inputs to a layer. The index argument must be less than the value of getNbInputs().

See comments for overloads of setInput() for layers with special behavior.

◆ setQuantizationBlockShape()

void nvinfer1::IMoELayer::setQuantizationBlockShape ( Dims const &  blockShape)
inlinenoexcept

Set the block shape for the quantization of the Mul output.

Parameters
blockShapethe block shape for the quantization of the Mul output.

The shape must have rank 3 and the dimensions representing block sizes for Mul output dimensions (batchSize, seqLen, moeInterSize) respectively. For example, a shape of [1, 1, 16] means block quantization on the last (moeInterSize) axis. -1 means a fully blocked dimension.

See also
getQuantizationBlockShape()

◆ setQuantizationDynamicDblQ()

void nvinfer1::IMoELayer::setQuantizationDynamicDblQ ( ITensor fcDownActivationDblQScale,
DataType  dataType,
Dims const &  blockShape,
DataType  dynQOutputScaleType 
)
inlinenoexcept

Configure dynamic quantization (with double quantization) after the mul op. ┌── fcGate ── activation ───┐ ┌──── DQ │ │ │ │ hiddenStates ───┤ ├── mul ── {DynQ ── DQ} ── fcDown ── output │ │ └── fcUp ───────────────────┘ When using mul output dynamic quantization (with double quantization), the user must provide:

Parameters
fcDownActivationDblQScalethe double quantization scale tensor.
dataTypethe type that the activation is quantized to.
blockShapethe blockShape used in quantization.
dynQOutputScaleTypethe data type of the scale tensor. In addition, the user should also insert DynQ/DQ/DQ before the hiddenStates input of the MoE layer. The quantization method must be the same as the quantization method here.

If setQuantizationStatic is called, then previous calls to this function are overridden. If setQuantizationToType, setQuantizationBlockShape or setDynQOutputScaleType is called, previous parameters set by this function are overridden.

See also
setQuantizationToType()
getQuantizationToType()
setQuantizationBlockShape()
getQuantizationBlockShape()
setDynQOutputScaleType()
getDynQOutputScaleType()

◆ setQuantizationStatic()

void nvinfer1::IMoELayer::setQuantizationStatic ( ITensor fcDownActivationScale,
DataType  dataType 
)
inlinenoexcept

Configure static quantization after the mul op. ┌── fcGate ── activation ───┐ │ │ hiddenStates ───┤ ├── mul ── {Q ── DQ} ── fcDown ── output │ │ └── fcUp ───────────────────┘ When using mul output static quantization, the user must provide:

Parameters
fcDownActivationScalethe scale tensor.
dataTypethe type that the activation is quantized to. In addition, the user should also insert Q/DQ before the hiddenStates input of the MoE layer. The quantization method must be the same as the quantization method here.

If setQuantizationDynamicDblQ is called, then previous calls to this function are overridden. If setQuantizationToType is called, previous parameters set by this function are overridden.

See also
setQuantizationToType()
getQuantizationToType()

◆ setQuantizationToType()

void nvinfer1::IMoELayer::setQuantizationToType ( DataType  type)
inlinenoexcept

Set the data type the mul output is quantized to.

Parameters
typethe data type the mul output is quantized to. The type must be one of DataType::kFP8, DataType::kFP4.

Default: DataType::kFLOAT which means the MoE layer is not quantized.

See also
getQuantizationToType()

◆ setSwigluParamAlpha()

void nvinfer1::IMoELayer::setSwigluParamAlpha ( float  alpha)
inlinenoexcept

Set the SwiGLU parameter alpha.

Parameters
alphathe SwiGLU parameter alpha.

Default: 1.0

See also
getSwigluParamAlpha()

◆ setSwigluParamBeta()

void nvinfer1::IMoELayer::setSwigluParamBeta ( float  beta)
inlinenoexcept

Set the SwiGLU parameter beta.

Parameters
betathe SwiGLU parameter beta.

Default: 0.0

See also
getSwigluParamBeta()

◆ setSwigluParamLimit()

void nvinfer1::IMoELayer::setSwigluParamLimit ( float  limit)
inlinenoexcept

Set the SwiGLU parameter limit.

Parameters
limitthe SwiGLU parameter limit.

Default: +inf

See also
getSwigluParamLimit()

◆ setSwigluParams()

void nvinfer1::IMoELayer::setSwigluParams ( float  limit,
float  alpha,
float  beta 
)
inlinenoexcept

Set the SwiGLU parameters.

Parameters
limitthe SwiGLU parameter limit.
alphathe SwiGLU parameter alpha.
betathe SwiGLU parameter beta.

Default: +inf, 1.0, 0.0

See also
setSwigluParamLimit()
getSwigluParamLimit()
setSwigluParamAlpha()
getSwigluParamAlpha()
setSwigluParamBeta()
getSwigluParamBeta()

Member Data Documentation

◆ mImpl

apiv::VMoELayer* nvinfer1::IMoELayer::mImpl
protected

The documentation for this class was generated from the following file:

  Copyright © 2024 NVIDIA Corporation
  Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact