Source code for ran.datasets.generate_cdl_channels

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataset generation using Sionna CDL models.

This module synthesizes frequency-domain channel responses with 3GPP CDL
(Clustered Delay Line) models via Sionna and saves them directly to safetensors format.

Outputs:
  - train_data.safetensors: Training channels
  - val_data.safetensors: Validation channels
  - test_data.safetensors: Test channels
  - Each contains H__sc_sym_rxant: complex64 array of shape (N, num_sc, 14, 4)

The CDL model supports configurable antenna arrays for both BS (Base Station)
and UE (User Equipment) with customizable polarization and antenna patterns.
"""

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
from safetensors.numpy import save_file
from tqdm import tqdm

from .sionna_cdl_config import SionnaCDLConfig

logger = logging.getLogger(__name__)

# Lazy import of Sionna to avoid TensorFlow loading until needed
if TYPE_CHECKING:
    pass


def _generate_channels(
    gen: Any,  # phy.channel.generate_ofdm_channel.GenerateOFDMChannel
    total: int,
    config: SionnaCDLConfig,
) -> np.ndarray:
    """Generate channel realizations from Sionna.

    Args:
        gen: Sionna channel generator
        total: Total number of channels to generate
        config: Configuration object

    Returns:
        Array of shape (total, num_sc, 14, 4) with complex64 dtype
    """
    H: np.ndarray = np.zeros((total, config.num_sc, 14, 4), dtype=np.complex64)
    filled = 0

    with tqdm(total=total, desc="Generating", unit="channels") as pbar:
        while filled < total:
            batch_size = min(config.batch_tf, total - filled)

            # Generate batch: [B, num_tx_ant, num_rx_ant, num_streams_per_tx,
            #                  num_streams_per_rx, num_symbols, num_sc]
            H_tf = gen(batch_size)
            H_np = H_tf.numpy()  # Shape: [B, 1, 4, 1, 1, 14, num_sc]

            for k in range(batch_size):
                # Extract: [1, 4, 1, 1, 14, num_sc] → [4, 14, num_sc]
                h_channel = H_np[k, 0, :, 0, 0, :, : config.num_sc]

                # Transpose to (num_sc, 14, 4)
                h_channel = h_channel.transpose(2, 1, 0)

                H[filled] = h_channel
                filled += 1
                pbar.update(1)

                if filled >= total:
                    break

    return H


def _save_to_safetensors(H: np.ndarray, path: Path) -> None:
    """Save channel data to safetensors format.

    Handles complex arrays by saving real and imaginary parts separately.

    Args:
        H: Channel array of shape (N, num_sc, 14, 4) with complex64 dtype
        path: Output file path
    """
    tensors: dict[str, np.ndarray] = {
        "H__sc_sym_rxant.real": H.real.astype(np.float32),
        "H__sc_sym_rxant.imag": H.imag.astype(np.float32),
    }
    save_file(tensors, str(path))

    # Log size info
    size_mb = path.stat().st_size / (1024 * 1024)
    logger.info("  %s: %d channels, %.1f MB", path.name, H.shape[0], size_mb)


[docs] def gen_cdl_channels( config: SionnaCDLConfig, validation_frac: float = 0.15, out_dir: str = "/opt/nvidia/aerial-framework/out/sionna_dataset", ) -> tuple[int, int, int]: """Generate CDL channel dataset using Sionna and save to safetensors. This function generates channel realizations, splits them into train/val/test, and saves directly to safetensors format, eliminating intermediate files. Args: config: Configuration object containing all channel generation parameters. validation_frac: Fraction of training data for validation. out_dir: Output directory for generated files. Returns: Tuple of (num_train, num_val, num_test) samples. """ # Lazy import Sionna/TensorFlow only when actually generating from sionna import phy # type: ignore[import-untyped] # Create output directory out_path = Path(out_dir) out_path.mkdir(parents=True, exist_ok=True) # Setup Sionna components rg = phy.ofdm.ResourceGrid( num_ofdm_symbols=14, fft_size=config.num_sc, subcarrier_spacing=30e3, num_guard_carriers=(0, 0), dc_null=False, pilot_pattern="empty", num_tx=1, num_streams_per_tx=1, ) # Create antenna arrays bs_array = phy.channel.tr38901.antenna.AntennaArray( num_rows=config.bs_num_rows, num_cols=config.bs_num_cols, polarization=config.bs_polarization, polarization_type=config.bs_polarization_type, antenna_pattern=config.bs_antenna_pattern, carrier_frequency=config.fc_ghz * 1e9, ) ue_array = phy.channel.tr38901.antenna.AntennaArray( num_rows=config.ue_num_rows, num_cols=config.ue_num_cols, polarization=config.ue_polarization, polarization_type=config.ue_polarization_type, antenna_pattern=config.ue_antenna_pattern, carrier_frequency=config.fc_ghz * 1e9, ) # Create CDL channel model channel_model = phy.channel.tr38901.cdl.CDL( model=config.tdl_model, delay_spread=config.delay_spread_ns * 1e-9, carrier_frequency=config.fc_ghz * 1e9, ut_array=ue_array, bs_array=bs_array, direction=config.direction, min_speed=config.speed_min, max_speed=config.speed_max, ) gen = phy.channel.generate_ofdm_channel.GenerateOFDMChannel( channel_model=channel_model, resource_grid=rg ) # Generate training data logger.info("Generating %d training channels...", config.train_total) H_train_all = _generate_channels(gen, config.train_total, config) # Generate test data logger.info("Generating %d test channels...", config.test_total) H_test = _generate_channels(gen, config.test_total, config) # Split training data into train/val if config.train_total < 2: raise ValueError( f"train_total must be at least 2 to allow train/val split, got {config.train_total}" ) rng = np.random.default_rng(config.prng_seed) n_val = max(1, int(validation_frac * config.train_total)) # Ensure at least 1 training sample remains n_val = min(n_val, config.train_total - 1) n_train = config.train_total - n_val indices = rng.permutation(config.train_total) train_indices = indices[:n_train] val_indices = indices[n_train:] H_train = H_train_all[train_indices] H_val = H_train_all[val_indices] logger.info(" Split: %d train, %d val, %d test", n_train, n_val, config.test_total) # Save to safetensors logger.info("Saving datasets to safetensors...") _save_to_safetensors(H_train, out_path / "train_data.safetensors") _save_to_safetensors(H_val, out_path / "val_data.safetensors") _save_to_safetensors(H_test, out_path / "test_data.safetensors") logger.info("Datasets saved to %s", out_dir) return n_train, n_val, config.test_total