Source code for nemo_gym.profiling

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from io import StringIO
from pathlib import Path
from subprocess import run
from typing import Optional

import yappi
from gprof2dot import main as gprof2dot_main
from pydantic import BaseModel
from pydot import graph_from_dot_file


[docs] class Profiler(BaseModel): name: str base_profile_dir: Path # Used to clean up and filter out unnecessary information in the yappi log required_str: Optional[str] = None
[docs] def model_post_init(self, context): assert " " not in self.name, f"Spaces are not allowed in profiler name, but got `{repr(self.name)}`" return super().model_post_init(context)
[docs] def _check_for_dot_installation(self) -> None: # pragma: no cover res = run("dot -h", shell=True, check=False) if res.returncode == 0: return raise RuntimeError("""You must install dot in order to use this profiling too. Please install dot using: - Mac: `brew install graphviz` - Linux: `apt update && apt install -y graphviz`""")
[docs] def start(self) -> None: self._check_for_dot_installation() yappi.set_clock_type("CPU") yappi.start() print(f"🔍 Enabled profiling for {self.name}")
[docs] def stop(self) -> None: print(f"🛑 Stopping profiler for {self.name}. Check {self.base_profile_dir} for the metrics!") yappi.stop() self.dump()
[docs] def dump(self) -> None: self.base_profile_dir.mkdir(parents=True, exist_ok=True) log_path = self.base_profile_dir / f"{self.name}.log" callgrind_path = self.base_profile_dir / f"{self.name}.callgrind" callgrind_dotfile_path = self.base_profile_dir / f"{self.name}.dot" callgrind_graph_path = self.base_profile_dir / f"{self.name}.png" yappi.get_func_stats().save(callgrind_path, type="CALLGRIND") gprof2dot_main(argv=f"--format=callgrind --output={callgrind_dotfile_path} -e 5 -n 5 {callgrind_path}".split()) (graph,) = graph_from_dot_file(callgrind_dotfile_path) graph.write_png(callgrind_graph_path) buffer = StringIO() yappi.get_func_stats().print_all( out=buffer, columns={ 0: ("name", 200), 1: ("ncall", 10), 2: ("tsub", 8), 3: ("ttot", 8), 4: ("tavg", 8), }, ) buffer.seek(0) res = "" past_header = False for line in buffer: if not past_header or (self.required_str and self.required_str in line): res += line if line.startswith("name"): past_header = True with open(log_path, "w") as f: f.write(res)