bridge.models.hf_pretrained.base
#
Module Contents#
Classes#
Abstract base class for all pretrained models. |
API#
- class bridge.models.hf_pretrained.base.PreTrainedBase(**kwargs)#
Bases:
abc.ABC
Abstract base class for all pretrained models.
This class provides a generic mechanism for managing model artifacts (e.g., config, tokenizer) with lazy loading. Subclasses that are decorated with
@dataclass
can define artifacts as fields with metadata specifying a loader method. Themodel
itself is handled via a dedicated property that relies on the abstract_load_model
method... rubric:: Example
@dataclass class MyModel(PreTrainedBase): config: AutoConfig = field( init=False, metadata=artifact(loader=”_load_config”) )
def _load_model(self) -> "PreTrainedModel": # Implementation for the loading logic ...
Initialization
- model_name_or_path: Union[str, pathlib.Path]#
None
- ARTIFACTS: ClassVar[List[str]]#
[]
- OPTIONAL_ARTIFACTS: ClassVar[List[str]]#
[]
- get_artifacts() Dict[str, str] #
Get the artifacts dictionary mapping artifact names to their attribute names.
- save_artifacts(save_directory: Union[str, pathlib.Path])#
Saves all loaded, generic artifacts that have a
save_pretrained
method to the specified directory. Note: This does not save themodel
attribute.
- abstractmethod _load_model() transformers.PreTrainedModel #
Subclasses must implement this to load the main model.
- abstractmethod _load_config() transformers.AutoConfig #
Subclasses must implement this to load the model config.
- property model: transformers.PreTrainedModel#
Lazily loads and returns the underlying model.
- property config: transformers.AutoConfig#
Lazy load and return the model config.
- property state: megatron.bridge.models.hf_pretrained.state.StateDict#
Get the state dict accessor for pandas-like querying.
This accessor can be backed by either a fully loaded model in memory or a “.safetensors” checkpoint on disk, enabling lazy loading of tensors.
.. rubric:: Examples
model.state() # Get full state dict model.state[“key”] # Get single entry model.state[[“key1”, “key2”]] # Get multiple entries model.state[”.weight”] # Glob pattern model.state.regex(r”..bias$”) # Regex pattern