nemo_automodel.components.utils.compile_utils#

Module Contents#

Classes#

CompileConfig

Configuration for torch.compile.

Functions#

configure_torch_dynamo

Configure torch._dynamo settings for compilation.

enable_torch_dynamo_scalar_outputs

Enable torch.dynamo to capture scalar outputs for better Flash Attention + torch.compile compatibility.

patch_prepare_fa2_from_position_ids

Apply a simple targeted patch to fix the prepare_fa2_from_position_ids function for torch.compile compatibility.

apply_flash_attention_compile_fix

Apply the Flash Attention + torch.compile compatibility fix.

compile_model

Compile the model with Flash Attention compatibility.

create_compile_config_from_dict

Create a CompileConfig from a dictionary.

build_compile_config

Build a compile config from configuration.

Data#

API#

nemo_automodel.components.utils.compile_utils.logger#

‘getLogger(…)’

class nemo_automodel.components.utils.compile_utils.CompileConfig(
enabled: bool = False,
mode: str = 'default',
fullgraph: bool = False,
dynamic: bool = False,
backend: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
dynamo_cache_size_limit: int = 256,
)[source]#

Configuration for torch.compile.

Initialization

enabled: bool#

False

mode: str#

‘default’

fullgraph: bool#

False

dynamic: bool#

False

backend: Optional[str]#

None

options: Optional[Dict[str, Any]]#

None

dynamo_cache_size_limit: int#

256

to_dict() Dict[str, Any][source]#

Convert to dictionary.

nemo_automodel.components.utils.compile_utils.configure_torch_dynamo(
cache_size_limit: int = 256,
capture_scalar_outputs: bool = True,
)[source]#

Configure torch._dynamo settings for compilation.

Parameters:
  • cache_size_limit – Cache size limit for dynamo compilation

  • capture_scalar_outputs – Whether to capture scalar outputs for Flash Attention compatibility

nemo_automodel.components.utils.compile_utils.enable_torch_dynamo_scalar_outputs()[source]#

Enable torch.dynamo to capture scalar outputs for better Flash Attention + torch.compile compatibility.

nemo_automodel.components.utils.compile_utils.patch_prepare_fa2_from_position_ids()[source]#

Apply a simple targeted patch to fix the prepare_fa2_from_position_ids function for torch.compile compatibility.

This is the key function that needs the fix for the max_length computation.

nemo_automodel.components.utils.compile_utils.apply_flash_attention_compile_fix()[source]#

Apply the Flash Attention + torch.compile compatibility fix.

This enables scalar output capture and patches the key function that causes issues. Note: This function is focused solely on Flash Attention compatibility. For dynamo configuration (cache size, etc.), use configure_torch_dynamo() separately.

nemo_automodel.components.utils.compile_utils.compile_model(
model: torch.nn.Module,
config: nemo_automodel.components.utils.compile_utils.CompileConfig,
) torch.nn.Module[source]#

Compile the model with Flash Attention compatibility.

Parameters:
  • model – The model to compile.

  • config – Compile configuration.

Returns:

The compiled model.

nemo_automodel.components.utils.compile_utils.create_compile_config_from_dict(
config_dict: Dict[str, Any],
) nemo_automodel.components.utils.compile_utils.CompileConfig[source]#

Create a CompileConfig from a dictionary.

Parameters:

config_dict – Dictionary containing compile configuration.

Returns:

CompileConfig instance.

nemo_automodel.components.utils.compile_utils.build_compile_config(
cfg: Optional[Dict[str, Any]],
) nemo_automodel.components.utils.compile_utils.CompileConfig[source]#

Build a compile config from configuration.

Parameters:

cfg – Configuration dictionary for compilation.

Returns:

CompileConfig instance.

nemo_automodel.components.utils.compile_utils._FLASH_ATTENTION_FIX_APPLIED#

‘apply_flash_attention_compile_fix(…)’