nemo_automodel.components.models.qwen3_5.decoder_layer#

Custom Qwen3.5 decoder layer that threads packed-sequence metadata to linear_attn.

HF’s Qwen3_5DecoderLayer.forward calls self.linear_attn with only hidden_states, cache_params, cache_position and attention_mask. For NEAT-packed inputs the linear-attn kernel additionally needs:

  • cu_seqlens – per-document cumulative lengths (FLA’s segment-reset signal for chunk_gated_delta_rule).

  • indices – non-padding token indices in the flattened sequence (used to unpad [B, T, ...] to [1, total_valid, ...] before the kernel and re-pad after; required for B>1 packed batches).

  • position_ids – needed by the CP path to undo PyTorch’s load-balanced shuffle.

This subclass derives the packing kwargs from the indexed attention_mask and forwards them, plus position_ids, into linear_attn. patch_hf_model swaps every Qwen3_5DecoderLayer instance to this class at model build time, so this is the only file that needs to know about the kwarg drop in HF’s decoder layer.

Module Contents#

Classes#

Qwen3_5DecoderLayerWithPacking

Drop-in subclass of HF Qwen3_5DecoderLayer with packing-aware dispatch.

API#

class nemo_automodel.components.models.qwen3_5.decoder_layer.Qwen3_5DecoderLayerWithPacking#

Bases: transformers.models.qwen3_5.modeling_qwen3_5.Qwen3_5DecoderLayer

Drop-in subclass of HF Qwen3_5DecoderLayer with packing-aware dispatch.

All weights and __init__ are inherited unchanged. Only forward is overridden so the linear_attn call site receives cu_seqlens, indices and position_ids in addition to attention_mask.

forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None,
cache_position: torch.LongTensor | None = None,
**kwargs,
) torch.Tensor#