nemo_rl.models.generation.fp8#

Module Contents#

Classes#

Functions#

Data#

API#

nemo_rl.models.generation.fp8.FP8_BLOCK_QUANT_KWARGS#

None

class nemo_rl.models.generation.fp8.FP8Config#
use_weight_pow2_scale: bool#

False

use_activation_pow2_scale: bool#

False

num_first_layers_in_bf16: int#

0

num_last_layers_in_bf16: int#

0

model_parallel_size: int#

None

class nemo_rl.models.generation.fp8.FP8State#
seen_params: set#

‘field(…)’

fp8_param_names: set#

‘field(…)’

vllm_patches: list#

‘field(…)’

nemo_rl.models.generation.fp8.global_fp8_config: nemo_rl.models.generation.fp8.FP8Config#

None

nemo_rl.models.generation.fp8.fp8_state: nemo_rl.models.generation.fp8.FP8State#

‘FP8State(…)’

nemo_rl.models.generation.fp8.fp8_patches_applied#

False

nemo_rl.models.generation.fp8.original_run_engine_core#

None

nemo_rl.models.generation.fp8.original_init#

None

nemo_rl.models.generation.fp8.my_init(*args, **kwargs)#
nemo_rl.models.generation.fp8.my_run_engine_core(*args, **kwargs)#
nemo_rl.models.generation.fp8.monkey_patch_vllm_ray_executor(fp8_config)#
nemo_rl.models.generation.fp8.apply_fp8_patches(self, fp8_config)#
nemo_rl.models.generation.fp8.init_fp8(vllm_cfg, model_name, model_parallel_size)#
nemo_rl.models.generation.fp8.is_fp8_model(vllm_config)#
nemo_rl.models.generation.fp8._get_params_in_layers(param_names, layers)#
nemo_rl.models.generation.fp8._get_module_from_param_name(model, name: str)#
nemo_rl.models.generation.fp8._is_fp8_weight(name, model)#
nemo_rl.models.generation.fp8.load_weights(weights, model_runner)#
nemo_rl.models.generation.fp8.cast_tensor_to_fp8_blockwise(data_hp, weight_block_size)#
nemo_rl.models.generation.fp8.process_weights_after_loading(self, layer) None#
nemo_rl.models.generation.fp8._per_token_group_quant_fp8(
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
y_num_columns,
y_row_stride,
eps,
fp8_min,
fp8_max,
BLOCK: vllm.triton_utils.tl.constexpr,
)#
nemo_rl.models.generation.fp8._per_token_group_quant_fp8_colmajor(
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
y_num_columns,
y_row_stride,
y_s_col_stride,
eps,
fp8_min,
fp8_max,
BLOCK: vllm.triton_utils.tl.constexpr,
)#
nemo_rl.models.generation.fp8.per_token_group_quant_fp8(
*args,
**kwargs,
) tuple[torch.Tensor, torch.Tensor]#