bridge.models.hf_pretrained.base#

Module Contents#

Classes#

PreTrainedBase

Abstract base class for all pretrained models.

API#

class bridge.models.hf_pretrained.base.PreTrainedBase(**kwargs)#

Bases: abc.ABC

Abstract base class for all pretrained models.

This class provides a generic mechanism for managing model artifacts (e.g., config, tokenizer) with lazy loading. Subclasses that are decorated with @dataclass can define artifacts as fields with metadata specifying a loader method. The model itself is handled via a dedicated property that relies on the abstract _load_model method.

.. rubric:: Example

@dataclass class MyModel(PreTrainedBase): config: AutoConfig = field( init=False, metadata=artifact(loader=”_load_config”) )

def _load_model(self) -> "PreTrainedModel":
    # Implementation for the loading logic
    ...

Initialization

model_name_or_path: Union[str, pathlib.Path]#

None

ARTIFACTS: ClassVar[List[str]]#

[]

OPTIONAL_ARTIFACTS: ClassVar[List[str]]#

[]

get_artifacts() Dict[str, str]#

Get the artifacts dictionary mapping artifact names to their attribute names.

_copy_custom_modeling_files(
source_path: Union[str, pathlib.Path],
target_path: Union[str, pathlib.Path],
) None#

Copy custom modeling files from source to target directory.

This preserves custom modeling files that were used during model loading with trust_remote_code=True, ensuring the saved model can be loaded properly.

Parameters:
  • source_path – Source directory containing custom modeling files

  • target_path – Target directory to copy files to

save_artifacts(
save_directory: Union[str, pathlib.Path],
original_source_path: Optional[Union[str, pathlib.Path]] = None,
)#

Saves all loaded, generic artifacts that have a save_pretrained method to the specified directory. Note: This does not save the model attribute.

If the model was loaded with trust_remote_code=True, this method will also attempt to preserve any custom modeling files to ensure the saved model can be loaded properly.

abstractmethod _load_model() transformers.PreTrainedModel#

Subclasses must implement this to load the main model.

abstractmethod _load_config() transformers.AutoConfig#

Subclasses must implement this to load the model config.

property model: transformers.PreTrainedModel#

Lazily loads and returns the underlying model.

property config: transformers.AutoConfig#

Lazy load and return the model config.

property state: megatron.bridge.models.hf_pretrained.state.StateDict#

Get the state dict accessor for pandas-like querying.

This accessor can be backed by either a fully loaded model in memory or a “.safetensors” checkpoint on disk, enabling lazy loading of tensors.

.. rubric:: Examples

model.state() # Get full state dict model.state[“key”] # Get single entry model.state[[“key1”, “key2”]] # Get multiple entries model.state[”.weight”] # Glob pattern model.state.regex(r”..bias$”) # Regex pattern