core.inference.utils#

Module Contents#

Classes#

InferenceMode

Process-wide flag indicating whether an inference engine is currently using the model.

Counter

A simple counter class

Functions#

device_memory_summary

One-line GPU memory summary for torch_memory_saver logging.

get_attention_mask

Constructs an attention mask given the input sequence length.

_init_moe_expert_cache

Initialize the cache of MoE layers once

set_moe_metadata_sync

Set _runs_metadata_sync on inference dispatchers.

set_decode_expert_padding

Toggle MoE drop-and-pad for decode.

check_flashinfer_jit_cache_installed

Verify that the flashinfer-jit-cache package is installed.

tensor_swap

Swap x[src_idxs] and x[dst_idxs]

await_process_call

Repeatedly wait for a multiprocessing callable to resolve, aborting upon process failure.

Data#

API#

class core.inference.utils.InferenceMode#

Process-wide flag indicating whether an inference engine is currently using the model.

Modules that need to distinguish between inference and non-inference (e.g. training, RL logprobs) paths should read InferenceMode.is_active() rather than relying on self.training, torch.is_grad_enabled(), or inference_context is not None.

_is_active: bool#

False

classmethod is_active() bool#

Return True while an inference engine is currently using the model.

classmethod set_active() None#

Mark the inference engine as active. Idempotent.

classmethod unset_active() None#

Mark the inference engine as inactive. Idempotent.

classmethod active()#

Context manager: set the flag for the duration of the with block.

core.inference.utils.device_memory_summary() str#

One-line GPU memory summary for torch_memory_saver logging.

class core.inference.utils.Counter(start: int = 0)#

A simple counter class

This class is responsible for assigning request ids to incoming requests

Initialization

__next__() int#
reset() None#

Reset counter

core.inference.utils.get_attention_mask(seq_length: int) torch.Tensor#

Constructs an attention mask given the input sequence length.

core.inference.utils.moe_layer_cache#

None

core.inference.utils._moe_metadata_sync_initialized#

False

core.inference.utils._init_moe_expert_cache(model)#

Initialize the cache of MoE layers once

core.inference.utils.set_moe_metadata_sync(model) None#

Set _runs_metadata_sync on inference dispatchers.

Exactly one dispatcher per model — the first MoE layer — fires update_metadata each step. All subsequent layers skip it to avoid redundant collective calls. Must be called once after the model is built and put into eval mode.

core.inference.utils.set_decode_expert_padding(
model,
set_to: bool = False,
capacity_factor: int = None,
)#

Toggle MoE drop-and-pad for decode.

Applies capacity_factor to the router and all token dispatchers so decode runs with fixed shapes (CUDA graph-safe). When enabling (set_to=True), clears variable-size dispatcher metadata from prefill. For no-drop decode, use capacity_factor = num_moe_experts / moe_router_topk.

Args:

  • model: Module containing MoE layers.

  • set_to: Enable (True) or disable (False) padding.

  • capacity_factor: Capacity scaling shared by router and dispatchers.

core.inference.utils.check_flashinfer_jit_cache_installed(log_version: bool = False)#

Verify that the flashinfer-jit-cache package is installed.

The flashinfer-jit-cache package provides pre-compiled CUTLASS fused MoE kernels so they don’t need to be JIT-compiled at runtime. This avoids a multi-minute compilation step during CUDA graph warmup.

Raises:

RuntimeError – If flashinfer-jit-cache is not installed and CUDA version is 12 or 13.

core.inference.utils.tensor_swap(x, src_idxs, dst_idxs)#

Swap x[src_idxs] and x[dst_idxs]

async core.inference.utils.await_process_call(
call,
process: multiprocessing.Process,
timeout: float = 1.0,
)#

Repeatedly wait for a multiprocessing callable to resolve, aborting upon process failure.

Note that the timeout in this function is only for checking process liveness. Its value should be set to a relatively high number. The only problem a high timeout introduces is that an error is raised slighly later. The timeout does not have any effect on the event-waiting, only on process failure detection.

Parameters:
  • event – The multiprocessing event to wait on.

  • process – The process to monitor for failure.

  • timeout – The timeout for each wait iteration in seconds.