nemo_automodel.components.models.nemotron_v3.cache

View as Markdown

Module Contents

Classes

NameDescription
NemotronHybridCacheHybrid KV cache for the NemotronH architecture (attention + Mamba2 layers).

API

class nemo_automodel.components.models.nemotron_v3.cache.NemotronHybridCache(
config,
batch_size: int,
dtype: torch.dtype,
device: torch.device
)

Hybrid KV cache for the NemotronH architecture (attention + Mamba2 layers).

Attention layers accumulate key/value tensors (growing sequence dimension). Mamba2 layers maintain fixed-size conv_state and ssm_state tensors. MLP/MoE layers have no caching.

Modeled after FalconHybridMambaAttentionDynamicCache from transformers.

conv_kernel_size
= config.conv_kernel
conv_states
key_cache
list[Tensor] = []
ssm_states
value_cache
list[Tensor] = []
nemo_automodel.components.models.nemotron_v3.cache.NemotronHybridCache.get_seq_length(
layer_idx: int | None = None
) -> int

Return attention KV cache sequence length.

nemo_automodel.components.models.nemotron_v3.cache.NemotronHybridCache.reorder_cache(
beam_idx: torch.LongTensor
) -> None

Reorder all caches for beam search.

nemo_automodel.components.models.nemotron_v3.cache.NemotronHybridCache.update(
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: dict[str, typing.Any] | None = None
) -> tuple[torch.Tensor, torch.Tensor]

Attention KV cache: append new K/V and return accumulated tensors.

nemo_automodel.components.models.nemotron_v3.cache.NemotronHybridCache.update_conv_state(
layer_idx: int,
new_conv_state: torch.Tensor,
cache_position: torch.LongTensor
) -> torch.Tensor

Update Mamba conv state: full overwrite (prefill) or roll+update (decode).