core.inference.utils#
Module Contents#
Classes#
Process-wide flag indicating whether an inference engine is currently using the model. |
|
A simple counter class |
Functions#
One-line GPU memory summary for torch_memory_saver logging. |
|
Constructs an attention mask given the input sequence length. |
|
Initialize the cache of MoE layers once |
|
Set _runs_metadata_sync on inference dispatchers. |
|
Toggle MoE drop-and-pad for decode. |
|
Verify that the flashinfer-jit-cache package is installed. |
|
Swap x[src_idxs] and x[dst_idxs] |
|
Repeatedly wait for a multiprocessing callable to resolve, aborting upon process failure. |
Data#
API#
- class core.inference.utils.InferenceMode#
Process-wide flag indicating whether an inference engine is currently using the model.
Modules that need to distinguish between inference and non-inference (e.g. training, RL logprobs) paths should read
InferenceMode.is_active()rather than relying onself.training,torch.is_grad_enabled(), orinference_context is not None.- _is_active: bool#
False
- classmethod is_active() bool#
Return True while an inference engine is currently using the model.
- classmethod set_active() None#
Mark the inference engine as active. Idempotent.
- classmethod unset_active() None#
Mark the inference engine as inactive. Idempotent.
- classmethod active()#
Context manager: set the flag for the duration of the
withblock.
- core.inference.utils.device_memory_summary() str#
One-line GPU memory summary for torch_memory_saver logging.
- class core.inference.utils.Counter(start: int = 0)#
A simple counter class
This class is responsible for assigning request ids to incoming requests
Initialization
- __next__() int#
- reset() None#
Reset counter
- core.inference.utils.get_attention_mask(seq_length: int) torch.Tensor#
Constructs an attention mask given the input sequence length.
- core.inference.utils.moe_layer_cache#
None
- core.inference.utils._moe_metadata_sync_initialized#
False
- core.inference.utils._init_moe_expert_cache(model)#
Initialize the cache of MoE layers once
- core.inference.utils.set_moe_metadata_sync(model) None#
Set _runs_metadata_sync on inference dispatchers.
Exactly one dispatcher per model — the first MoE layer — fires update_metadata each step. All subsequent layers skip it to avoid redundant collective calls. Must be called once after the model is built and put into eval mode.
- core.inference.utils.set_decode_expert_padding(
- model,
- set_to: bool = False,
- capacity_factor: int = None,
Toggle MoE drop-and-pad for decode.
Applies
capacity_factorto the router and all token dispatchers so decode runs with fixed shapes (CUDA graph-safe). When enabling (set_to=True), clears variable-size dispatcher metadata from prefill. For no-drop decode, usecapacity_factor = num_moe_experts / moe_router_topk.Args:
model: Module containing MoE layers.
set_to: Enable (True) or disable (False) padding.
capacity_factor: Capacity scaling shared by router and dispatchers.
- core.inference.utils.check_flashinfer_jit_cache_installed(log_version: bool = False)#
Verify that the flashinfer-jit-cache package is installed.
The flashinfer-jit-cache package provides pre-compiled CUTLASS fused MoE kernels so they don’t need to be JIT-compiled at runtime. This avoids a multi-minute compilation step during CUDA graph warmup.
- Raises:
RuntimeError – If flashinfer-jit-cache is not installed and CUDA version is 12 or 13.
- core.inference.utils.tensor_swap(x, src_idxs, dst_idxs)#
Swap x[src_idxs] and x[dst_idxs]
- async core.inference.utils.await_process_call(
- call,
- process: multiprocessing.Process,
- timeout: float = 1.0,
Repeatedly wait for a multiprocessing callable to resolve, aborting upon process failure.
Note that the timeout in this function is only for checking process liveness. Its value should be set to a relatively high number. The only problem a high timeout introduces is that an error is raised slighly later. The timeout does not have any effect on the event-waiting, only on process failure detection.
- Parameters:
event – The multiprocessing event to wait on.
process – The process to monitor for failure.
timeout – The timeout for each wait iteration in seconds.