bridge.peft.walk_utils
#
Walking utilities for PyTorch module transformation.
This module provides utilities for recursively applying transformations to PyTorch modules, handling complex hierarchies including lists, dictionaries, and nested structures.
.. rubric:: Examples
Basic module transformation: >>> def add_tag(module, name=None, **kwargs): … module.tag = f”transformed_{name}” … return module >>> >>> model = nn.Sequential(nn.Linear(10, 5), nn.ReLU()) >>> transformed = walk(model, add_tag)
Conditional transformation: >>> def freeze_linear(module, **kwargs): … if isinstance(module, nn.Linear): … for param in module.parameters(): … param.requires_grad = False … return module >>> >>> frozen_model = walk(model, freeze_linear)
Module Contents#
Classes#
Protocol for objects that can be evaluated as boolean. |
Functions#
Applies a function to a PyTorch module or a collection of modules. |
|
Recursively apply a function to a module or collection. |
|
Checks if a predicate holds for all modules in a given module or its children, optionally recursively. |
|
Applies a transformation function to a module and optionally to its child modules. |
|
Apply a transformation function to a list of modules. |
|
Applies a transformation function to a ModuleDict of modules. |
|
Create a wrapper for a list of modules, preserving the original type. |
|
Extract kwargs that match the function signature. |
Data#
API#
- class bridge.peft.walk_utils.HasBool#
Bases:
typing.Protocol
Protocol for objects that can be evaluated as boolean.
- __bool__() bool #
- bridge.peft.walk_utils._TModule#
‘TypeVar(…)’
- bridge.peft.walk_utils.ModuleFunc#
None
- bridge.peft.walk_utils.ModulePredicate#
None
- bridge.peft.walk_utils.map(
- module: bridge.peft.walk_utils._TModule,
- func: bridge.peft.walk_utils.ModuleFunc,
- leaf_only: bool = False,
- **kwargs,
Applies a function to a PyTorch module or a collection of modules.
This function can be used to modify modules in place, such as changing their attributes, applying normalization, or any other custom transformations. It supports individual modules, lists of modules, and dictionaries of modules. The function can be applied selectively to modules that do not have parameters if
leaf_only
is set to True.- Parameters:
module – The module or collection of modules to which the function will be applied.
func – A callable that takes a module (and optionally additional keyword arguments) and returns a transformed module. The signature should be
func(module, **kwargs)
.leaf_only – If True, the function will only be applied to modules that do not have any parameters. Defaults to False.
**kwargs – Additional keyword arguments that will be passed to
func
.
- Returns:
The transformed module or collection of modules.
.. rubric:: Examples
import torch.nn as nn from megatron.bridge.peft.walk_utils import map
Example: Adding a custom attribute to all modules#
model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10)) def add_id(m, module_id=0): … m.custom_id = module_id … return m model = map(model, add_id, module_id=42)
- bridge.peft.walk_utils.walk(
- module: bridge.peft.walk_utils._TModule,
- func: bridge.peft.walk_utils.ModuleFunc,
- leaf_only: bool = False,
- **kwargs,
Recursively apply a function to a module or collection.
This function is similar to
map
, but it applies the function recursively to all child modules as well. This is useful for applying transformations that need to consider the module hierarchy.- Parameters:
module – The module or collection to recursively apply to.
func – The function to apply.
leaf_only – If True, only apply to modules without parameters. Defaults to False.
**kwargs – Additional kwargs to pass to the function.
- Returns:
The transformed module or collection.
.. rubric:: Examples
import torch.nn as nn from megatron.bridge.peft.walk_utils import walk
Example: Freezing all parameters in a model#
model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10)) def freeze_params(m): … for param in m.parameters(recurse=False): … param.requires_grad = False … return m frozen_model = walk(model, freeze_params)
- bridge.peft.walk_utils.forall(
- module: torch.nn.Module,
- func: bridge.peft.walk_utils.ModulePredicate,
- recurse: bool = False,
Checks if a predicate holds for all modules in a given module or its children, optionally recursively.
This function iterates over all modules and applies a predicate function to determine if all modules satisfy a certain condition. If
recurse
is True, it checks all child modules recursively.- Parameters:
module (nn.Module) – The root module to check.
func (ModulePredicate) – A predicate function that takes a module as input and returns a boolean or an object that can be evaluated as a boolean.
recurse (bool) – If True, applies the predicate recursively to all child modules. Defaults to False.
- Returns:
True if all modules satisfy the predicate, False otherwise.
- Return type:
bool
.. rubric:: Examples
import torch.nn as nn from megatron.bridge.peft.walk_utils import forall
model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10)) predicate = lambda m: isinstance(m, (nn.Linear, nn.Sequential, nn.ReLU)) print(forall(model, predicate, recurse=True)) True
- bridge.peft.walk_utils._map_module(
- module: bridge.peft.walk_utils._TModule,
- func: bridge.peft.walk_utils.ModuleFunc,
- recurse=False,
- leaf_only=False,
- transformed_modules=None,
- **kwargs,
Applies a transformation function to a module and optionally to its child modules.
- Parameters:
module – nn.Module The module to which the function will be applied.
func – ModuleFunc The function that will be applied to the module.
recurse – bool, optional Whether to apply the function recursively to child modules.
leaf_only – bool, optional Whether to apply the function only to modules without parameters.
transformed_modules – set, optional A set to keep track of modules that have already been transformed.
**kwargs – dict Additional keyword arguments that will be passed to the transformation function.
- Returns:
nn.Module The transformed module.
- bridge.peft.walk_utils._map_module_list(
- module_list: bridge.peft.walk_utils._TModule,
- func: bridge.peft.walk_utils.ModuleFunc,
- recurse=False,
- leaf_only=False,
- transformed_modules=None,
- **kwargs,
Apply a transformation function to a list of modules.
- bridge.peft.walk_utils._map_module_dict(
- module_dict: bridge.peft.walk_utils._TModule,
- func: bridge.peft.walk_utils.ModuleFunc,
- recurse: bool = False,
- leaf_only: bool = False,
- transformed_modules=None,
- **kwargs,
Applies a transformation function to a ModuleDict of modules.
- Parameters:
module_dict – nn.ModuleDict The ModuleDict of modules to which the function will be applied.
func – ModuleFunc The function that will be applied to the modules.
recurse – bool, optional Whether to apply the function recursively to child modules.
leaf_only – bool, optional Whether to apply the function only to modules without parameters.
**kwargs – dict Additional keyword arguments that will be passed to the transformation function.
- Returns:
nn.ModuleDict The ModuleDict of transformed modules.
- bridge.peft.walk_utils._create_list_wrapper(module_list, to_add)#
Create a wrapper for a list of modules, preserving the original type.
- bridge.peft.walk_utils._get_func_kwargs(func, **kwargs)#
Extract kwargs that match the function signature.