bridge.peft.walk_utils#

Walking utilities for PyTorch module transformation.

This module provides utilities for recursively applying transformations to PyTorch modules, handling complex hierarchies including lists, dictionaries, and nested structures.

.. rubric:: Examples

Basic module transformation: >>> def add_tag(module, name=None, **kwargs): … module.tag = f”transformed_{name}” … return module >>> >>> model = nn.Sequential(nn.Linear(10, 5), nn.ReLU()) >>> transformed = walk(model, add_tag)

Conditional transformation: >>> def freeze_linear(module, **kwargs): … if isinstance(module, nn.Linear): … for param in module.parameters(): … param.requires_grad = False … return module >>> >>> frozen_model = walk(model, freeze_linear)

Module Contents#

Classes#

HasBool

Protocol for objects that can be evaluated as boolean.

Functions#

map

Applies a function to a PyTorch module or a collection of modules.

walk

Recursively apply a function to a module or collection.

forall

Checks if a predicate holds for all modules in a given module or its children, optionally recursively.

_map_module

Applies a transformation function to a module and optionally to its child modules.

_map_module_list

Apply a transformation function to a list of modules.

_map_module_dict

Applies a transformation function to a ModuleDict of modules.

_create_list_wrapper

Create a wrapper for a list of modules, preserving the original type.

_get_func_kwargs

Extract kwargs that match the function signature.

Data#

API#

class bridge.peft.walk_utils.HasBool#

Bases: typing.Protocol

Protocol for objects that can be evaluated as boolean.

__bool__() bool#
bridge.peft.walk_utils._TModule#

‘TypeVar(…)’

bridge.peft.walk_utils.ModuleFunc#

None

bridge.peft.walk_utils.ModulePredicate#

None

bridge.peft.walk_utils.map(
module: bridge.peft.walk_utils._TModule,
func: bridge.peft.walk_utils.ModuleFunc,
leaf_only: bool = False,
**kwargs,
) bridge.peft.walk_utils._TModule#

Applies a function to a PyTorch module or a collection of modules.

This function can be used to modify modules in place, such as changing their attributes, applying normalization, or any other custom transformations. It supports individual modules, lists of modules, and dictionaries of modules. The function can be applied selectively to modules that do not have parameters if leaf_only is set to True.

Parameters:
  • module – The module or collection of modules to which the function will be applied.

  • func – A callable that takes a module (and optionally additional keyword arguments) and returns a transformed module. The signature should be func(module, **kwargs).

  • leaf_only – If True, the function will only be applied to modules that do not have any parameters. Defaults to False.

  • **kwargs – Additional keyword arguments that will be passed to func.

Returns:

The transformed module or collection of modules.

.. rubric:: Examples

import torch.nn as nn from megatron.bridge.peft.walk_utils import map

Example: Adding a custom attribute to all modules#

model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10)) def add_id(m, module_id=0): … m.custom_id = module_id … return m model = map(model, add_id, module_id=42)

bridge.peft.walk_utils.walk(
module: bridge.peft.walk_utils._TModule,
func: bridge.peft.walk_utils.ModuleFunc,
leaf_only: bool = False,
**kwargs,
) bridge.peft.walk_utils._TModule#

Recursively apply a function to a module or collection.

This function is similar to map, but it applies the function recursively to all child modules as well. This is useful for applying transformations that need to consider the module hierarchy.

Parameters:
  • module – The module or collection to recursively apply to.

  • func – The function to apply.

  • leaf_only – If True, only apply to modules without parameters. Defaults to False.

  • **kwargs – Additional kwargs to pass to the function.

Returns:

The transformed module or collection.

.. rubric:: Examples

import torch.nn as nn from megatron.bridge.peft.walk_utils import walk

Example: Freezing all parameters in a model#

model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10)) def freeze_params(m): … for param in m.parameters(recurse=False): … param.requires_grad = False … return m frozen_model = walk(model, freeze_params)

bridge.peft.walk_utils.forall(
module: torch.nn.Module,
func: bridge.peft.walk_utils.ModulePredicate,
recurse: bool = False,
) bool#

Checks if a predicate holds for all modules in a given module or its children, optionally recursively.

This function iterates over all modules and applies a predicate function to determine if all modules satisfy a certain condition. If recurse is True, it checks all child modules recursively.

Parameters:
  • module (nn.Module) – The root module to check.

  • func (ModulePredicate) – A predicate function that takes a module as input and returns a boolean or an object that can be evaluated as a boolean.

  • recurse (bool) – If True, applies the predicate recursively to all child modules. Defaults to False.

Returns:

True if all modules satisfy the predicate, False otherwise.

Return type:

bool

.. rubric:: Examples

import torch.nn as nn from megatron.bridge.peft.walk_utils import forall

model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10)) predicate = lambda m: isinstance(m, (nn.Linear, nn.Sequential, nn.ReLU)) print(forall(model, predicate, recurse=True)) True

bridge.peft.walk_utils._map_module(
module: bridge.peft.walk_utils._TModule,
func: bridge.peft.walk_utils.ModuleFunc,
recurse=False,
leaf_only=False,
transformed_modules=None,
**kwargs,
) bridge.peft.walk_utils._TModule#

Applies a transformation function to a module and optionally to its child modules.

Parameters:
  • module – nn.Module The module to which the function will be applied.

  • func – ModuleFunc The function that will be applied to the module.

  • recurse – bool, optional Whether to apply the function recursively to child modules.

  • leaf_only – bool, optional Whether to apply the function only to modules without parameters.

  • transformed_modules – set, optional A set to keep track of modules that have already been transformed.

  • **kwargs – dict Additional keyword arguments that will be passed to the transformation function.

Returns:

nn.Module The transformed module.

bridge.peft.walk_utils._map_module_list(
module_list: bridge.peft.walk_utils._TModule,
func: bridge.peft.walk_utils.ModuleFunc,
recurse=False,
leaf_only=False,
transformed_modules=None,
**kwargs,
) bridge.peft.walk_utils._TModule#

Apply a transformation function to a list of modules.

bridge.peft.walk_utils._map_module_dict(
module_dict: bridge.peft.walk_utils._TModule,
func: bridge.peft.walk_utils.ModuleFunc,
recurse: bool = False,
leaf_only: bool = False,
transformed_modules=None,
**kwargs,
) bridge.peft.walk_utils._TModule#

Applies a transformation function to a ModuleDict of modules.

Parameters:
  • module_dict – nn.ModuleDict The ModuleDict of modules to which the function will be applied.

  • func – ModuleFunc The function that will be applied to the modules.

  • recurse – bool, optional Whether to apply the function recursively to child modules.

  • leaf_only – bool, optional Whether to apply the function only to modules without parameters.

  • **kwargs – dict Additional keyword arguments that will be passed to the transformation function.

Returns:

nn.ModuleDict The ModuleDict of transformed modules.

bridge.peft.walk_utils._create_list_wrapper(module_list, to_add)#

Create a wrapper for a list of modules, preserving the original type.

bridge.peft.walk_utils._get_func_kwargs(func, **kwargs)#

Extract kwargs that match the function signature.