Source code for nemo_rl.utils.timer

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from contextlib import contextmanager
from typing import Dict, List, Optional, Union

import numpy as np


[docs] class Timer: """A utility for timing code execution. Supports two usage patterns: 1. Explicit start/stop: timer.start("label"), timer.stop("label") 2. Context manager: with timer.time("label"): ... The timer keeps track of multiple timing measurements for each label, and supports different reductions on these measurements (mean, median, min, max, std dev). Example usage: ``` timer = Timer() # Method 1: start/stop timer.start("load_data") data = load_data() timer.stop("load_data") # Method 2: context manager with timer.time("model_forward"): model_outputs = model(inputs) # Multiple timing measurements for the same operation for batch in dataloader: with timer.time("model_forward_multiple"): outputs = model(batch) # Get all times for one label model_forward_times = timer.get_elapsed("model_forward_multiple") # Get reductions for one label mean_forward_time = timer.reduce("model_forward_multiple") max_forward_time = timer.reduce("model_forward_multiple", "max") ``` """ # Define valid reduction types and their corresponding NumPy functions _REDUCTION_FUNCTIONS = { "mean": np.mean, "median": np.median, "min": np.min, "max": np.max, "std": np.std, "sum": np.sum, "count": len, } def __init__(self): # Dictionary mapping labels to lists of elapsed times # We store a list of times for each label rather than a single value # to support multiple timing runs with the same label (e.g., in loops) # This allows calculating reductions like mean, min, max, and std dev self._timers: Dict[str, List[float]] = {} self._start_times: Dict[str, float] = {}
[docs] def start(self, label: str) -> None: """Start timing for the given label.""" if label in self._start_times: raise ValueError(f"Timer '{label}' is already running") self._start_times[label] = time.perf_counter()
[docs] def stop(self, label: str) -> float: """Stop timing for the given label and return the elapsed time. Args: label: The label to stop timing for Returns: The elapsed time in seconds Raises: ValueError: If the timer for the given label is not running """ if label not in self._start_times: raise ValueError( f"Timer '{label}' is not running. Running times: {self._start_times.keys()}" ) elapsed = time.perf_counter() - self._start_times[label] if label not in self._timers: self._timers[label] = [] self._timers[label].append(elapsed) del self._start_times[label] return elapsed
[docs] @contextmanager def time(self, label: str): """Context manager for timing a block of code. Args: label: The label to use for this timing Yields: None """ self.start(label) try: yield finally: self.stop(label)
[docs] def get_elapsed(self, label: str) -> List[float]: """Get all elapsed time measurements for a specific label. Args: label: The timing label to get elapsed times for Returns: A list of all elapsed time measurements in seconds Raises: KeyError: If the label doesn't exist """ if label not in self._timers: raise KeyError(f"No timings recorded for '{label}'") return self._timers[label]
[docs] def get_latest_elapsed(self, label: str) -> float: """Get the most recent elapsed time measurement for a specific label. Args: label: The timing label to get the latest elapsed time for Returns: The most recent elapsed time measurement in seconds Raises: KeyError: If the label doesn't exist IndexError: If the label exists but has no measurements """ if label not in self._timers: raise KeyError(f"No timings recorded for '{label}'") if not self._timers[label]: raise IndexError(f"No measurements recorded for '{label}'") return self._timers[label][-1]
[docs] def reduce(self, label: str, operation: str = "mean") -> float: """Apply a reduction function to timing measurements for the specified label. Args: label: The timing label to get reduction for operation: The type of reduction to apply. Valid options are: - "mean": Average time (default) - "median": Median time - "min": Minimum time - "max": Maximum time - "std": Standard deviation - "sum": Total time - "count": Number of measurements Returns: A single float with the reduction result Raises: KeyError: If the label doesn't exist ValueError: If an invalid operation is provided """ if operation not in self._REDUCTION_FUNCTIONS: valid_reductions = ", ".join(self._REDUCTION_FUNCTIONS.keys()) raise ValueError( f"Invalid operation '{operation}'. Valid options are: {valid_reductions}" ) if label not in self._timers: raise KeyError(f"No timings recorded for '{label}'") reduction_func = self._REDUCTION_FUNCTIONS[operation] return reduction_func(self._timers[label])
[docs] def get_timing_metrics( self, reduction_op: Union[str, Dict[str, str]] = "mean" ) -> Dict[str, List[float]]: """Get all timing measurements with optional reduction. Args: reduction_op: Either a string specifying a reduction operation to apply to all labels, or a dictionary mapping specific labels to reduction operations. Valid reduction operations are: "mean", "median", "min", "max", "std", "sum", "count". If a label is not in the dictionary, no reduction is applied and all measurements are returned. Returns: A dictionary mapping labels to either: - A list of all timing measurements for that label (if no reduction specified) - A single float with the reduction result (if reduction specified) Raises: ValueError: If an invalid reduction operation is provided """ if isinstance(reduction_op, str): reduction_op = {label: reduction_op for label in self._timers} results = {} for label, op in reduction_op.items(): if label not in self._timers: continue if op in self._REDUCTION_FUNCTIONS: results[label] = self.reduce(label, op) else: results[label] = self._timers[label] # Add any labels not in the reduction_op dictionary for label in self._timers: if label not in reduction_op: results[label] = self._timers[label] return results
[docs] def reset(self, label: Optional[str] = None) -> None: """Reset timings for the specified label or all labels. Args: label: Optional label to reset. If None, resets all timers. """ if label: if label in self._timers: del self._timers[label] if label in self._start_times: del self._start_times[label] else: self._timers = {} self._start_times = {}