nemo_automodel.components.training.neftune#
NEFTune: Noisy Embeddings Fine-Tuning.
Implements the technique from “NEFTune: Noisy Embeddings Improve Instruction Finetuning” (https://arxiv.org/abs/2310.05914). Adds scaled uniform noise to token embeddings during training to improve generalization, with no additional compute or data overhead.
Module Contents#
Classes#
Applies NEFTune noise to a model’s embedding layer during training. |
Functions#
Find the input embedding layer on a model. |
Data#
API#
- nemo_automodel.components.training.neftune.logger#
‘getLogger(…)’
- class nemo_automodel.components.training.neftune.NEFTune(noise_alpha: float = 5.0)#
Applies NEFTune noise to a model’s embedding layer during training.
NEFTune adds uniform random noise scaled by
alpha / sqrt(seq_len * hidden_dim)to the embedding output. The noise is only applied when the model is in training mode.- Parameters:
noise_alpha – Noise magnitude. Higher values add more noise. Typical values are 5-15. Set to 0 to disable.
Example::
neftune = NEFTune(noise_alpha=5.0) neftune.activate(model) # ... training loop ... neftune.deactivate(model)
Initialization
- _neftune_forward_hook(
- module: torch.nn.Module,
- input: tuple,
- output: torch.Tensor,
Forward hook that adds NEFTune noise to embedding output during training.
- activate(model: torch.nn.Module) None#
Attach NEFTune noise hook to the model’s input embedding layer.
- Parameters:
model – The model whose embeddings will be augmented with noise.
- Raises:
RuntimeError – If NEFTune is already active on this model.
ValueError – If the model has no recognizable embedding layer.
- deactivate(model: torch.nn.Module) None#
Remove the NEFTune noise hook from the model.
Safe to call even if NEFTune is not active (no-op in that case).
- Parameters:
model – The model to deactivate NEFTune on.
- property is_active: bool#
Whether NEFTune noise is currently being applied.
- nemo_automodel.components.training.neftune._get_input_embeddings(
- model: torch.nn.Module,
Find the input embedding layer on a model.
Checks for
get_input_embeddings()method first (HF models), then falls back to common attribute names.- Parameters:
model – The model to search.
- Returns:
The embedding module, or None if not found.