Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
JitTransform Class#
Description#
The JitTransform class is a callback for PyTorch Lightning that applies Just-In-Time (JIT) compilation to PyTorch models. It leverages either torch.compile or the Thunder compiler to optimize model performance during training. This transformation is particularly useful for enhancing the efficiency of models with complex architectures or when deploying models in production environments.
Parameters#
- nemo.lightning.pytorch.callbacks.config#
Type: JitConfig
The configuration object specifying how JIT compilation should be applied. It includes options to select the compiler (torch.compile or Thunder), module selectors to target specific parts of the model, and additional parameters for profiling and compiler behavior.
Usage Example#
Below is an example of how to integrate JitTransform into a PyTorch Lightning Trainer:
from nemo.lightning.pytorch.callbacks import JitTransform
from nemo.lightning.pytorch.callbacks import JitConfig
from pytorch_lightning import Trainer
# Define the JIT configuration
jit_config = JitConfig(
use_torch=True,
torch_kwargs={'mode': 'default'},
module_selector='*.layer*'
)
# Initialize the JitTransform callback
jit_callback = JitTransform(config=jit_config)
# Initialize the Trainer with the JitTransform callback
trainer = Trainer(callbacks=[jit_callback])
# Proceed with training as usual
trainer.fit(model)
Attributes#
- nemo.lightning.pytorch.callbacks.config#
The JitConfig instance containing the JIT compilation settings.
Methods#
- nemo.lightning.pytorch.callbacks.on_train_epoch_start(trainer, pl_module)#
Called at the start of each training epoch. This method applies JIT compilation to the specified modules within the model based on the provided configuration.
Parameters:
trainer (pl.Trainer): The PyTorch Lightning trainer instance.
pl_module (pl.LightningModule): The PyTorch Lightning module being trained.
Detailed Description#
The JitTransform callback applies JIT compilation at the beginning of each training epoch. It identifies the target modules within the model based on the module_selector patterns provided in the JitConfig. Depending on the configuration, it either uses torch.compile or Thunder to compile the selected modules, enhancing their execution performance.
The callback ensures that compilation occurs only once by setting a _compiled flag on the pl_module. This prevents redundant compilation in subsequent epochs. Additionally, it supports profiling through Thunder’s profiler if enabled in the configuration.
Constraints#
Mutually Exclusive Options: The use_torch and use_thunder options in JitConfig are mutually exclusive. Enabling both simultaneously will raise an assertion error.
Module Selection: The module_selector supports simple wildcard patterns. Complex selectors may require extending the matching logic.
See Also#
JitConfig