nemo_automodel.components.models.nemotron_v3.cache#

Module Contents#

Classes#

NemotronHybridCache

Hybrid KV cache for the NemotronH architecture (attention + Mamba2 layers).

API#

class nemo_automodel.components.models.nemotron_v3.cache.NemotronHybridCache(
config,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
)#

Hybrid KV cache for the NemotronH architecture (attention + Mamba2 layers).

Attention layers accumulate key/value tensors (growing sequence dimension). Mamba2 layers maintain fixed-size conv_state and ssm_state tensors. MLP/MoE layers have no caching.

Modeled after FalconHybridMambaAttentionDynamicCache from transformers.

Initialization

is_compileable#

False

update(
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: dict[str, Any] | None = None,
) tuple[torch.Tensor, torch.Tensor]#

Attention KV cache: append new K/V and return accumulated tensors.

update_conv_state(
layer_idx: int,
new_conv_state: torch.Tensor,
cache_position: torch.LongTensor,
) torch.Tensor#

Update Mamba conv state: full overwrite (prefill) or roll+update (decode).

get_seq_length(layer_idx: int | None = None) int#

Return attention KV cache sequence length.

reorder_cache(beam_idx: torch.LongTensor) None#

Reorder all caches for beam search.