bridge.models.hf_pretrained.base#

Module Contents#

Classes#

PreTrainedBase

Abstract base class for all pretrained models.

API#

class bridge.models.hf_pretrained.base.PreTrainedBase(**kwargs)#

Bases: abc.ABC

Abstract base class for all pretrained models.

This class provides a generic mechanism for managing model artifacts (e.g., config, tokenizer) with lazy loading. Subclasses that are decorated with @dataclass can define artifacts as fields with metadata specifying a loader method. The model itself is handled via a dedicated property that relies on the abstract _load_model method.

.. rubric:: Example

@dataclass class MyModel(PreTrainedBase): config: AutoConfig = field( init=False, metadata=artifact(loader=”_load_config”) )

def _load_model(self) -> "PreTrainedModel":
    # Implementation for the loading logic
    ...

Initialization

model_name_or_path: Union[str, pathlib.Path]#

None

ARTIFACTS: ClassVar[List[str]]#

[]

OPTIONAL_ARTIFACTS: ClassVar[List[str]]#

[]

get_artifacts() Dict[str, str]#

Get the artifacts dictionary mapping artifact names to their attribute names.

save_artifacts(save_directory: Union[str, pathlib.Path])#

Saves all loaded, generic artifacts that have a save_pretrained method to the specified directory. Note: This does not save the model attribute.

abstractmethod _load_model() transformers.PreTrainedModel#

Subclasses must implement this to load the main model.

abstractmethod _load_config() transformers.AutoConfig#

Subclasses must implement this to load the model config.

property model: transformers.PreTrainedModel#

Lazily loads and returns the underlying model.

property config: transformers.AutoConfig#

Lazy load and return the model config.

property state: megatron.bridge.models.hf_pretrained.state.StateDict#

Get the state dict accessor for pandas-like querying.

This accessor can be backed by either a fully loaded model in memory or a “.safetensors” checkpoint on disk, enabling lazy loading of tensors.

.. rubric:: Examples

model.state() # Get full state dict model.state[“key”] # Get single entry model.state[[“key1”, “key2”]] # Get multiple entries model.state[”.weight”] # Glob pattern model.state.regex(r”..bias$”) # Regex pattern