nemo_automodel.components.models.common.bidirectional

View as Markdown

Module Contents

Classes

NameDescription
EncoderStateDictAdapterAdapter for encoder model state dicts.

Data

__all__

API

class nemo_automodel.components.models.common.bidirectional.EncoderStateDictAdapter()

Bases: StateDictAdapter

Adapter for encoder model state dicts.

Internal format uses a model. prefix on all keys. HF format does not. This adapter strips or adds the model. prefix as needed, including for PEFT-wrapped keys (base_model.model.model.X <-> base_model.model.X).

_MODEL_PREFIX
= 'model.'
_PEFT_MODEL_PREFIX
= _PEFT_PREFIX + _MODEL_PREFIX
_PEFT_PREFIX
= 'base_model.model.'
nemo_automodel.components.models.common.bidirectional.EncoderStateDictAdapter._add_model_prefix(
key
)
nemo_automodel.components.models.common.bidirectional.EncoderStateDictAdapter._strip_model_prefix(
key
)
nemo_automodel.components.models.common.bidirectional.EncoderStateDictAdapter.convert_single_tensor_to_hf(
fqn,
tensor,
kwargs = {}
)
nemo_automodel.components.models.common.bidirectional.EncoderStateDictAdapter.from_hf(
hf_state_dict,
device_mesh = None,
kwargs = {}
)
nemo_automodel.components.models.common.bidirectional.EncoderStateDictAdapter.to_hf(
state_dict,
kwargs = {}
)
nemo_automodel.components.models.common.bidirectional.__all__ = ['EncoderStateDictAdapter']