Source code for nemo_rl.utils.logger

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import glob
import json
import logging
import os
import re
import threading
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, TypedDict

import ray
import requests
import torch
import wandb
from prometheus_client.parser import text_string_to_metric_families
from prometheus_client.samples import Sample
from rich.box import ROUNDED
from rich.console import Console
from rich.logging import RichHandler
from rich.panel import Panel
from torch.utils.tensorboard import SummaryWriter

from nemo_rl.data.interfaces import LLMMessageLogType
from nemo_rl.distributed.batched_data_dict import BatchedDataDict

# Flag to track if rich logging has been configured
_rich_logging_configured = False


[docs] class WandbConfig(TypedDict): project: str name: str
[docs] class TensorboardConfig(TypedDict): log_dir: str
[docs] class GPUMonitoringConfig(TypedDict): collection_interval: int | float flush_interval: int | float
[docs] class LoggerConfig(TypedDict): log_dir: str wandb_enabled: bool tensorboard_enabled: bool wandb: WandbConfig tensorboard: TensorboardConfig monitor_gpus: bool gpu_monitoring: GPUMonitoringConfig
[docs] class LoggerInterface(ABC): """Abstract base class for logger backends."""
[docs] @abstractmethod def log_metrics( self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = "", step_metric: Optional[str] = None, ) -> None: """Log a dictionary of metrics.""" pass
[docs] @abstractmethod def log_hyperparams(self, params: Dict[str, Any]) -> None: """Log dictionary of hyperparameters.""" pass
[docs] class TensorboardLogger(LoggerInterface): """Tensorboard logger backend.""" def __init__(self, cfg: TensorboardConfig, log_dir: Optional[str] = None): self.writer = SummaryWriter(log_dir=log_dir) print(f"Initialized TensorboardLogger at {log_dir}")
[docs] def log_metrics( self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = "", step_metric: Optional[str] = None, # ignored in TensorBoard ) -> None: """Log metrics to Tensorboard. Args: metrics: Dict of metrics to log step: Global step value prefix: Optional prefix for metric names step_metric: Optional step metric name (ignored in TensorBoard) """ for name, value in metrics.items(): if prefix: name = f"{prefix}/{name}" self.writer.add_scalar(name, value, step)
[docs] def log_hyperparams(self, params: Dict[str, Any]) -> None: """Log hyperparameters to Tensorboard. Args: params: Dictionary of hyperparameters to log """ # Flatten the params because add_hparams does not support nested dicts self.writer.add_hparams(flatten_dict(params), {})
[docs] class WandbLogger(LoggerInterface): """Weights & Biases logger backend.""" def __init__(self, cfg: WandbConfig, log_dir: Optional[str] = None): self.run = wandb.init(**cfg, dir=log_dir) print( f"Initialized WandbLogger for project {cfg.get('project')}, run {cfg.get('name')} at {log_dir}" )
[docs] def define_metric( self, name: str, step_metric: Optional[str] = None, ) -> None: """Define a metric with custom step metric. Args: name: Name of the metric or pattern (e.g. 'ray/*') step_metric: Optional name of the step metric to use """ self.run.define_metric(name, step_metric=step_metric)
[docs] def log_metrics( self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = "", step_metric: Optional[str] = None, ) -> None: """Log metrics to wandb. Args: metrics: Dict of metrics to log step: Global step value prefix: Optional prefix for metric names step_metric: Optional name of a field in metrics to use as step instead of the provided step value """ if prefix: metrics = { f"{prefix}/{k}" if k != step_metric else k: v for k, v in metrics.items() } # If step_metric is provided, use the corresponding value from metrics as step if step_metric and step_metric in metrics: # commit=False so the step does not get incremented self.run.log(metrics, commit=False) else: self.run.log(metrics, step=step)
[docs] def log_hyperparams(self, params: Dict[str, Any]) -> None: """Log hyperparameters to wandb. Args: params: Dict of hyperparameters to log """ self.run.config.update(params)
[docs] class GpuMetricSnapshot(TypedDict): step: int metrics: Dict[str, Any]
[docs] class RayGpuMonitorLogger: """Monitor GPU utilization across a Ray cluster and log metrics to a parent logger.""" def __init__( self, collection_interval: int | float, flush_interval: int | float, metric_prefix: str, step_metric: str, parent_logger: Optional["Logger"] = None, ): """Initialize the GPU monitor. Args: collection_interval: Interval in seconds to collect GPU metrics flush_interval: Interval in seconds to flush metrics to parent logger step_metric: Name of the field to use as the step metric parent_logger: Logger to receive the collected metrics """ self.collection_interval = collection_interval self.flush_interval = flush_interval self.metric_prefix = metric_prefix self.step_metric = step_metric self.parent_logger = parent_logger self.metrics_buffer: list[ GpuMetricSnapshot ] = [] # Store metrics with timestamps self.last_flush_time = time.time() self.is_running = False self.collection_thread = None self.lock = threading.Lock()
[docs] def start(self): """Start the GPU monitoring thread.""" if not ray.is_initialized(): raise ValueError( "Ray must be initialized with nemo_rl.distributed.virtual_cluster.init_ray() before the GPU logging can begin." ) if self.is_running: return self.start_time = time.time() self.is_running = True self.collection_thread = threading.Thread( target=self._collection_loop, daemon=True, # Make this a daemon thread so it doesn't block program exit ) self.collection_thread.start() print( f"GPU monitoring started with collection interval={self.collection_interval}s, flush interval={self.flush_interval}s" )
[docs] def stop(self): """Stop the GPU monitoring thread.""" self.is_running = False if self.collection_thread: self.collection_thread.join(timeout=self.collection_interval * 2) # Final flush self.flush() print("GPU monitoring stopped")
[docs] def _collection_loop(self): """Main collection loop that runs in a separate thread.""" while self.is_running: try: collection_time = time.time() relative_time = collection_time - self.start_time # Collect metrics with timing information metrics = self._collect_metrics() if metrics: with self.lock: self.metrics_buffer.append( { "step": int( relative_time ), # Store the relative time as step "metrics": metrics, } ) # Check if it's time to flush current_time = time.time() if current_time - self.last_flush_time >= self.flush_interval: self.flush() self.last_flush_time = current_time time.sleep(self.collection_interval) except Exception as e: print( f"Error in GPU monitoring collection loop or stopped abruptly: {e}" ) time.sleep(self.collection_interval) # Continue despite errors
[docs] def _parse_gpu_metric(self, sample: Sample, node_idx: int) -> Dict[str, Any]: """Parse a GPU metric sample into a standardized format. Args: sample: Prometheus metric sample node_idx: Index of the node Returns: Dictionary with metric name and value """ # Expected labels for GPU metrics expected_labels = ["GpuIndex"] for label in expected_labels: if label not in sample.labels: # This is probably a CPU node return {} metric_name = sample.name # Rename known metrics to match wandb naming convention if metric_name == "ray_node_gpus_utilization": metric_name = "gpu" elif metric_name == "ray_node_gram_used": metric_name = "memory" else: # Skip unexpected metrics return {} labels = sample.labels index = labels["GpuIndex"] value = sample.value metric_name = f"node.{node_idx}.gpu.{index}.{metric_name}" return {metric_name: value}
[docs] def _parse_gpu_sku(self, sample: Sample, node_idx: int) -> Dict[str, str]: """Parse a GPU metric sample into a standardized format. Args: sample: Prometheus metric sample node_idx: Index of the node Returns: Dictionary with metric name and value """ # TODO: Consider plumbing {'GpuDeviceName': 'NVIDIA H100 80GB HBM3'} # Expected labels for GPU metrics expected_labels = ["GpuIndex", "GpuDeviceName"] for label in expected_labels: if label not in sample.labels: # This is probably a CPU node return {} metric_name = sample.name # Only return SKU if the metric is one of these which publish these metrics if ( metric_name != "ray_node_gpus_utilization" and metric_name != "ray_node_gram_used" ): # Skip unexpected metrics return {} labels = sample.labels index = labels["GpuIndex"] value = labels["GpuDeviceName"] metric_name = f"node.{node_idx}.gpu.{index}.type" return {metric_name: value}
[docs] def _collect_gpu_sku(self) -> Dict[str, str]: """Collect GPU SKU from all Ray nodes. Note: This is an internal API and users are not expected to call this. Returns: Dictionary of SKU types on all Ray nodes """ # TODO: We can re-use the same path for metrics because even though both utilization and memory metrics duplicate # the GPU metadata information; since the metadata is the same for each node, we can overwrite it and expect them to # be the same return self._collect(sku=True)
[docs] def _collect_metrics(self) -> Dict[str, Any]: """Collect GPU metrics from all Ray nodes. Returns: Dictionary of collected metrics """ return self._collect(metrics=True)
[docs] def _collect(self, metrics: bool = False, sku: bool = False) -> Dict[str, Any]: """Collect GPU metrics from all Ray nodes. Returns: Dictionary of collected metrics """ assert metrics ^ sku, ( f"Must collect either metrics or sku, not both: {metrics=}, {sku=}" ) parser_fn = self._parse_gpu_metric if metrics else self._parse_gpu_sku if not ray.is_initialized(): print("Ray is not initialized. Cannot collect GPU metrics.") return {} try: nodes = ray.nodes() if not nodes: print("No Ray nodes found.") return {} # Use a dictionary to keep unique metric endpoints and maintain order unique_metric_addresses = {} for node in nodes: node_ip = node["NodeManagerAddress"] metrics_port = node.get("MetricsExportPort") if not metrics_port: continue metrics_address = f"{node_ip}:{metrics_port}" unique_metric_addresses[metrics_address] = True # Process each node's metrics collected_metrics = {} for node_idx, metric_address in enumerate(unique_metric_addresses): gpu_metrics = self._fetch_and_parse_metrics( node_idx, metric_address, parser_fn ) collected_metrics.update(gpu_metrics) return collected_metrics except Exception as e: print(f"Error collecting GPU metrics: {e}") return {}
[docs] def _fetch_and_parse_metrics(self, node_idx, metric_address, parser_fn): """Fetch metrics from a node and parse GPU metrics. Args: node_idx: Index of the node metric_address: Address of the metrics endpoint Returns: Dictionary of GPU metrics """ url = f"http://{metric_address}/metrics" try: response = requests.get(url, timeout=5.0) if response.status_code != 200: print(f"Error: Status code {response.status_code}") return {} metrics_text = response.text gpu_metrics = {} # Parse the Prometheus format for family in text_string_to_metric_families(metrics_text): # Skip non-GPU metrics if family.name not in ( "ray_node_gram_used", "ray_node_gpus_utilization", ): continue for sample in family.samples: metrics = parser_fn(sample, node_idx) gpu_metrics.update(metrics) return gpu_metrics except Exception as e: print(f"Error fetching metrics from {metric_address}: {e}") return {}
[docs] def flush(self): """Flush collected metrics to the parent logger.""" with self.lock: if not self.metrics_buffer: return if self.parent_logger: # Log each set of metrics with its original step for entry in self.metrics_buffer: step = entry["step"] metrics = entry["metrics"] # Add the step metric directly to metrics for use as step_metric metrics[self.step_metric] = step # Pass step_metric as the step_metric to use it as the step value in wandb self.parent_logger.log_metrics( metrics, step=step, prefix=self.metric_prefix, step_metric=self.step_metric, ) # Clear buffer after logging self.metrics_buffer = []
[docs] class Logger(LoggerInterface): """Main logger class that delegates to multiple backend loggers.""" def __init__(self, cfg: LoggerConfig): """Initialize the logger. Args: cfg: Config dict with the following keys: - wandb_enabled - tensorboard_enabled - wandb - tensorboard - monitor_gpus - gpu_collection_interval - gpu_flush_interval """ self.loggers = [] self.wandb_logger = None self.base_log_dir = cfg["log_dir"] os.makedirs(self.base_log_dir, exist_ok=True) if cfg["wandb_enabled"]: wandb_log_dir = os.path.join(self.base_log_dir, "wandb") os.makedirs(wandb_log_dir, exist_ok=True) self.wandb_logger = WandbLogger(cfg["wandb"], log_dir=wandb_log_dir) self.loggers.append(self.wandb_logger) if cfg["tensorboard_enabled"]: tensorboard_log_dir = os.path.join(self.base_log_dir, "tensorboard") os.makedirs(tensorboard_log_dir, exist_ok=True) tensorboard_logger = TensorboardLogger( cfg["tensorboard"], log_dir=tensorboard_log_dir ) self.loggers.append(tensorboard_logger) # Initialize GPU monitoring if requested self.gpu_monitor = None if cfg["monitor_gpus"]: metric_prefix = "ray" step_metric = f"{metric_prefix}/ray_step" if cfg["wandb_enabled"] and self.wandb_logger: self.wandb_logger.define_metric( f"{metric_prefix}/*", step_metric=step_metric ) self.gpu_monitor = RayGpuMonitorLogger( collection_interval=cfg["gpu_monitoring"]["collection_interval"], flush_interval=cfg["gpu_monitoring"]["flush_interval"], metric_prefix=metric_prefix, step_metric=step_metric, parent_logger=self, ) self.gpu_monitor.start() if not self.loggers: print("No loggers initialized")
[docs] def log_metrics( self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = "", step_metric: Optional[str] = None, ) -> None: """Log metrics to all enabled backends. Args: metrics: Dict of metrics to log step: Global step value prefix: Optional prefix for metric names step_metric: Optional name of a field in metrics to use as step instead of the provided step value (currently only needed for wandb) """ for logger in self.loggers: logger.log_metrics(metrics, step, prefix, step_metric)
[docs] def log_hyperparams(self, params: Dict[str, Any]) -> None: """Log hyperparameters to all enabled backends. Args: params: Dict of hyperparameters to log """ for logger in self.loggers: logger.log_hyperparams(params)
[docs] def log_batched_dict_as_jsonl( self, to_log: BatchedDataDict | Dict[str, Any], filename: str ) -> None: """Log a list of dictionaries to a JSONL file. Args: to_log: BatchedDataDict to log filename: Filename to log to (within the log directory) """ if not isinstance(to_log, BatchedDataDict): to_log = BatchedDataDict(to_log) # Create full path within log directory filepath = os.path.join(self.base_log_dir, filename) os.makedirs(os.path.dirname(filepath), exist_ok=True) # Write to JSONL file with open(filepath, "w") as f: for i, sample in enumerate(to_log.make_microbatch_iterator(1)): for key, value in sample.items(): if isinstance(value, torch.Tensor): sample[key] = value.tolist() f.write(json.dumps({**sample, "idx": i}) + "\n") print(f"Logged data to {filepath}")
[docs] def __del__(self): """Clean up resources when the logger is destroyed.""" if self.gpu_monitor: self.gpu_monitor.stop()
[docs] def flatten_dict(d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: """Flatten a nested dictionary. Handles nested dictionaries and lists by creating keys with separators. For lists, the index is used as part of the key. Args: d: Dictionary to flatten sep: Separator to use between nested keys Returns: Flattened dictionary with compound keys Examples: ```{doctest} >>> from nemo_rl.utils.logger import flatten_dict >>> flatten_dict({"a": 1, "b": {"c": 2}}) {'a': 1, 'b.c': 2} >>> flatten_dict({"a": [1, 2], "b": {"c": [3, 4]}}) {'a.0': 1, 'a.1': 2, 'b.c.0': 3, 'b.c.1': 4} >>> flatten_dict({"a": [{"b": 1}, {"c": 2}]}) {'a.0.b': 1, 'a.1.c': 2} ``` """ result = {} def _flatten(d, parent_key=""): for key, value in d.items(): new_key = f"{parent_key}{sep}{key}" if parent_key else key if isinstance(value, dict): _flatten(value, new_key) elif isinstance(value, list): for i, item in enumerate(value): list_key = f"{new_key}{sep}{i}" if isinstance(item, dict): _flatten(item, list_key) else: result[list_key] = item else: result[new_key] = value _flatten(d) return result
""" Rich Console Logging Functionality --------------------------------- Functions for setting up rich console logging and visualizing model outputs. """
[docs] def configure_rich_logging( level: str = "INFO", show_time: bool = True, show_path: bool = True ) -> None: """Configure rich logging for more visually appealing log output. Args: level: The logging level to use show_time: Whether to show timestamps in logs show_path: Whether to show file paths in logs """ global _rich_logging_configured # Only configure if not already done if not _rich_logging_configured: # Configure logging with rich handler logging.basicConfig( level=level.upper(), format="%(message)s", datefmt="[%X]", handlers=[ RichHandler( rich_tracebacks=True, show_time=show_time, show_path=show_path, markup=True, ) ], ) _rich_logging_configured = True
[docs] def get_next_experiment_dir(base_log_dir): """Create a new experiment directory with an incremented ID. Args: base_log_dir (str): The base log directory path Returns: str: Path to the new experiment directory with incremented ID """ # Check if the log directory already contains an experiment ID pattern (e.g., /exp_001/) pattern = re.compile(r"exp_(\d+)") next_exp_id = 1 # Check for existing experiment directories existing_dirs = glob.glob(os.path.join(base_log_dir, "exp_*")) if existing_dirs: # Extract experiment IDs and find the maximum exp_ids = [] for dir_path in existing_dirs: match = pattern.search(dir_path) if match: exp_ids.append(int(match.group(1))) if exp_ids: # Increment the highest experiment ID next_exp_id = max(exp_ids) + 1 # Format the new log directory with the incremented experiment ID new_log_dir = os.path.join(base_log_dir, f"exp_{next_exp_id:03d}") # Create the new log directory os.makedirs(new_log_dir, exist_ok=True) return new_log_dir