nemo_rl.models.policy.workers.patches#

Module Contents#

Functions#

_get_transformer_engine_file

Return absolute path to a Transformer Engine file or raise if it cannot be found.

apply_transformer_engine_patch

Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files.

apply_torch_aten_alias_tensor_patch

Register a sharding rule for torch.ops.aten.alias.default.

API#

nemo_rl.models.policy.workers.patches._get_transformer_engine_file(relative_path: str) str#

Return absolute path to a Transformer Engine file or raise if it cannot be found.

The relative_path should be a POSIX-style path under the transformer_engine package root, e.g. “pytorch/triton/permutation.py”.

nemo_rl.models.policy.workers.patches.apply_transformer_engine_patch()#

Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files.

This locates the target file via importlib metadata instead of importing transformer_engine, to avoid side effects during initialization. If the permutation module has already been imported, it will be reloaded so that the patched source takes effect.

nemo_rl.models.policy.workers.patches.apply_torch_aten_alias_tensor_patch()#

Register a sharding rule for torch.ops.aten.alias.default.

Work around ‘NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered’ in PyTorch 2.9. See https://github.com/pytorch/pytorch/pull/166867 for the upstream fix. We can remove this patch when we upgrade torch to include this fix.