core.inference.utils#
Module Contents#
Classes#
A simple counter class |
Functions#
Constructs an attention mask given the input sequence length. |
|
Initialize the cache of MoE layers once |
|
Toggle MoE drop-and-pad for decode. |
|
Swap x[src_idxs] and x[dst_idxs] |
|
Repeatedly wait for a multiprocessing event to be set, aborting upon process failure. |
Data#
API#
- class core.inference.utils.Counter(start: int = 0)#
A simple counter class
This class is responsible for assigning request ids to incoming requests
Initialization
- __next__() int#
- reset() None#
Reset counter
- core.inference.utils.get_attention_mask(seq_length: int) torch.Tensor#
Constructs an attention mask given the input sequence length.
- core.inference.utils.moe_layer_cache#
None
- core.inference.utils._init_moe_expert_cache(model)#
Initialize the cache of MoE layers once
- core.inference.utils.set_decode_expert_padding(
- model,
- set_to: bool = False,
- capacity_factor: int = None,
Toggle MoE drop-and-pad for decode.
Applies
capacity_factorto the router and all token dispatchers so decode runs with fixed shapes (CUDA graph-safe). When enabling (set_to=True), clears variable-size dispatcher metadata from prefill. For no-drop decode, usecapacity_factor = num_moe_experts / moe_router_topk.Args:
model: Module containing MoE layers.
set_to: Enable (True) or disable (False) padding.
capacity_factor: Capacity scaling shared by router and dispatchers.
- core.inference.utils.tensor_swap(x, src_idxs, dst_idxs)#
Swap x[src_idxs] and x[dst_idxs]
- async core.inference.utils.await_process_event(
- event: multiprocessing.Event,
- process: multiprocessing.Process,
- timeout: float = 1.0,
Repeatedly wait for a multiprocessing event to be set, aborting upon process failure.
Note that the timeout in this function is only for checking process liveness. Its value should be set to a relatively high number. The only problem a high timeout introduces is that an error is raised slighly later. The timeout does not have any effect on the event-waiting, only on process failure detection.
- Parameters:
event – The multiprocessing event to wait on.
process – The process to monitor for failure.
timeout – The timeout for each wait iteration in seconds.