nemo_automodel.components.models.nemotron_v3.cache#
Module Contents#
Classes#
Hybrid KV cache for the NemotronH architecture (attention + Mamba2 layers). |
API#
- class nemo_automodel.components.models.nemotron_v3.cache.NemotronHybridCache(
- config,
- batch_size: int,
- dtype: torch.dtype,
- device: torch.device,
Hybrid KV cache for the NemotronH architecture (attention + Mamba2 layers).
Attention layers accumulate key/value tensors (growing sequence dimension). Mamba2 layers maintain fixed-size conv_state and ssm_state tensors. MLP/MoE layers have no caching.
Modeled after
FalconHybridMambaAttentionDynamicCachefrom transformers.Initialization
- is_compileable#
False
- update(
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: dict[str, Any] | None = None,
Attention KV cache: append new K/V and return accumulated tensors.
- update_conv_state(
- layer_idx: int,
- new_conv_state: torch.Tensor,
- cache_position: torch.LongTensor,
Update Mamba conv state: full overwrite (prefill) or roll+update (decode).
- get_seq_length(layer_idx: int | None = None) int#
Return attention KV cache sequence length.
- reorder_cache(beam_idx: torch.LongTensor) None#
Reorder all caches for beam search.