nemo_automodel.components.training.neftune#

NEFTune: Noisy Embeddings Fine-Tuning.

Implements the technique from “NEFTune: Noisy Embeddings Improve Instruction Finetuning” (https://arxiv.org/abs/2310.05914). Adds scaled uniform noise to token embeddings during training to improve generalization, with no additional compute or data overhead.

Module Contents#

Classes#

NEFTune

Applies NEFTune noise to a model’s embedding layer during training.

Functions#

_get_input_embeddings

Find the input embedding layer on a model.

Data#

API#

nemo_automodel.components.training.neftune.logger#

‘getLogger(…)’

class nemo_automodel.components.training.neftune.NEFTune(noise_alpha: float = 5.0)#

Applies NEFTune noise to a model’s embedding layer during training.

NEFTune adds uniform random noise scaled by alpha / sqrt(seq_len * hidden_dim) to the embedding output. The noise is only applied when the model is in training mode.

Parameters:

noise_alpha – Noise magnitude. Higher values add more noise. Typical values are 5-15. Set to 0 to disable.

Example::

neftune = NEFTune(noise_alpha=5.0)
neftune.activate(model)
# ... training loop ...
neftune.deactivate(model)

Initialization

_neftune_forward_hook(
module: torch.nn.Module,
input: tuple,
output: torch.Tensor,
) torch.Tensor#

Forward hook that adds NEFTune noise to embedding output during training.

activate(model: torch.nn.Module) None#

Attach NEFTune noise hook to the model’s input embedding layer.

Parameters:

model – The model whose embeddings will be augmented with noise.

Raises:
  • RuntimeError – If NEFTune is already active on this model.

  • ValueError – If the model has no recognizable embedding layer.

deactivate(model: torch.nn.Module) None#

Remove the NEFTune noise hook from the model.

Safe to call even if NEFTune is not active (no-op in that case).

Parameters:

model – The model to deactivate NEFTune on.

property is_active: bool#

Whether NEFTune noise is currently being applied.

nemo_automodel.components.training.neftune._get_input_embeddings(
model: torch.nn.Module,
) Optional[torch.nn.Module]#

Find the input embedding layer on a model.

Checks for get_input_embeddings() method first (HF models), then falls back to common attribute names.

Parameters:

model – The model to search.

Returns:

The embedding module, or None if not found.