nemo_rl.models.policy.workers.patches#
Module Contents#
Functions#
Return absolute path to a Transformer Engine file or raise if it cannot be found. |
|
Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. |
|
Register a sharding rule for |
API#
- nemo_rl.models.policy.workers.patches._get_transformer_engine_file(relative_path: str) str#
Return absolute path to a Transformer Engine file or raise if it cannot be found.
The relative_path should be a POSIX-style path under the transformer_engine package root, e.g. “pytorch/triton/permutation.py”.
- nemo_rl.models.policy.workers.patches.apply_transformer_engine_patch()#
Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files.
This locates the target file via importlib metadata instead of importing
transformer_engine, to avoid side effects during initialization. If the permutation module has already been imported, it will be reloaded so that the patched source takes effect.
- nemo_rl.models.policy.workers.patches.apply_torch_aten_alias_tensor_patch()#
Register a sharding rule for
torch.ops.aten.alias.default.Work around ‘NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered’ in PyTorch 2.9. See https://github.com/pytorch/pytorch/pull/166867 for the upstream fix. We can remove this patch when we upgrade torch to include this fix.