nemo_rl.modelopt.models.policy.workers.megatron_quant_policy_worker#

Module Contents#

Classes#

API#

class nemo_rl.modelopt.models.policy.workers.megatron_quant_policy_worker.MegatronQuantPolicyWorker(config, *args, **kwargs)#

Bases: nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorkerImpl

_quantize(model)#

Quantize the model if the model is not quantized yet.

_patch_validate_model_paths()#

Patch validate_model_paths to handle quantized checkpoint paths.

In cases like distillation where the teacher model is the same as the student model, we need to save an extra quantized checkpoint. This patch checks for modelopt state and redirects to a _quantized suffix path. It also handles pre-quantized model symlinks.

_patch_setup_model_and_optimizer()#

Patch setup_model_and_optimizer to restore modelopt state when loading quantized checkpoints.

_restore_modelopt_state_pre_load(state, model)#

Restore ModelOpt state into the model before load_checkpoint runs.

Forwarded as the pre_load_checkpoint_hook to

Func:

setup_model_and_optimizer and :func:setup_reference_model_state via the _pre_load_checkpoint_hook instance attribute. Quantizers must exist on the model graph before load_checkpoint populates their amax/scale buffers.

hide_tensor_quantizers()#

Context manager that temporarily hides TensorQuantizer modules from module iteration.

enable_forward_pre_hook()#

Enable forward pre-hook, hiding TensorQuantizer modules.

disable_forward_pre_hook(param_sync=True)#

Disable forward pre-hook, hiding TensorQuantizer modules.

disable_quantization()#

Context manager that temporarily disables quantization.

_hide_extra_state()#

Patch model.state_dict() to exclude _extra_state keys.

ModelOpt appends quantization calibration data (amax/scale) to TE’s serialized extra state, making it larger than the non-quantized reference model’s copy. These are calibration metadata, not learned weights, and can also be resized by TE during forward passes. Filtering them out lets the base class swap/restore skip them cleanly.

use_reference_model() Generator[None, None, None]#

Context manager that temporarily swaps the reference model and active model.

without_model_config()#

Context manager that temporarily removes the config attribute from TensorQuantizer modules.

Used by :meth:use_reference_model and :meth:save_checkpoint. Both of these flows traverse the module tree (e.g. for state-dict swapping or checkpoint serialization) where the unrelated config attribute on TensorQuantizer instances is detected as a model config and triggers spurious validation/serialization errors. We strip it for the duration of the call and restore it on exit.

get_quantizer_stats() dict#

Return summary statistics for all enabled TensorQuantizers.

Useful for verifying that calibration ran and amax values are valid.

abstractmethod generate(**kwargs)#

Quantized Megatron generation is not supported.

ModelOpt unconditionally patches flash_decode_and_prefill on quantized attention modules, which breaks the Megatron generation path.

save_checkpoint(*args, **kwargs)#

Save the checkpoint.

static _find_weight_quantizer(module, param_weight)#

Find the enabled weight quantizer that corresponds to param_weight.

Uses ModelOpt’s QuantModule.iter_weights_for_calibration to discover (weight, weight_quantizer) pairs, then matches by identity. This handles standard weight / weight_quantizer as well as custom names like gate_up_proj / gate_up_proj_weight_quantizer.

Returns the matching TensorQuantizer or None.

_iter_params_with_optional_kv_scales(kv_scales=None)#

Pre-fold weights on-the-fly via lazy proxy tasks.

Wraps each conversion task so that reading task.param_weight returns weight_quantizer(weight) instead of the raw weight. The folded tensor is computed lazily when export_hf_weights accesses it, so only one extra weight-sized tensor exists at a time — O(1) extra memory.

Raises:

RuntimeError – If weight folding fails for a specific parameter, with context about which parameter caused the failure.