nemo_rl.models.generation.fp8#

Module Contents#

Classes#

Functions#

my_init

my_run_engine_core

monkey_patch_vllm_ray_executor

get_vllm_qkv_scale_names

Get vLLM-compatible parameter names for Q/K/V FP8 scales.

convert_calibration_to_vllm_format

Convert NeMo-RL calibration results to vLLM parameter format.

apply_fp8_patches

init_fp8

is_fp8_model

_get_params_in_layers

_get_module_from_param_name

_is_fp8_weight

load_weights

cast_tensor_to_fp8_blockwise

process_weights_after_loading

This function is used to process the weights after loading for a Linear layer.

process_weights_after_loading_moe

This function is used to process the weights after loading for a FusedMoE layer.

process_weights_after_loading_kv

Modified version of BaseKVCacheMethod.process_weights_after_loading.

_per_token_group_quant_fp8

_per_token_group_quant_fp8_colmajor

per_token_group_quant_fp8

Data#

API#

nemo_rl.models.generation.fp8.FP8_BLOCK_QUANT_KWARGS#

None

class nemo_rl.models.generation.fp8.FP8Config#
use_weight_pow2_scale: bool#

False

use_activation_pow2_scale: bool#

False

num_first_layers_in_bf16: int#

0

num_last_layers_in_bf16: int#

0

model_parallel_size: int#

None

kv_cache_dtype: str#

‘auto’

use_fp8_weights: bool#

True

class nemo_rl.models.generation.fp8.FP8State#
seen_params: set#

‘field(…)’

fp8_param_names: set#

‘field(…)’

vllm_patches: list#

‘field(…)’

nemo_rl.models.generation.fp8.global_fp8_config: nemo_rl.models.generation.fp8.FP8Config#

None

nemo_rl.models.generation.fp8.fp8_state: nemo_rl.models.generation.fp8.FP8State#

‘FP8State(…)’

nemo_rl.models.generation.fp8.fp8_patches_applied#

False

nemo_rl.models.generation.fp8.original_run_engine_core#

None

nemo_rl.models.generation.fp8.original_init#

None

nemo_rl.models.generation.fp8.my_init(*args, **kwargs)#
nemo_rl.models.generation.fp8.my_run_engine_core(*args, **kwargs)#
nemo_rl.models.generation.fp8.monkey_patch_vllm_ray_executor(fp8_config)#
nemo_rl.models.generation.fp8.get_vllm_qkv_scale_names(layer_idx: int) dict[str, str]#

Get vLLM-compatible parameter names for Q/K/V FP8 scales.

This function centralizes the naming convention for Q/K/V scale parameters that vLLM expects. These names must match vLLM’s internal parameter structure.

Parameters:

layer_idx – The transformer layer index (0-based)

Returns:

  • ‘q_scale’: Q activation scale name

  • ’k_scale’: K activation scale name

  • ’v_scale’: V activation scale name

Return type:

Dictionary mapping scale types to vLLM parameter names

.. note::

The q_scale has an extra ‘.attn.’ component compared to k_scale/v_scale. This matches vLLM’s parameter remapping logic in: vllm.model_executor.model_loader.weight_utils.maybe_remap_kv_scale_name

.. rubric:: Example

get_vllm_qkv_scale_names(0) { ‘q_scale’: ‘model.layers.0.self_attn.attn.q_scale’, ‘k_scale’: ‘model.layers.0.self_attn.k_scale’, ‘v_scale’: ‘model.layers.0.self_attn.v_scale’ }

nemo_rl.models.generation.fp8.convert_calibration_to_vllm_format(
calibration_results: dict[str, dict[str, float]],
) dict[str, float]#

Convert NeMo-RL calibration results to vLLM parameter format.

Currently only used by megatron policy worker. After FP8 KV cache is supported by DTensor path, this function can be reused.

This function transforms the calibration output format (with layer_N keys) into the flat dictionary format that vLLM expects for parameter loading.

Parameters:

calibration_results – Dict with keys like “layer_0”, “layer_1”, etc. Each value is a dict with keys: “q_scale”, “k_scale”, “v_scale” and corresponding float scale values.

Returns:

Flat dictionary mapping vLLM parameter names to scale values. Keys follow vLLM’s naming convention as defined in get_vllm_qkv_scale_names.

.. rubric:: Example

calib = { … “layer_0”: {“q_scale”: 1.0, “k_scale”: 2.0, “v_scale”: 3.0}, … “layer_1”: {“q_scale”: 1.5, “k_scale”: 2.5, “v_scale”: 3.5} … } convert_calibration_to_vllm_format(calib) { ‘model.layers.0.self_attn.attn.q_scale’: 1.0, ‘model.layers.0.self_attn.k_scale’: 2.0, ‘model.layers.0.self_attn.v_scale’: 3.0, ‘model.layers.1.self_attn.attn.q_scale’: 1.5, ‘model.layers.1.self_attn.k_scale’: 2.5, ‘model.layers.1.self_attn.v_scale’: 3.5 }

nemo_rl.models.generation.fp8.apply_fp8_patches(self, fp8_config)#
nemo_rl.models.generation.fp8.init_fp8(vllm_cfg, model_name, model_parallel_size)#
nemo_rl.models.generation.fp8.is_fp8_model(vllm_config)#
nemo_rl.models.generation.fp8._get_params_in_layers(param_names, layers)#
nemo_rl.models.generation.fp8._get_module_from_param_name(model, name: str)#
nemo_rl.models.generation.fp8._is_fp8_weight(name, model)#
nemo_rl.models.generation.fp8.load_weights(weights, model_runner)#
nemo_rl.models.generation.fp8.cast_tensor_to_fp8_blockwise(data_hp, weight_block_size)#
nemo_rl.models.generation.fp8.process_weights_after_loading(self, layer) None#

This function is used to process the weights after loading for a Linear layer.

Compared to the original process_weights_after_loading in vllm, we just avoid creation of new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit.

nemo_rl.models.generation.fp8.process_weights_after_loading_moe(self, layer) None#

This function is used to process the weights after loading for a FusedMoE layer.

Compared to the original process_weights_after_loading in vllm, we just avoid creation of new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit.

nemo_rl.models.generation.fp8.process_weights_after_loading_kv(self, layer) None#

Modified version of BaseKVCacheMethod.process_weights_after_loading.

Doesn’t delete k_scale, v_scale, q_scale, and prob_scale parameters to allow for dynamic updates during refit.

nemo_rl.models.generation.fp8._per_token_group_quant_fp8(
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
y_num_columns,
y_row_stride,
eps,
fp8_min,
fp8_max,
BLOCK: vllm.triton_utils.tl.constexpr,
)#
nemo_rl.models.generation.fp8._per_token_group_quant_fp8_colmajor(
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
y_num_columns,
y_row_stride,
y_s_col_stride,
eps,
fp8_min,
fp8_max,
BLOCK: vllm.triton_utils.tl.constexpr,
)#
nemo_rl.models.generation.fp8.per_token_group_quant_fp8(
*args,
**kwargs,
) tuple[torch.Tensor, torch.Tensor]#