nemo_automodel.components.utils.compile_utils

View as Markdown

Module Contents

Classes

NameDescription
CompileConfigConfiguration for torch.compile.

Functions

NameDescription
apply_flash_attention_compile_fixApply the Flash Attention + torch.compile compatibility fix.
build_compile_configBuild a compile config from configuration.
compile_modelCompile the model with Flash Attention compatibility.
configure_torch_dynamoConfigure torch._dynamo settings for compilation.
create_compile_config_from_dictCreate a CompileConfig from a dictionary.
enable_torch_dynamo_scalar_outputsEnable torch.dynamo to capture scalar outputs for better Flash Attention + torch.compile compatibility.
patch_prepare_fa2_from_position_idsApply a simple targeted patch to fix the prepare_fa2_from_position_ids function

Data

_FLASH_ATTENTION_FIX_APPLIED

logger

API

class nemo_automodel.components.utils.compile_utils.CompileConfig(
enabled: bool = False,
mode: str = 'default',
fullgraph: bool = False,
dynamic: bool = False,
backend: typing.Optional[str] = None,
options: typing.Optional[typing.Dict[str, typing.Any]] = None,
dynamo_cache_size_limit: int = 256
)
Dataclass

Configuration for torch.compile.

options
Optional[Dict[str, Any]] = options or {}
nemo_automodel.components.utils.compile_utils.CompileConfig.to_dict() -> typing.Dict[str, typing.Any]

Convert to dictionary.

nemo_automodel.components.utils.compile_utils.apply_flash_attention_compile_fix()

Apply the Flash Attention + torch.compile compatibility fix.

This enables scalar output capture and patches the key function that causes issues. Note: This function is focused solely on Flash Attention compatibility. For dynamo configuration (cache size, etc.), use configure_torch_dynamo() separately.

nemo_automodel.components.utils.compile_utils.build_compile_config(
cfg: typing.Optional[typing.Dict[str, typing.Any]]
) -> nemo_automodel.components.utils.compile_utils.CompileConfig

Build a compile config from configuration.

Parameters:

cfg
Optional[Dict[str, Any]]

Configuration dictionary for compilation.

Returns: CompileConfig

CompileConfig instance.

nemo_automodel.components.utils.compile_utils.compile_model(
model: torch.nn.Module,
config: nemo_automodel.components.utils.compile_utils.CompileConfig
) -> torch.nn.Module

Compile the model with Flash Attention compatibility.

Parameters:

model
nn.Module

The model to compile.

config
CompileConfig

Compile configuration.

Returns: nn.Module

The compiled model.

nemo_automodel.components.utils.compile_utils.configure_torch_dynamo(
cache_size_limit: int = 256,
capture_scalar_outputs: bool = True
)

Configure torch._dynamo settings for compilation.

Parameters:

cache_size_limit
intDefaults to 256

Cache size limit for dynamo compilation

capture_scalar_outputs
boolDefaults to True

Whether to capture scalar outputs for Flash Attention compatibility

nemo_automodel.components.utils.compile_utils.create_compile_config_from_dict(
config_dict: typing.Dict[str, typing.Any]
) -> nemo_automodel.components.utils.compile_utils.CompileConfig

Create a CompileConfig from a dictionary.

Parameters:

config_dict
Dict[str, Any]

Dictionary containing compile configuration.

Returns: CompileConfig

CompileConfig instance.

nemo_automodel.components.utils.compile_utils.enable_torch_dynamo_scalar_outputs()

Enable torch.dynamo to capture scalar outputs for better Flash Attention + torch.compile compatibility.

nemo_automodel.components.utils.compile_utils.patch_prepare_fa2_from_position_ids()

Apply a simple targeted patch to fix the prepare_fa2_from_position_ids function for torch.compile compatibility.

This is the key function that needs the fix for the max_length computation.

nemo_automodel.components.utils.compile_utils._FLASH_ATTENTION_FIX_APPLIED = apply_flash_attention_compile_fix()
nemo_automodel.components.utils.compile_utils.logger = logging.getLogger(__name__)