nemo_automodel.components.models.qwen3_5.decoder_layer#
Custom Qwen3.5 decoder layer that threads packed-sequence metadata to linear_attn.
HF’s Qwen3_5DecoderLayer.forward calls self.linear_attn with only
hidden_states, cache_params, cache_position and attention_mask.
For NEAT-packed inputs the linear-attn kernel additionally needs:
cu_seqlens– per-document cumulative lengths (FLA’s segment-reset signal forchunk_gated_delta_rule).indices– non-padding token indices in the flattened sequence (used to unpad[B, T, ...]to[1, total_valid, ...]before the kernel and re-pad after; required for B>1 packed batches).position_ids– needed by the CP path to undo PyTorch’s load-balanced shuffle.
This subclass derives the packing kwargs from the indexed attention_mask and
forwards them, plus position_ids, into linear_attn. patch_hf_model
swaps every Qwen3_5DecoderLayer instance to this class at model build time,
so this is the only file that needs to know about the kwarg drop in HF’s
decoder layer.
Module Contents#
Classes#
Drop-in subclass of HF |
API#
- class nemo_automodel.components.models.qwen3_5.decoder_layer.Qwen3_5DecoderLayerWithPacking#
Bases:
transformers.models.qwen3_5.modeling_qwen3_5.Qwen3_5DecoderLayerDrop-in subclass of HF
Qwen3_5DecoderLayerwith packing-aware dispatch.All weights and
__init__are inherited unchanged. Onlyforwardis overridden so thelinear_attncall site receivescu_seqlens,indicesandposition_idsin addition toattention_mask.- forward(
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values=None,
- cache_position: torch.LongTensor | None = None,
- **kwargs,