nemo_automodel.components.utils.compile_utils
#
Module Contents#
Classes#
Configuration for torch.compile. |
Functions#
Configure torch._dynamo settings for compilation. |
|
Enable torch.dynamo to capture scalar outputs for better Flash Attention + torch.compile compatibility. |
|
Apply a simple targeted patch to fix the prepare_fa2_from_position_ids function for torch.compile compatibility. |
|
Apply the Flash Attention + torch.compile compatibility fix. |
|
Compile the model with Flash Attention compatibility. |
|
Create a CompileConfig from a dictionary. |
|
Build a compile config from configuration. |
Data#
API#
- nemo_automodel.components.utils.compile_utils.logger#
‘getLogger(…)’
- class nemo_automodel.components.utils.compile_utils.CompileConfig(
- enabled: bool = False,
- mode: str = 'default',
- fullgraph: bool = False,
- dynamic: bool = False,
- backend: Optional[str] = None,
- options: Optional[Dict[str, Any]] = None,
- dynamo_cache_size_limit: int = 256,
Configuration for torch.compile.
Initialization
- enabled: bool#
False
- mode: str#
‘default’
- fullgraph: bool#
False
- dynamic: bool#
False
- backend: Optional[str]#
None
- options: Optional[Dict[str, Any]]#
None
- dynamo_cache_size_limit: int#
256
- nemo_automodel.components.utils.compile_utils.configure_torch_dynamo(
- cache_size_limit: int = 256,
- capture_scalar_outputs: bool = True,
Configure torch._dynamo settings for compilation.
- Parameters:
cache_size_limit – Cache size limit for dynamo compilation
capture_scalar_outputs – Whether to capture scalar outputs for Flash Attention compatibility
- nemo_automodel.components.utils.compile_utils.enable_torch_dynamo_scalar_outputs()[source]#
Enable torch.dynamo to capture scalar outputs for better Flash Attention + torch.compile compatibility.
- nemo_automodel.components.utils.compile_utils.patch_prepare_fa2_from_position_ids()[source]#
Apply a simple targeted patch to fix the prepare_fa2_from_position_ids function for torch.compile compatibility.
This is the key function that needs the fix for the max_length computation.
- nemo_automodel.components.utils.compile_utils.apply_flash_attention_compile_fix()[source]#
Apply the Flash Attention + torch.compile compatibility fix.
This enables scalar output capture and patches the key function that causes issues. Note: This function is focused solely on Flash Attention compatibility. For dynamo configuration (cache size, etc.), use configure_torch_dynamo() separately.
- nemo_automodel.components.utils.compile_utils.compile_model(
- model: torch.nn.Module,
- config: nemo_automodel.components.utils.compile_utils.CompileConfig,
Compile the model with Flash Attention compatibility.
- Parameters:
model – The model to compile.
config – Compile configuration.
- Returns:
The compiled model.
- nemo_automodel.components.utils.compile_utils.create_compile_config_from_dict(
- config_dict: Dict[str, Any],
Create a CompileConfig from a dictionary.
- Parameters:
config_dict – Dictionary containing compile configuration.
- Returns:
CompileConfig instance.
- nemo_automodel.components.utils.compile_utils.build_compile_config(
- cfg: Optional[Dict[str, Any]],
Build a compile config from configuration.
- Parameters:
cfg – Configuration dictionary for compilation.
- Returns:
CompileConfig instance.
- nemo_automodel.components.utils.compile_utils._FLASH_ATTENTION_FIX_APPLIED#
‘apply_flash_attention_compile_fix(…)’