bridge.models.hf_pretrained.state#

Module Contents#

Classes#

StateDict

A state dict accessor that provides a unified interface for querying model checkpoints.

StateSource

Abstract base class for a source of model state.

DictStateSource

A state source backed by an in-memory Python dictionary.

SafeTensorsStateSource

A state source backed by a directory of .safetensors files.

API#

class bridge.models.hf_pretrained.state.StateDict(
source: Dict[str, torch.Tensor] | bridge.models.hf_pretrained.state.StateSource,
)#

Bases: collections.abc.Mapping[str, torch.Tensor]

A state dict accessor that provides a unified interface for querying model checkpoints.

StateDict allows for efficient and flexible access to tensor data from various sources, such as in-memory dictionaries or directories of .safetensors files. A key feature is its ability to query and load only the required tensors without loading the entire checkpoint into memory, making it highly memory-efficient for large models.

It supports a flexible, pandas-like querying interface that allows for accessing tensors by exact name, a list of names, glob patterns, or regular expressions. This makes it easy to inspect and manipulate model checkpoints.

.. rubric:: Examples

Setup an example StateDict from an in-memory dictionary

import torch import re d = { 
 “model.layer.0.weight”: torch.randn(10, 10), 
 “model.layer.0.bias”: torch.randn(10), 
 “model.layer.1.weight”: torch.randn(10, 10), 
 “model.layer.1.bias”: torch.randn(10), 
 } state = StateDict(d)

1. Access a single tensor by exact key

state[“model.layer.0.weight”].shape torch.Size([10, 10])

2. Access multiple tensors with a list of strings

list(state[[“model.layer.0.weight”, “model.layer.1.weight”]].keys()) [‘model.layer.0.weight’, ‘model.layer.1.weight’]

3. Access with a glob pattern

sorted(list(state.glob(“model.layer.*.bias”).keys())) [‘model.layer.0.bias’, ‘model.layer.1.bias’]

4. Access with a compiled regex pattern

regex = re.compile(r”model.layer.0..*”) sorted(list(state[regex].keys())) [‘model.layer.0.bias’, ‘model.layer.0.weight’]

The same querying flexibility applies to checkpoints on disk. The following is a conceptual example of using StateDict with a SafetensorsStateSource to query a sharded checkpoint without loading all of it into memory.

.. code-block:: python

# Assume SafetensorsStateSource is available
# from megatron.bridge.models.state import SafetensorsStateSource

# Imagine a directory 'my_model_checkpoint/' with sharded weights.
state_from_disk = StateDict(SafetensorsStateSource('my_model_checkpoint/'))

# You can query it just like the in-memory dictionary. Only the required
# tensors (e.g., all weight tensors) will be loaded from disk.
weights = state_from_disk.glob("model.layer.*.weight")

Initialization

Initializes the StateDict query accessor.

Parameters:

source – The source of the tensor data. This can be a standard Python dictionary mapping tensor names to torch.Tensor objects, or an instance of a StateSource subclass (e.g., SafetensorsStateSource) for more advanced, out-of-memory access.

source: StateSource#

None

_get_all_keys() List[str]#

Get all available tensor keys from the underlying source.

_load_tensors(
keys_to_load: List[str],
) Dict[str, torch.Tensor]#

Load specified tensors from the underlying source.

_match_keys(
pattern: Union[str, Pattern],
) List[str]#

Match keys against a glob pattern or regex.

__getitem__(
key: Union[str, List[str], Pattern],
) Union[torch.Tensor, Dict[str, torch.Tensor]]#

Accesses state dict entries using various key types.

This method allows for retrieving tensors using:

  • A single string for an exact key match.

  • A list of strings for multiple exact key matches.

  • A string with glob-style wildcards (*, ?, []).

  • A compiled regular expression object.

Parameters:

key – A single key string, a list of keys, a glob pattern string, or a compiled regular expression.

Returns:

  • A single torch.Tensor if key is a string that matches exactly one key and does not contain wildcards.

  • A Dict[str, torch.Tensor] for all other cases (list of keys, glob pattern, or regex), mapping the matched keys to their corresponding tensors.

Raises:

KeyError – If the key (or any key in a list) is not found, or if a pattern matches no keys.

.. rubric:: Examples

d = { 
 “model.embed_tokens.weight”: torch.randn(10, 1), 
 “model.layers.0.mlp.weight”: torch.randn(10, 1), 
 “model.layers.0.self_attn.q_proj.weight”: torch.randn(10, 1), 
 “lm_head.weight”: torch.randn(10, 1), 
 } state = StateDict(d)

Exact match (returns a single tensor)

tensor = state[“model.embed_tokens.weight”] isinstance(tensor, torch.Tensor) True

List of keys (returns a dict of tensors)

tensors = state[[“model.embed_tokens.weight”, “lm_head.weight”]] sorted(tensors.keys()) [‘lm_head.weight’, ‘model.embed_tokens.weight’]

Glob pattern (returns a dict of tensors)

layer_0_weights = state[“model.layers.0.*.weight”] sorted(layer_0_weights.keys()) [‘model.layers.0.mlp.weight’, ‘model.layers.0.self_attn.q_proj.weight’]

Regex pattern (returns a dict of tensors)

import re attn_weights = state[re.compile(r”.self_attn.”)] list(attn_weights.keys()) [‘model.layers.0.self_attn.q_proj.weight’]

regex(pattern: str) Dict[str, torch.Tensor]#

Queries the state dict with a regular expression pattern.

This is a convenience method that compiles the pattern string and uses it to retrieve all matching tensors.

Parameters:

pattern – The regular expression string to match against tensor keys.

Returns:

A dictionary mapping matching tensor names to their torch.Tensor objects.

.. rubric:: Examples

d = { 
 “model.layers.0.self_attn.weight”: torch.randn(1, 1), 
 “model.layers.1.self_attn.weight”: torch.randn(1, 1), 
 “model.layers.1.mlp.weight”: torch.randn(1, 1) 
 } state = StateDict(d)

attention_weights = state.regex(r”model.layers.\d+.self_attn.*”) sorted(attention_weights.keys()) [‘model.layers.0.self_attn.weight’, ‘model.layers.1.self_attn.weight’]

glob(pattern: str) Dict[str, torch.Tensor]#

Queries the state dict with a glob pattern.

This is a convenience method for pattern matching using Unix shell-style wildcards.

Parameters:

pattern – The glob pattern string to match against tensor keys.

Returns:

A dictionary mapping matching tensor names to their torch.Tensor objects.

.. rubric:: Examples

d = { 
 “model.layers.0.mlp.weight”: torch.randn(1, 1), 
 “model.layers.0.mlp.bias”: torch.randn(1, 1), 
 “model.layers.1.mlp.weight”: torch.randn(1, 1) 
 } state = StateDict(d)

Get all mlp weights and biases from the first layer

layer_0_mlp = state.glob(“model.layers.0.mlp.*”) sorted(layer_0_mlp.keys()) [‘model.layers.0.mlp.bias’, ‘model.layers.0.mlp.weight’]

__call__() Dict[str, torch.Tensor]#

Loads and returns the entire state dict as a dictionary.

.. note::

This method loads all tensors from the source into memory. For large models, this can be memory-intensive. Prefer using pattern-based or single-key lookups for more efficient access if you only need a subset of the state dict.

Returns:

A dictionary containing all tensor names and their corresponding torch.Tensor objects.

keys() List[str]#

Get all state dict keys.

items() List[tuple]#

Get all state dict items.

__contains__(key: str) bool#

Check if a key exists in the state dict.

__repr__() str#

String representation.

get(key: str, default=None) Optional[torch.Tensor]#

Gets a tensor from the state dict. Returns default if the key is not found. Note: This method is for single key lookup and does not support patterns.

__iter__() Iterable[str]#

Iterate over state dict keys.

__len__() int#

Get number of entries in the state dict.

has_glob(pattern: str) bool#

Efficiently checks if any tensor key matches the given glob pattern. This is forwarded to the underlying StateSource which may have an optimized implementation that avoids iterating over all keys.

Parameters:

pattern – The glob pattern to match against tensor keys.

Returns:

True if a matching key is found, False otherwise.

class bridge.models.hf_pretrained.state.StateSource#

Bases: abc.ABC, collections.abc.Mapping[str, torch.Tensor]

Abstract base class for a source of model state.

This class defines a standard interface for StateDict to access tensor data, abstracting away the details of how and where the data is stored. Subclasses can implement loading from different storage backends, such as in-memory dictionaries or files on disk. This allows StateDict to handle various checkpoint formats in a uniform way.

abstractmethod get_all_keys() List[str]#

Returns a list of all available tensor keys in the source.

abstractmethod load_tensors(keys: List[str]) Dict[str, torch.Tensor]#

Loads the specified tensors from the source.

__getitem__(key: str) torch.Tensor#

Loads a single tensor by key.

__iter__() Iterable[str]#

Iterates over all tensor keys.

__len__() int#

Returns the total number of tensors in the source.

has_glob(pattern: str) bool#

Checks if any tensor key matches the given glob pattern. This default implementation is not efficient for all sources, as it may load all keys. Subclasses should override this method if a more performant implementation is available.

class bridge.models.hf_pretrained.state.DictStateSource(state_dict: Dict[str, torch.Tensor])#

Bases: bridge.models.hf_pretrained.state.StateSource

A state source backed by an in-memory Python dictionary.

This is the simplest StateSource implementation. It’s used when the entire model state dict is already loaded into a dictionary in memory.

Parameters:

state_dict – A dictionary mapping tensor names (str) to torch.Tensor objects.

Initialization

get_all_keys() List[str]#
load_tensors(keys: List[str]) Dict[str, torch.Tensor]#
class bridge.models.hf_pretrained.state.SafeTensorsStateSource(path: Union[str, pathlib.Path])#

Bases: bridge.models.hf_pretrained.state.StateSource

A state source backed by a directory of .safetensors files.

This source is designed for efficiently loading tensors from checkpoints saved in the Safetensors format, which is common for large models that are often “sharded” into multiple files.

It can handle two common scenarios:

  1. A directory containing multiple .safetensors files.

  2. A directory containing a model.safetensors.index.json file, which maps tensor names to the specific .safetensors file they reside in. This is the standard format used by Hugging Face Transformers.

Using this source allows StateDict to query for tensor keys and load only the necessary files and tensors from disk, avoiding high memory usage.

Parameters:

path – The path to the directory containing the .safetensors files and/or the index file. Can also be a Hugging Face Hub model ID.

Initialization

property path: pathlib.Path#

The local path to the checkpoint files. If the initial path is a Hugging Face Hub model ID, this property will handle downloading the necessary files and return the local cache path.

property key_to_filename_map: Dict[str, str]#

Provides a mapping from tensor keys to the safetensor filename they are stored in.

This map is constructed either from model.safetensors.index.json if it exists, or by scanning all .safetensors files in the directory. The result is cached for efficiency.

static _resolve_path(
model_name_or_path: Union[str, pathlib.Path],
) pathlib.Path#

Resolves a model name or path to a local directory. If the path is not a local directory, it is treated as a Hugging Face Hub model ID, and the corresponding files are downloaded.

get_all_keys() List[str]#
load_tensors(
keys_to_load: List[str],
) Dict[str, torch.Tensor]#
has_glob(pattern: str) bool#

Efficiently checks if any tensor key matches the given glob pattern.

This method avoids loading all tensor keys into memory at once. It scans the checkpoint index or file headers and returns as soon as a match is found.

Parameters:

pattern – The glob pattern to match against tensor keys.

Returns:

True if a matching key is found, False otherwise.

save_generator(
generator: Iterable[Tuple[str, torch.Tensor]],
output_path: Union[str, pathlib.Path],
strict: bool = True,
)#

Saves tensors from a generator to .safetensors files, preserving the original sharding structure in a memory-efficient, streaming fashion.

This method reads the sharding information (which tensor belongs to which file) from the source checkpoint. It then consumes a generator of tensors, buffering them in memory only until a complete file shard can be written to disk. This approach minimizes peak memory usage compared to collecting all tensors first.

If the original checkpoint had a model.safetensors.index.json file, a new one will be created for the saved tensors.

Parameters:
  • generator – An iterable of (tensor_name, tensor) tuples.

  • output_path – The directory where the new safetensor files and index will be saved.

  • strict – If True (default), raises a KeyError if the generator yields a tensor name not found in the original model’s sharding structure. If False, it prints a warning and skips the tensor.

_get_key_to_filename_map() Optional[Dict[str, str]]#
static _cached_get_key_to_filename_map(
model_name_or_path: Union[str, pathlib.Path],
) Optional[Dict[str, str]]#

Static, cached method to get the key-to-filename map.