nemo_automodel._transformers.v4_patches.rotary#

Runtime RoPE patches for legacy v4-style remote-code models.

Module Contents#

Functions#

_to_local

Unwrap DTensor to its local shard for numeric checks.

_safe_rope_forward

Drop-in replacement matching Nemotron-Flash-1B’s native rotary forward.

_compute_flash_inv_freq

Compute inv_freq using Nemotron-Flash-1B’s own NTK/default formula.

_is_nemotron_flash_config

should_fix_rotary_embeddings

Return True when the legacy rotary workaround should run.

fix_rotary_embeddings

Install Nemotron-Flash-1B’s native NTK inv_freq deterministically.

Data#

API#

nemo_automodel._transformers.v4_patches.rotary.logger#

‘getLogger(…)’

nemo_automodel._transformers.v4_patches.rotary._to_local(t)#

Unwrap DTensor to its local shard for numeric checks.

nemo_automodel._transformers.v4_patches.rotary._safe_rope_forward(self, x, position_ids, **kwargs)#

Drop-in replacement matching Nemotron-Flash-1B’s native rotary forward.

Mirrors modeling_nemotron_flash.LlamaRotaryEmbedding.forward verbatim (incl. @torch.no_grad + autocast disable for FP32 precision) so that running this patched forward is semantically identical to letting Flash’s native forward run with the same inv_freq.

nemo_automodel._transformers.v4_patches.rotary._compute_flash_inv_freq(cfg, device, dim)#

Compute inv_freq using Nemotron-Flash-1B’s own NTK/default formula.

Copy of the relevant init branch from modeling_nemotron_flash.LlamaRotaryEmbedding.__init__. Flash’s NTK differs from transformers’ standard:

  • factor = 2 (hardcoded in Flash)

  • Reads config.orig_max_position_embeddings (not original_max_position_embeddings).

  • Scales base directly (no post-hoc attention_scaling).

nemo_automodel._transformers.v4_patches.rotary._is_nemotron_flash_config(cfg)#
nemo_automodel._transformers.v4_patches.rotary.should_fix_rotary_embeddings(model_parts)#

Return True when the legacy rotary workaround should run.

nemo_automodel._transformers.v4_patches.rotary.fix_rotary_embeddings(model_parts)#

Install Nemotron-Flash-1B’s native NTK inv_freq deterministically.

Flash’s own LlamaRotaryEmbedding.__init__ (remote code, under trust_remote_code) can land with NaN/Inf inv_freq buffers under transformers 5.x’s meta-device init context, and its NTK formula is non-standard (factor=2, reads config.orig_max_position_embeddings, no post-hoc attention_scaling), so transformers’ own ROPE_INIT_FUNCTIONS does not match it. The old version of this patch sidestepped that by overwriting inv_freq with a plain-vanilla formula (no NTK) and replacing forward with a vanilla one — but that silently downgraded training-time rope semantics relative to Flash’s native, which vanilla HF uses when reloading the consolidated checkpoint. The result was Phase 4 HF KL > 1.0, “fixed” by skipping Phase 4.

This revised patch computes inv_freq using Flash’s own NTK formula (copied verbatim from modeling_nemotron_flash.LlamaRotaryEmbedding) and installs it on every Flash rotary found, unconditionally. The forward is also replaced with _safe_rope_forward (now semantically identical to Flash’s native forward), which guards against any init-order oddity in the remote-code class. Training, Phase 3 Automodel reload, and Phase 4 vanilla HF reload all end up computing the same NTK-scaled rope.

Scope: only touches modules whose config is recognized as Nemotron- Flash (via _is_nemotron_flash_config), so non-Flash models are never affected. should_fix_rotary_embeddings further narrows the call site.

nemo_automodel._transformers.v4_patches.rotary.__all__#

[‘fix_rotary_embeddings’, ‘should_fix_rotary_embeddings’]