TensorRT 11.0.0
nvinfer1::IAttention Class Referenceabstract

Helper for constructing an attention that consumes query, key and value tensors. More...

#include <NvInfer.h>

Inheritance diagram for nvinfer1::IAttention:
nvinfer1::INoCopy

Public Member Functions

bool setNormalizationOperation (AttentionNormalizationOp op) noexcept
 Set the normalization operation for the attention. More...
 
AttentionNormalizationOp getNormalizationOperation () const noexcept
 Get the normalization operation for the attention. More...
 
bool setMask (ITensor &mask) noexcept
 Set whether a mask will be used for the normalization operation. More...
 
ITensorgetMask () noexcept
 Get the optional mask in attention. More...
 
TRT_DEPRECATED bool setCausal (bool isCausal) noexcept
 Set whether the attention will run a causal inference. Cannot be used together with setMask(). More...
 
TRT_DEPRECATED bool getCausal () const noexcept
 Get whether the attention will run a causal inference. More...
 
bool setCausalKind (CausalMaskKind kind) noexcept
 Set the causal mask alignment orientation for the attention. More...
 
CausalMaskKind getCausalKind () const noexcept
 Get the causal mask alignment orientation for the attention. More...
 
bool setDecomposable (bool decomposable) noexcept
 Set whether the attention can be decomposed to use multiple kernels if no fused kernel support found. More...
 
bool getDecomposable () const noexcept
 Get whether the attention can be decomposed to use multiple kernels if no fused kernel support found. More...
 
bool setInput (int32_t index, ITensor &input) noexcept
 Append or replace an input of this layer with a specific tensor. More...
 
int32_t getNbInputs () const noexcept
 Get the number of inputs of IAttention. IAttention has three inputs. More...
 
ITensorgetInput (int32_t index) const noexcept
 Get the IAttention input corresponding to the given index. More...
 
int32_t getNbOutputs () const noexcept
 Get the number of outputs of a layer. IAttention has one output. More...
 
ITensorgetOutput (int32_t index) const noexcept
 Get the IAttention output corresponding to the given index. IAttention has only one output. More...
 
bool setName (char const *name) noexcept
 Set the name of the attention. More...
 
char const * getName () const noexcept
 Return the name of the attention. More...
 
bool setNormalizationQuantizeScale (ITensor &tensor) noexcept
 Set the quantization scale for the attention normalization output. More...
 
ITensorgetNormalizationQuantizeScale () const noexcept
 Get the quantization scale for the attention normalization output. More...
 
bool setNormalizationQuantizeToType (DataType type) noexcept
 Set the datatype the attention normalization is quantized to. More...
 
DataType getNormalizationQuantizeToType () const noexcept
 Get the datatype the attention normalization is quantized to. More...
 
bool setMetadata (char const *metadata) noexcept
 Set the metadata for IAttention. More...
 
char const * getMetadata () const noexcept
 Get the metadata of IAttention. More...
 
bool setNbRanks (int32_t nbRanks) noexcept
 Set the number of ranks for multi-device attention execution. More...
 
int32_t getNbRanks () const noexcept
 Get the number of ranks for multi-device execution. More...
 
TRT_NODISCARD bool setQueryForm (AttentionIOForm form) noexcept
 Set the query form. More...
 
TRT_NODISCARD AttentionIOForm getQueryForm () const noexcept
 Get the query form. More...
 
TRT_NODISCARD bool setKeyValueForm (AttentionIOForm form) noexcept
 Set the key-value form. More...
 
TRT_NODISCARD AttentionIOForm getKeyValueForm () const noexcept
 Get the key-value form. More...
 
TRT_NODISCARD bool setQueryLengths (ITensor *lengths) noexcept
 Set the query lengths tensor. More...
 
TRT_NODISCARD ITensorgetQueryLengths () const noexcept
 Get the query lengths tensor. More...
 
TRT_NODISCARD bool setKeyValueLengths (ITensor *lengths) noexcept
 Set the key-value lengths tensor. More...
 
TRT_NODISCARD ITensorgetKeyValueLengths () const noexcept
 Get the key-value lengths tensor. More...
 

Protected Member Functions

virtual ~IAttention () noexcept=0
 
- Protected Member Functions inherited from nvinfer1::INoCopy
 INoCopy ()=default
 
virtual ~INoCopy ()=default
 
 INoCopy (INoCopy const &other)=delete
 
INoCopyoperator= (INoCopy const &other)=delete
 
 INoCopy (INoCopy &&other)=delete
 
INoCopyoperator= (INoCopy &&other)=delete
 

Protected Attributes

apiv::VAttention * mImpl
 

Detailed Description

Helper for constructing an attention that consumes query, key and value tensors.

An attention subgraph implicitly includes three main components, two MatrixMultiply layers known as BMM1 and BMM2, and one normalization operation which defaults to be a Softmax. By default, IAttention is not decomposable and TensorRT will try to use a single fused kernel, which may be more efficient than if the subgraph is expressed without IAttention. Setting the IAttention to decomposable=True can allow IAttention to be decomposed to use multiple kernels if no fused kernel support found.

Query Key Value Mask (optional) NormalizationQuantizeScale (optional) | | | | | | Transpose | | | | | | | | -—BMM1-— | | | | | | | *------------------------— | | | | Normalization | | | | | *---------------------------------------------— | | ----—BMM2---— | Output

The attention has the following inputs, in order of input index:

  • Query contains the input query. It is a tensor of type kFLOAT, kHALF or kBF16, its shape depends on the query form.
    • For query form kPADDED_BHND, shape is [batchSize, numHeadsQuery, numTokens, dimHead]
    • For query form kPACKED_NHD, shape is [totalTokens, numHeadsQuery, dimHead]
  • Key contains the input key. It is a tensor of type kFLOAT, kHALF or kBF16, its shape depends on the key value form.
    • For key value form kPADDED_BHND, shape is [batchSize, numHeadsKeyValue, numTokens, dimHead]
    • For key value form kPACKED_NHD, shape is [totalTokens, numHeadsKeyValue, dimHead]
  • Value contains the input value. It is a tensor of type kFLOAT, kHALF or kBF16, its shape depends on the key value form.
    • For key value form kPADDED_BHND, shape is [batchSize, numHeadsKeyValue, numTokens, dimHead]
    • For key value form kPACKED_NHD, shape is [totalTokens, numHeadsKeyValue, dimHead]
  • Mask (optional) contains the mask value. It is a tensor of type kBOOL or the same data type of BMM1 output. Shape is [batchSize, numHeadsQuery, numTokensQuery, numTokensKeyValue] with batchSize and numHeadsQuery broadcastable. TensorRT uses stride-based indexing to load the mask data.
    • For a kBOOL mask, a True value indicates that the corresponding position is allowed to attend.
    • For other data types, the mask values will be added to the BMM1 output, known as an add mask.
  • NormalizationQuantizeScale (optional) contains the quantization scale for the attention normalization output. It is a tensor of type kFLOAT, kHALF or kBF16 with dimension 0 or 1.

The attention has one output:

  • Output has the same shape, form, and data type as the query input.
    • For query form kPADDED_BHND, output shape is [batchSize, numHeadsQuery, numTokens, dimHead]
    • For query form kPACKED_NHD, output shape is [totalTokens, numHeadsQuery, dimHead]
See also
https://docs.nvidia.com/deeplearning/tensorrt/latest/performance/best-practices.html#multi-head-attention-fusion for the complete matrix of fused kernel support.
Warning
Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.

Constructor & Destructor Documentation

◆ ~IAttention()

nvinfer1::IAttention::~IAttention ( )
inlineprotectedpure virtualdefaultnoexcept

Member Function Documentation

◆ getCausal()

TRT_DEPRECATED bool nvinfer1::IAttention::getCausal ( ) const
inlinenoexcept

Get whether the attention will run a causal inference.

See also
setCausal(), getCausalKind()
Deprecated:
Deprecated in TensorRT 10.16. Superseded by getCausalKind.
Returns
True if the attention will run a causal inference, false otherwise. Default is false.

◆ getCausalKind()

CausalMaskKind nvinfer1::IAttention::getCausalKind ( ) const
inlinenoexcept

Get the causal mask alignment orientation for the attention.

See also
setCausalKind(), CausalMaskKind
Returns
The causal mask alignment orientation. Default is kNONE.

◆ getDecomposable()

bool nvinfer1::IAttention::getDecomposable ( ) const
inlinenoexcept

Get whether the attention can be decomposed to use multiple kernels if no fused kernel support found.

Returns
True if the attention can be decomposed to use multiple kernels by the compiler, false otherwise. Default is false.
See also
setDecomposable

◆ getInput()

ITensor * nvinfer1::IAttention::getInput ( int32_t  index) const
inlinenoexcept

Get the IAttention input corresponding to the given index.

Parameters
indexThe index of the input tensor.
Returns
The input tensor, or nullptr if the index is out of range.

◆ getKeyValueForm()

TRT_NODISCARD AttentionIOForm nvinfer1::IAttention::getKeyValueForm ( ) const
inlinenoexcept

Get the key-value form.

Returns
The key-value form. Default is kPADDED_BHND.
See also
setKeyValueForm()
AttentionIOForm

◆ getKeyValueLengths()

TRT_NODISCARD ITensor * nvinfer1::IAttention::getKeyValueLengths ( ) const
inlinenoexcept

Get the key-value lengths tensor.

Returns
The key-value lengths tensor, or nullptr if not set.
See also
setKeyValueLengths()

◆ getMask()

ITensor * nvinfer1::IAttention::getMask ( )
inlinenoexcept

Get the optional mask in attention.

See also
setMask
Returns
The optional mask in attention, nullptr if no mask is set.

◆ getMetadata()

char const * nvinfer1::IAttention::getMetadata ( ) const
inlinenoexcept

Get the metadata of IAttention.

Returns
The metadata as a null-terminated C-style string. If setMetadata() has not been called, an empty string "" will be returned as a default value.
See also
setMetadata()

◆ getName()

char const * nvinfer1::IAttention::getName ( ) const
inlinenoexcept

Return the name of the attention.

See also
setName()
Returns
The name of the attention.

◆ getNbInputs()

int32_t nvinfer1::IAttention::getNbInputs ( ) const
inlinenoexcept

Get the number of inputs of IAttention. IAttention has three inputs.

Returns
The number of inputs of IAttention.

◆ getNbOutputs()

int32_t nvinfer1::IAttention::getNbOutputs ( ) const
inlinenoexcept

Get the number of outputs of a layer. IAttention has one output.

◆ getNbRanks()

int32_t nvinfer1::IAttention::getNbRanks ( ) const
inlinenoexcept

Get the number of ranks for multi-device execution.

Returns
The number of ranks configured for multi-device attention. Default is 1.
See also
setNbRanks()

◆ getNormalizationOperation()

AttentionNormalizationOp nvinfer1::IAttention::getNormalizationOperation ( ) const
inlinenoexcept

Get the normalization operation for the attention.

See also
setNormalizationOperation(), AttentionNormalizationOp
Returns
The normalization operation for the attention. Default is kSOFTMAX.

◆ getNormalizationQuantizeScale()

ITensor * nvinfer1::IAttention::getNormalizationQuantizeScale ( ) const
inlinenoexcept

Get the quantization scale for the attention normalization output.

Returns
The quantization scale for the attention normalization output or nullptr if no quantization scale is set.

◆ getNormalizationQuantizeToType()

DataType nvinfer1::IAttention::getNormalizationQuantizeToType ( ) const
inlinenoexcept

Get the datatype the attention normalization is quantized to.

Returns
The datatype the attention normalization is quantized to. The default value is DataType::kFLOAT.
Warning
Must be used after normalization quantization to type is set by setNormalizationQuantizeToType.

◆ getOutput()

ITensor * nvinfer1::IAttention::getOutput ( int32_t  index) const
inlinenoexcept

Get the IAttention output corresponding to the given index. IAttention has only one output.

Parameters
indexThe index of the output tensor.
Returns
The indexed output tensor, or nullptr if the index is out of range.

◆ getQueryForm()

TRT_NODISCARD AttentionIOForm nvinfer1::IAttention::getQueryForm ( ) const
inlinenoexcept

Get the query form.

Returns
The query form. Default is kPADDED_BHND.
See also
setQueryForm()
AttentionIOForm

◆ getQueryLengths()

TRT_NODISCARD ITensor * nvinfer1::IAttention::getQueryLengths ( ) const
inlinenoexcept

Get the query lengths tensor.

Returns
The query lengths tensor, or nullptr if not set.
See also
setQueryLengths()

◆ setCausal()

TRT_DEPRECATED bool nvinfer1::IAttention::setCausal ( bool  isCausal)
inlinenoexcept

Set whether the attention will run a causal inference. Cannot be used together with setMask().

Parameters
isCausalTrue to enable causal masking with kUPPER_LEFT alignment, false to disable causal masking.
See also
getCausal(), setCausalKind()
Deprecated:
Deprecated in TensorRT 10.16. Superseded by setCausalKind.
Returns
True if the causal inference is set successfully, false otherwise.

◆ setCausalKind()

bool nvinfer1::IAttention::setCausalKind ( CausalMaskKind  kind)
inlinenoexcept

Set the causal mask alignment orientation for the attention.

When set to kUPPER_LEFT or kLOWER_RIGHT, an implicit causal mask is applied. When set to kNONE, no causal masking is applied.

Cannot be used together with setMask(). Building with both a mask tensor and a causal orientation other than kNONE will fail validation.

Parameters
kindThe causal mask alignment to apply.
See also
getCausalKind(), CausalMaskKind
Returns
True if the causal mask kind is set successfully, false otherwise.

◆ setDecomposable()

bool nvinfer1::IAttention::setDecomposable ( bool  decomposable)
inlinenoexcept

Set whether the attention can be decomposed to use multiple kernels if no fused kernel support found.

See also
getDecomposable
Returns
True if the decomposable attention is set successfully, false otherwise.

◆ setInput()

bool nvinfer1::IAttention::setInput ( int32_t  index,
ITensor input 
)
inlinenoexcept

Append or replace an input of this layer with a specific tensor.

Parameters
indexthe index of the input to modify.
inputthe new input tensor.

The indices are as follows:

Input 0 is the input query tensor. Input 1 is the input key tensor. Input 2 is the input value tensor.

Returns
True if the input tensor is set successfully, false otherwise.

◆ setKeyValueForm()

TRT_NODISCARD bool nvinfer1::IAttention::setKeyValueForm ( AttentionIOForm  form)
inlinenoexcept

Set the key-value form.

Default is kPADDED_BHND.

Parameters
formThe key-value form.
Returns
True if the key-value form is set successfully, false otherwise.
See also
getKeyValueForm()
AttentionIOForm

◆ setKeyValueLengths()

TRT_NODISCARD bool nvinfer1::IAttention::setKeyValueLengths ( ITensor lengths)
inlinenoexcept

Set the key-value lengths tensor.

An optional tensor to specify per-batch key-value lengths. The semantics depend on the key-value form:

  • When key-value form is kPADDED_BHND: contains per-batch lengths with shape [batchSize]. Each element must be <= the sequence length dimension of the KV tensor. If not set, the sequence length dimension of the KV tensor is used for all batches.
  • When key-value form is kPACKED_NHD: contains cumulative token counts with shape [batchSize + 1]. The first element should be 0 and the last element equals totalTokens. The total_tokens dimension of the KV tensor must be >= the last element of this tensor. Must be set when key-value form is kPACKED_NHD.
Warning
When key-value form is kPACKED_NHD, providing a first element that is not 0 results in undefined behavior.
Parameters
lengthsA 1D tensor of type kINT32. If nullptr, clears a previously set key-value lengths tensor.
Returns
True if the key-value lengths tensor is set or cleared successfully, false otherwise.
See also
getKeyValueLengths()

◆ setMask()

bool nvinfer1::IAttention::setMask ( ITensor mask)
inlinenoexcept

Set whether a mask will be used for the normalization operation.

Parameters
maskthe mask tensor of type kBOOL or the same data type of BMM1 output with 4d shape broadcastable to [batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue]. For a kBOOL mask, a True value indicates that the corresponding position is allowed to attend. For other data types, the mask values will be added to the BMM1 output, known as an add mask.
See also
getMask
Returns
True if the mask is set successfully, false otherwise.

◆ setMetadata()

bool nvinfer1::IAttention::setMetadata ( char const *  metadata)
inlinenoexcept

Set the metadata for IAttention.

The metadata is emitted in the JSON returned by IEngineInspector with ProfilingVerbosity set to kDETAILED.

Parameters
metadataThe per-layer metadata.
Warning
The string name must be null-terminated and be at most 4096 bytes including the terminator.
See also
getMetadata()
getLayerInformation()
Returns
True if the metadata is set successfully, false otherwise.

◆ setName()

bool nvinfer1::IAttention::setName ( char const *  name)
inlinenoexcept

Set the name of the attention.

The name is used in error diagnostics. This method copies the name string.

Warning
The string name must be null-terminated, and be at most 4096 bytes including the terminator.
See also
getName()
Returns
True if the name is set successfully, false otherwise.

◆ setNbRanks()

bool nvinfer1::IAttention::setNbRanks ( int32_t  nbRanks)
inlinenoexcept

Set the number of ranks for multi-device attention execution.

When nbRanks > 1, this hints attention to perform multi-device attention.

Parameters
nbRanksThe number of ranks. Must be >= 1.
Returns
True if successful, false otherwise.
See also
getNbRanks()

◆ setNormalizationOperation()

bool nvinfer1::IAttention::setNormalizationOperation ( AttentionNormalizationOp  op)
inlinenoexcept

Set the normalization operation for the attention.

See also
getNormalizationOperation(), AttentionNormalizationOp
Returns
True if the normalization operation is set successfully, false otherwise.

◆ setNormalizationQuantizeScale()

bool nvinfer1::IAttention::setNormalizationQuantizeScale ( ITensor tensor)
inlinenoexcept

Set the quantization scale for the attention normalization output.

Parameters
tensorfor quantization scale. Data type must be DataType::kFLOAT, DataType::kHALF or DataType::kBF16. Must be a 0-d or 1-d.
Returns
True if the quantization scale is set successfully, false otherwise.
Warning
Must be used together with setNormalizationQuantizeToType to set normalization output datatype to DataType::kFP8 or DataType::kINT8.

◆ setNormalizationQuantizeToType()

bool nvinfer1::IAttention::setNormalizationQuantizeToType ( DataType  type)
inlinenoexcept

Set the datatype the attention normalization is quantized to.

Parameters
typethe datatype the attention normalization is quantized to. Must be one of DataType::kFP8, DataType::kINT8.
Returns
True if the quantization to type is set successfully, false otherwise.

◆ setQueryForm()

TRT_NODISCARD bool nvinfer1::IAttention::setQueryForm ( AttentionIOForm  form)
inlinenoexcept

Set the query form.

Default is kPADDED_BHND.

Parameters
formThe query form.
Returns
True if the query form is set successfully, false otherwise.
See also
getQueryForm()
AttentionIOForm

◆ setQueryLengths()

TRT_NODISCARD bool nvinfer1::IAttention::setQueryLengths ( ITensor lengths)
inlinenoexcept

Set the query lengths tensor.

An optional tensor to specify the cumulative number of tokens per batch element. Must be set when query form is kPACKED_NHD. Ignored when query form is kPADDED_BHND. When set, contains cumulative token counts with shape [batchSize + 1]. The first element should be 0 and the last element equals totalTokens. The number of tokens for batch i is lengths[i + 1] - lengths[i]. The total_tokens dimension of the query tensor must be >= the last element of this tensor.

Warning
Providing a first element that is not 0 results in undefined behavior.
Parameters
lengthsA 1D tensor of type kINT32 with shape [batchSize + 1]. If nullptr, clears a previously set query lengths tensor.
Returns
True if the query lengths tensor is set or cleared successfully, false otherwise.
See also
getQueryLengths()

Member Data Documentation

◆ mImpl

apiv::VAttention* nvinfer1::IAttention::mImpl
protected

The documentation for this class was generated from the following file:

  Copyright © 2024 NVIDIA Corporation
  Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact