aiq.profiler.inference_optimization.experimental.concurrency_spike_analysis#
An enhanced script that:
Groups workflow events by example_number to build a per-example call tree (no cross-example nesting).
Tracks concurrency globally across all examples.
Identifies concurrency “spikes” (concurrency >= a threshold).
Correlates concurrency spikes with token usage and call metadata.
Computes average call latency by concurrency level, using midpoint concurrency as an approximation.
Returns a Pydantic result containing concurrency distribution, spike intervals, correlation stats, etc., along with a textual report containing the real call count, active calls in spikes, etc.
Changes from previous version:
Now shows the actual total calls in the dataset.
Displays the real number of active calls for each spike interval.
Computes and reports average latency by concurrency (no visualization).
Functions#
Sort events by time, push on |
|
Groups by example_number, builds separate call trees, returns combined list of top-level calls. |
|
|
DFS to produce a flat list of all calls (including nested). |
|
Flatten calls, produce (start, +1)/(end, -1), accumulate total time at each concurrency level. |
|
Return piecewise segments of (start, end, concurrency) across all calls. |
|
concurrency => total_time -> find concurrency level at given percentile of total time. |
If concurrency >= threshold, label that segment a 'spike'. |
|
Return all calls overlapping [start_t, end_t). |
|
For each spike, gather calls that overlap, compute average prompt_tokens, total_tokens across them. |
|
|
Approx concurrency at the midpoint of this call. |
|
For each call, find concurrency at midpoint, then bucket durations by concurrency, compute avg. |
Module Contents#
- build_call_tree_for_example(
- example_df: pandas.DataFrame,
Sort events by time, push on
*_START
, pop on*_END
, build stack-based calls for a single example.
- build_call_tree_per_example(
- df: pandas.DataFrame,
Groups by example_number, builds separate call trees, returns combined list of top-level calls.
- flatten_calls( ) list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode] #
DFS to produce a flat list of all calls (including nested).
- compute_concurrency_distribution( ) dict[int, float] #
Flatten calls, produce (start, +1)/(end, -1), accumulate total time at each concurrency level.
- build_concurrency_segments( ) list[tuple[float, float, int]] #
Return piecewise segments of (start, end, concurrency) across all calls.
- find_percentile_concurrency( ) float #
concurrency => total_time -> find concurrency level at given percentile of total time.
- detect_concurrency_spikes( ) list[aiq.profiler.inference_optimization.data_models.ConcurrencySpikeInfo] #
If concurrency >= threshold, label that segment a ‘spike’.
- find_calls_active_in_interval(
- roots: list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode],
- start_t: float,
- end_t: float,
Return all calls overlapping [start_t, end_t). Overlap => not (call.end_time <= start_t or call.start_time >= end_t).
- correlate_spike_calls(
- spikes: list[aiq.profiler.inference_optimization.data_models.ConcurrencySpikeInfo],
- roots: list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode],
For each spike, gather calls that overlap, compute average prompt_tokens, total_tokens across them.
- compute_midpoint_concurrency(
- n: aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode,
- segments: list[tuple[float, float, int]],
Approx concurrency at the midpoint of this call.
- average_latency_by_midpoint_concurrency( ) dict[int, float] #
For each call, find concurrency at midpoint, then bucket durations by concurrency, compute avg.
- concurrency_spike_analysis(
- all_steps: list[list[aiq.data_models.intermediate_step.IntermediateStep]],
- concurrency_spike_threshold: int | None = None,
Build per-example call trees (no cross-example nesting).
Compute concurrency distribution & concurrency segments across all calls.
Derive concurrency percentiles (p50, p90, p95, p99).
If threshold not provided, pick e.g. ceil of p90 concurrency.
Detect spikes, gather calls in those intervals => correlation stats.
Also compute average latency by concurrency and add to report.
Return a Pydantic object with everything, plus a textual report.