aiq.profiler.inference_optimization.experimental.concurrency_spike_analysis#

An enhanced script that:

  1. Groups workflow events by example_number to build a per-example call tree (no cross-example nesting).

  2. Tracks concurrency globally across all examples.

  3. Identifies concurrency “spikes” (concurrency >= a threshold).

  4. Correlates concurrency spikes with token usage and call metadata.

  5. Computes average call latency by concurrency level, using midpoint concurrency as an approximation.

  6. Returns a Pydantic result containing concurrency distribution, spike intervals, correlation stats, etc., along with a textual report containing the real call count, active calls in spikes, etc.

Changes from previous version:

  • Now shows the actual total calls in the dataset.

  • Displays the real number of active calls for each spike interval.

  • Computes and reports average latency by concurrency (no visualization).

Functions#

build_call_tree_for_example(...)

Sort events by time, push on *_START, pop on *_END, build stack-based calls for a single example.

build_call_tree_per_example(...)

Groups by example_number, builds separate call trees, returns combined list of top-level calls.

flatten_calls(...)

DFS to produce a flat list of all calls (including nested).

compute_concurrency_distribution(→ dict[int, float])

Flatten calls, produce (start, +1)/(end, -1), accumulate total time at each concurrency level.

build_concurrency_segments(→ list[tuple[float, float, ...)

Return piecewise segments of (start, end, concurrency) across all calls.

find_percentile_concurrency(→ float)

concurrency => total_time -> find concurrency level at given percentile of total time.

detect_concurrency_spikes(...)

If concurrency >= threshold, label that segment a 'spike'.

find_calls_active_in_interval(...)

Return all calls overlapping [start_t, end_t).

correlate_spike_calls(...)

For each spike, gather calls that overlap, compute average prompt_tokens, total_tokens across them.

compute_midpoint_concurrency(→ float)

Approx concurrency at the midpoint of this call.

average_latency_by_midpoint_concurrency(→ dict[int, float])

For each call, find concurrency at midpoint, then bucket durations by concurrency, compute avg.

concurrency_spike_analysis(...)

Module Contents#

build_call_tree_for_example(
example_df: pandas.DataFrame,
) list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode]#

Sort events by time, push on *_START, pop on *_END, build stack-based calls for a single example.

build_call_tree_per_example(
df: pandas.DataFrame,
) list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode]#

Groups by example_number, builds separate call trees, returns combined list of top-level calls.

flatten_calls(
roots: list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode],
) list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode]#

DFS to produce a flat list of all calls (including nested).

compute_concurrency_distribution(
roots: list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode],
) dict[int, float]#

Flatten calls, produce (start, +1)/(end, -1), accumulate total time at each concurrency level.

build_concurrency_segments(
roots: list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode],
) list[tuple[float, float, int]]#

Return piecewise segments of (start, end, concurrency) across all calls.

find_percentile_concurrency(
dist_map: dict[int, float],
percentile: float,
) float#

concurrency => total_time -> find concurrency level at given percentile of total time.

detect_concurrency_spikes(
segments: list[tuple[float, float, int]],
threshold: int,
) list[aiq.profiler.inference_optimization.data_models.ConcurrencySpikeInfo]#

If concurrency >= threshold, label that segment a ‘spike’.

find_calls_active_in_interval(
roots: list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode],
start_t: float,
end_t: float,
) list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode]#

Return all calls overlapping [start_t, end_t). Overlap => not (call.end_time <= start_t or call.start_time >= end_t).

correlate_spike_calls(
spikes: list[aiq.profiler.inference_optimization.data_models.ConcurrencySpikeInfo],
roots: list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode],
) aiq.profiler.inference_optimization.data_models.ConcurrencyCorrelationStats#

For each spike, gather calls that overlap, compute average prompt_tokens, total_tokens across them.

compute_midpoint_concurrency(
n: aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode,
segments: list[tuple[float, float, int]],
) float#

Approx concurrency at the midpoint of this call.

average_latency_by_midpoint_concurrency(
roots: list[aiq.profiler.inference_optimization.data_models.ConcurrencyCallNode],
) dict[int, float]#

For each call, find concurrency at midpoint, then bucket durations by concurrency, compute avg.

concurrency_spike_analysis(
all_steps: list[list[aiq.data_models.intermediate_step.IntermediateStep]],
concurrency_spike_threshold: int | None = None,
) aiq.profiler.inference_optimization.data_models.ConcurrencyAnalysisResult#
  1. Build per-example call trees (no cross-example nesting).

  2. Compute concurrency distribution & concurrency segments across all calls.

  3. Derive concurrency percentiles (p50, p90, p95, p99).

  4. If threshold not provided, pick e.g. ceil of p90 concurrency.

  5. Detect spikes, gather calls in those intervals => correlation stats.

  6. Also compute average latency by concurrency and add to report.

  7. Return a Pydantic object with everything, plus a textual report.