nemo_automodel.components.distributed.pipelining.hf_utils
#
Module Contents#
Functions#
Patch a HF model/module to produce pipeline-compatible forward. |
|
Validate if a model is compatible with torch.distributed.pipelining. |
Data#
API#
- nemo_automodel.components.distributed.pipelining.hf_utils.logger#
‘getLogger(…)’
- nemo_automodel.components.distributed.pipelining.hf_utils.create_pipeline_forward_inner(
- model_class_name: str = 'AutoModel',
- nemo_automodel.components.distributed.pipelining.hf_utils.create_pipeline_forward_causal_lm() Callable #
- nemo_automodel.components.distributed.pipelining.hf_utils.patch_hf_model_for_pp(
- model,
- patch_inner_model: bool = True,
- patch_causal_lm_model: bool = True,
Patch a HF model/module to produce pipeline-compatible forward.
If model has .model (e.g., LlamaForCausalLM), patch inner and outer.
Else, patch the module itself.
- nemo_automodel.components.distributed.pipelining.hf_utils.init_hf_model_buffers(
- model: torch.nn.Module,
- device: torch.device,
- nemo_automodel.components.distributed.pipelining.hf_utils.validate_hf_model_for_pipeline_support(model: torch.nn.Module) None #
Validate if a model is compatible with torch.distributed.pipelining.