nemo_rl.models.generation.fp8
#
Module Contents#
Classes#
Functions#
Data#
API#
- nemo_rl.models.generation.fp8.FP8_BLOCK_QUANT_KWARGS#
None
- class nemo_rl.models.generation.fp8.FP8Config#
- use_weight_pow2_scale: bool#
False
- use_activation_pow2_scale: bool#
False
- num_first_layers_in_bf16: int#
0
- num_last_layers_in_bf16: int#
0
- model_parallel_size: int#
None
- class nemo_rl.models.generation.fp8.FP8State#
- seen_params: set#
‘field(…)’
- fp8_param_names: set#
‘field(…)’
- vllm_patches: list#
‘field(…)’
- nemo_rl.models.generation.fp8.global_fp8_config: nemo_rl.models.generation.fp8.FP8Config#
None
- nemo_rl.models.generation.fp8.fp8_state: nemo_rl.models.generation.fp8.FP8State#
‘FP8State(…)’
- nemo_rl.models.generation.fp8.fp8_patches_applied#
False
- nemo_rl.models.generation.fp8.original_run_engine_core#
None
- nemo_rl.models.generation.fp8.original_init#
None
- nemo_rl.models.generation.fp8.my_init(*args, **kwargs)#
- nemo_rl.models.generation.fp8.my_run_engine_core(*args, **kwargs)#
- nemo_rl.models.generation.fp8.monkey_patch_vllm_ray_executor(fp8_config)#
- nemo_rl.models.generation.fp8.apply_fp8_patches(self, fp8_config)#
- nemo_rl.models.generation.fp8.init_fp8(vllm_cfg, model_name, model_parallel_size)#
- nemo_rl.models.generation.fp8.is_fp8_model(vllm_config)#
- nemo_rl.models.generation.fp8._get_params_in_layers(param_names, layers)#
- nemo_rl.models.generation.fp8._get_module_from_param_name(model, name: str)#
- nemo_rl.models.generation.fp8._is_fp8_weight(name, model)#
- nemo_rl.models.generation.fp8.load_weights(weights, model_runner)#
- nemo_rl.models.generation.fp8.cast_tensor_to_fp8_blockwise(data_hp, weight_block_size)#
- nemo_rl.models.generation.fp8.process_weights_after_loading(self, layer) None #
- nemo_rl.models.generation.fp8._per_token_group_quant_fp8(
- y_ptr,
- y_q_ptr,
- y_s_ptr,
- group_size,
- y_num_columns,
- y_row_stride,
- eps,
- fp8_min,
- fp8_max,
- BLOCK: vllm.triton_utils.tl.constexpr,
- nemo_rl.models.generation.fp8._per_token_group_quant_fp8_colmajor(
- y_ptr,
- y_q_ptr,
- y_s_ptr,
- group_size,
- y_num_columns,
- y_row_stride,
- y_s_col_stride,
- eps,
- fp8_min,
- fp8_max,
- BLOCK: vllm.triton_utils.tl.constexpr,
- nemo_rl.models.generation.fp8.per_token_group_quant_fp8(
- *args,
- **kwargs,