Using pyAerial to evaluate a PUSCH neural receiver#

This example shows how to use the pyAerial cuPHY Python bindings to evaluate a trained neural network -based PUSCH receiver. In this example, the neural network is used to replace channel estimation, noise and interference estimation and channel equalization, and thus outputs log-likelihood ratios directly. The model is a variant of what has been proposed in

  1. Cammerer, F. Aït Aoudia, J. Hoydis, A. Oeldemann, A. Roessler, T. Mayer and A. Keller, “A Neural Receiver for 5G NR Multi-user MIMO”, IEEE Globecom Workshops (GC Wkshps), Dec. 2023.

The rest of the PUSCH receiver pipeline following the neural receiver, meaning LDPC decoding chain, is modeled using pyAerial. Also, the neural receiver takes LS channel estimates as inputs in addition to the received PUSCH slot. These are also obtained using pyAerial. The neural receiver -based PUSCH receiver is compared against the conventional PUSCH receiver, which is built using pyAerial’s (fully fused) PUSCH pipeline.

PUSCH transmitter is emulated by PDSCH transmission with properly chosen parameters, that way making it a 5G NR compliant PUSCH transmission.

Imports#

[1]:
%matplotlib widget
from collections import defaultdict
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_MODULE_LOADING"] = "LAZY"

import cupy as cp
import numpy as np

from aerial.phy5g.pdsch import PdschTx
from aerial.phy5g.pusch import PuschRx
from aerial.phy5g.algorithms import ChannelEstimator
from aerial.phy5g.algorithms import TrtEngine
from aerial.phy5g.algorithms import TrtTensorPrms
from aerial.phy5g.ldpc import get_mcs
from aerial.phy5g.ldpc import random_tb
from aerial.phy5g.ldpc import get_tb_size
from aerial.phy5g.ldpc import LdpcDeRateMatch
from aerial.phy5g.ldpc import LdpcDecoder
from aerial.phy5g.ldpc import CrcChecker
from aerial.pycuphy.types import PuschLdpcKernelLaunch
from aerial.phy5g.config import PuschConfig
from aerial.phy5g.config import PuschUeConfig
from aerial.phy5g.channel_models import FadingChannel, TdlChannelConfig, CdlChannelConfig
from aerial.util.cuda import CudaStream
from simulation_monitor import SimulationMonitor

Parameters#

Set simulation parameters, numerology, PUSCH parameters and channel parameters here.

[2]:
# Simulation parameters.
esno_db_range = np.arange(-8.0, 6.0, 2.0)
num_slots = 10000
min_num_tb_errors = 250

# Numerology and frame structure. See TS 38.211.
num_ofdm_symbols = 14
fft_size = 4096
num_guard_subcarriers = (410, 410)
num_slots_per_frame = 20

# System/gNB configuration
num_tx_ant = 1             # UE antennas
num_rx_ant = 4             # gNB antennas
cell_id = 41               # Physical cell ID
enable_pusch_tdi = 0       # Enable time interpolation for equalizer coefficients
eq_coeff_algo = 1          # Equalizer algorithm

# PUSCH parameters
rnti = 1234                # UE RNTI
scid = 0                   # DMRS scrambling ID
data_scid = 0              # Data scrambling ID
layers = 1                 # Number of layers
mcs_index = 7              # MCS index as per TS 38.214 table.
mcs_table = 0              # MCS table index
dmrs_ports = 1             # Used DMRS port.
start_prb = 0              # Start PRB index.
num_prbs = 273             # Number of allocated PRBs.
start_sym = 0              # Start symbol index.
num_symbols = 12           # Number of symbols.
dmrs_scrm_id = 41          # DMRS scrambling ID
dmrs_syms = [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]  # Indicates which symbols are used for DMRS.
dmrs_max_len = 1
dmrs_add_ln_pos = 2
num_dmrs_cdm_grps_no_data = 2
mod_order, code_rate = get_mcs(mcs_index, mcs_table+1)  # Different indexing for MCS table.
tb_size = get_tb_size(  # TB size in bits
    mod_order=mod_order,
    code_rate=code_rate,
    dmrs_syms=dmrs_syms,
    num_prbs=num_prbs,
    start_sym=start_sym,
    num_symbols=num_symbols,
    num_layers=layers)

# Channel parameters
carrier_frequency = 3.5e9  # Carrier frequency in Hz.
channel_type = "tdl"       # Channel type: "tdl" or "cdl"
delay_profile = "A"        # Delay profile: 'A', 'B', 'C', 'D', 'E' as per 3GPP TR 38.901
speed = 0.8333             # UE speed [m/s]. Used to calculate Doppler shift.

Create the model file for the TRT engine#

The TRT engine is built based on TensorRT plan files which are not portable across different platforms. Hence the plan file is created here from a supplied ONNX file.

[3]:
MODEL_DIR = "../models"
nrx_onnx_file = f"{MODEL_DIR}/neural_rx.onnx"
nrx_trt_file = f"{MODEL_DIR}/neural_rx.trt"
command = f"trtexec " + \
    f"--onnx={nrx_onnx_file} " + \
    f"--saveEngine={nrx_trt_file} " + \
    f"--skipInference " + \
    f"--inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,int32:chw,int32:chw " + \
    f"--outputIOFormats=fp32:chw,fp32:chw " + \
    f"--shapes=rx_slot_real:1x3276x12x4,rx_slot_imag:1x3276x12x4,h_hat_real:1x4914x1x4,h_hat_imag:1x4914x1x4 " + \
    f"> /dev/null"
return_val = os.system(command)
if return_val == 0:
    print("TRT engine model created.")
else:
    raise SystemExit("Failed to create the TRT engine file!")
TRT engine model created.

Create the PUSCH pipelines#

As mentioned, PUSCH transmission is emulated here by the PDSCH transmission chain. Note that the static cell parameters and static PUSCH parameters are given upon creating the PUSCH transmission/reception objects. Dynamically (per slot) changing parameters are however set when actually running the transmission/reception, see further below.

[4]:
cuda_stream = CudaStream()

pusch_tx = PdschTx(
    cell_id=cell_id,
    num_rx_ant=num_tx_ant,
    num_tx_ant=num_tx_ant,
    cuda_stream=cuda_stream
)

# This is the fully fused PUSCH receiver chain.
pusch_rx = PuschRx(
    cell_id=cell_id,
    num_rx_ant=num_rx_ant,
    num_tx_ant=num_rx_ant,
    enable_pusch_tdi=enable_pusch_tdi,
    eq_coeff_algo=eq_coeff_algo,
    # To make this equal separate PUSCH Rx components configuration:
    ldpc_kernel_launch=PuschLdpcKernelLaunch.PUSCH_RX_LDPC_STREAM_SEQUENTIAL,
    cuda_stream=cuda_stream
)

# PUSCH configuration object. Note that default values are used for some parameters
# not given here.
pusch_ue_config = PuschUeConfig(
    scid=scid,
    layers=layers,
    dmrs_ports=dmrs_ports,
    rnti=rnti,
    data_scid=data_scid,
    mcs_table=mcs_table,
    mcs_index=mcs_index,
    code_rate=int(code_rate * 10),
    mod_order=mod_order,
    tb_size=tb_size // 8
)
# Note that this is a list. One UE group only in this case.
pusch_configs = [PuschConfig(
    ue_configs=[pusch_ue_config],
    num_dmrs_cdm_grps_no_data=num_dmrs_cdm_grps_no_data,
    dmrs_scrm_id=dmrs_scrm_id,
    start_prb=start_prb,
    num_prbs=num_prbs,
    dmrs_syms=dmrs_syms,
    dmrs_max_len=dmrs_max_len,
    dmrs_add_ln_pos=dmrs_add_ln_pos,
    start_sym=start_sym,
    num_symbols=num_symbols
)]


class NeuralRx:
    """PUSCH neural receiver class.

    This class encapsulates the PUSCH neural receiver chain built using
    pyAerial components.
    """

    def __init__(self,
                 num_rx_ant,
                 enable_pusch_tdi,
                 eq_coeff_algo,
                 cuda_stream):
        """Initialize the neural receiver."""
        self.cuda_stream=cuda_stream

        # Build the components of the receiver. The channel estimator outputs just the LS
        # channel estimates.
        self.channel_estimator = ChannelEstimator(
            num_rx_ant=num_rx_ant,
            ch_est_algo=3,  # This is LS channel estimation.
            cuda_stream=self.cuda_stream
        )

        # Create the pyAerial TRT engine object. This wraps TensorRT and links it together
        # with the rest of cuPHY. Here pyAerial's Python bindings to the engine are used
        # to run inference with the neural receiver model.
        # The inputs of the neural receiver are:
        # - LS channel estimates
        # - The Rx slot
        # - Active DMRS ports (active layers out of the layers that the neural receiver supports)
        # - DMRS OFDM symbol locations (indices)
        # - DMRS subcarrier positions within a PRB (indices)
        # Note that the shapes are given without batch size.
        self.trt_engine = TrtEngine(
            trt_model_file="../models/neural_rx.trt",
            max_batch_size=1,
            input_tensors=[TrtTensorPrms('rx_slot_real', (3276, 12, 4), np.float32),
                           TrtTensorPrms('rx_slot_imag', (3276, 12, 4), np.float32),
                           TrtTensorPrms('h_hat_real', (4914, 1, 4), np.float32),
                           TrtTensorPrms('h_hat_imag', (4914, 1, 4), np.float32),
                           TrtTensorPrms('active_dmrs_ports', (1,), np.float32),
                           TrtTensorPrms('dmrs_ofdm_pos', (3,), np.int32),
                           TrtTensorPrms('dmrs_subcarrier_pos', (6,), np.int32)],
            output_tensors=[TrtTensorPrms('output_1', (8, 1, 3276, 12), np.float32),
                            TrtTensorPrms('output_2', (1, 3276, 12, 8), np.float32)],
            cuda_stream=self.cuda_stream
        )

        # LDPC (de)rate matching and decoding implemented using pyAerial.
        self.derate_match = LdpcDeRateMatch(
            enable_scrambling=True,
            cuda_stream=self.cuda_stream
        )
        self.decoder = LdpcDecoder(cuda_stream=self.cuda_stream)
        self.crc_checker = CrcChecker(cuda_stream=self.cuda_stream)

        # Pre-allocate constant tensors for TRT inference to avoid runtime allocation.
        self.dmrs_subcarrier_pos = cp.array([[0, 2, 4, 6, 8, 10]], dtype=cp.int32)
        self.active_dmrs_ports = cp.ones((1, 1), dtype=cp.float32)

    def run(
        self,
        rx_slot,
        slot,
        pusch_configs=pusch_configs
    ):
        """Run the receiver."""
        # Channel estimation.
        ch_est = self.channel_estimator.estimate(
            rx_slot=rx_slot,
            slot=slot,
            pusch_configs=pusch_configs
        )

        # This is the neural receiver part.
        # It outputs the LLRs for all symbols.
        with self.cuda_stream:
            dmrs_ofdm_pos = cp.where(cp.array(pusch_configs[0].dmrs_syms))[0].astype(cp.int32)
            dmrs_ofdm_pos = dmrs_ofdm_pos[None, ...]

            rx_slot_in = rx_slot[None, :, pusch_configs[0].start_sym:pusch_configs[0].start_sym+pusch_configs[0].num_symbols, :]
            ch_est_in = cp.transpose(ch_est[0], (0, 3, 1, 2)).reshape(ch_est[0].shape[0] * ch_est[0].shape[3], ch_est[0].shape[1], ch_est[0].shape[2])
            ch_est_in = ch_est_in[None, ...]

            input_tensors = {
                "rx_slot_real": cp.real(rx_slot_in),
                "rx_slot_imag": cp.imag(rx_slot_in),
                "h_hat_real": cp.real(ch_est_in),
                "h_hat_imag": cp.imag(ch_est_in),
                "active_dmrs_ports": self.active_dmrs_ports,
                "dmrs_ofdm_pos": dmrs_ofdm_pos,
                "dmrs_subcarrier_pos": self.dmrs_subcarrier_pos
            }

        outputs = self.trt_engine.run(input_tensors)

        with self.cuda_stream:
            # The neural receiver outputs some values also for DMRS symbols, remove those
            # from the output.
            data_syms = cp.array(pusch_configs[0].dmrs_syms[pusch_configs[0].start_sym:pusch_configs[0].start_sym + pusch_configs[0].num_symbols]) == 0
            llrs = cp.take(outputs["output_1"][0, ...], cp.where(data_syms)[0], axis=3)

        coded_blocks = self.derate_match.derate_match(
            input_llrs=[llrs],
            pusch_configs=pusch_configs
        )

        code_blocks = self.decoder.decode(
            input_llrs=coded_blocks,
            pusch_configs=pusch_configs
        )

        decoded_tbs, _ = self.crc_checker.check_crc(
            input_bits=code_blocks,
            pusch_configs=pusch_configs
        )

        decoded_tbs = [tb.get(order='F') for tb in decoded_tbs]
        return decoded_tbs

neural_rx = NeuralRx(
    num_rx_ant=num_rx_ant,
    enable_pusch_tdi=enable_pusch_tdi,
    eq_coeff_algo=eq_coeff_algo,
    cuda_stream=cuda_stream
)

Channel model setup#

The pyAerial FadingChannel class provides GPU-accelerated TDL (Tapped Delay Line) and CDL (Clustered Delay Line) channel models based on 3GPP TR 38.901. The channel operates in frequency domain and includes built-in OFDM modulation/demodulation and AWGN noise addition.

The channel accepts both CuPy and NumPy arrays, enabling seamless integration with the rest of the pyAerial pipeline.

[5]:
# Calculate Doppler shift from speed and carrier frequency.
max_doppler_shift = speed * carrier_frequency / 3e8

# Create channel configuration based on channel type.
# Note: n_bs_ant = gNB antennas, n_ue_ant = UE antennas
# With enable_swap_tx_rx=True, TX=UE, RX=gNB (uplink)
if channel_type == "tdl":
    channel_config = TdlChannelConfig(
        delay_profile=delay_profile,
        delay_spread=0.0,  # Neural Rx trained for frequency-flat
        max_doppler_shift=max_doppler_shift,
        n_bs_ant=num_rx_ant,  # gNB antennas (receiver in uplink)
        n_ue_ant=num_tx_ant   # UE antennas (transmitter in uplink)
    )
elif channel_type == "cdl":
    channel_config = CdlChannelConfig(
        delay_profile=delay_profile,
        delay_spread=0.0,  # Neural Rx trained for frequency-flat
        max_doppler_shift=max_doppler_shift,
        bs_ant_size=(1, num_rx_ant // 2, 2),  # gNB antenna array
        ue_ant_size=(1, num_tx_ant, 1)        # UE antenna array
    )
else:
    raise ValueError(f"Invalid channel type: {channel_type}. Use 'tdl' or 'cdl'.")

# Create FadingChannel with OFDM parameters.
n_sc = fft_size - sum(num_guard_subcarriers)
channel = FadingChannel(
    channel_config=channel_config,
    n_sc=n_sc,
    numerology=1,  # 30 kHz subcarrier spacing
    n_fft=fft_size,
    n_symbol_slot=num_ofdm_symbols,
    cuda_stream=cuda_stream
)


def apply_channel(tx_signal, snr_db, slot_idx):
    """Apply fading channel with AWGN noise.

    Args:
        tx_signal: Transmitted signal, shape (n_sc, n_symbol, n_tx_ant).
        snr_db: Signal-to-noise ratio in dB.
        slot_idx: Slot index for time-varying channel.

    Returns:
        Received signal after channel and noise, shape (n_sc, n_symbol, n_rx_ant).
    """
    # Reshape for FadingChannel: (n_cell, n_ue, n_tx_ant, n_symbol, n_sc)
    tx_reshaped = tx_signal.transpose((2, 1, 0))[None, None, ...]

    # Run channel (uplink: swap tx/rx to apply channel in uplink direction)
    rx_signal = channel(
        freq_in=tx_reshaped,
        tti_idx=slot_idx,
        snr_db=snr_db,
        enable_swap_tx_rx=True)

    # Reshape back to (n_sc, n_symbol, n_rx_ant)
    return rx_signal[0, 0].transpose((2, 1, 0))

Run the actual simulation#

Here we loop across the Es/No range, and simulate a number of slots for each Es/No value. A single transport block is simulated within a slot. The simulation starts over from the next Es/No value when a minimum number of transport block errors is reached.

[6]:
cases = ["PUSCH Rx", "Neural Rx"]
monitor = SimulationMonitor(cases, esno_db_range)

# Loop the Es/No range.
for esno_db in esno_db_range:
    monitor.step(esno_db)
    num_tb_errors = defaultdict(int)

    # Run multiple slots and compute BLER.
    # We reset the channel for every slot to simulate different channel realizations.
    for slot_idx in range(num_slots):
        channel.reset()

        slot_number = slot_idx % num_slots_per_frame

        # Get modulation order and coderate.
        tb_input_np = random_tb(
            mod_order=mod_order,
            code_rate=code_rate,
            dmrs_syms=dmrs_syms,
            num_prbs=num_prbs,
            start_sym=start_sym,
            num_symbols=num_symbols,
            num_layers=layers)
        tb_input = cp.array(tb_input_np, dtype=cp.uint8, order='F')

        # Transmit PUSCH. This is where we set the dynamically changing parameters.
        # Input parameters are given as lists as the interface supports multiple UEs.
        tx_tensor = pusch_tx.run(
            tb_inputs=[tb_input],          # Input transport block in bytes.
            num_ues=1,                     # We simulate only one UE here.
            slot=slot_number,              # Slot number.
            num_dmrs_cdm_grps_no_data=num_dmrs_cdm_grps_no_data,
            dmrs_scrm_ids=[dmrs_scrm_id],  # DMRS scrambling ID.
            start_prb=start_prb,           # Start PRB index.
            num_prbs=num_prbs,             # Number of allocated PRBs.
            dmrs_syms=dmrs_syms,           # List of binary numbers indicating which symbols are DMRS.
            start_sym=start_sym,           # Start symbol index.
            num_symbols=num_symbols,       # Number of symbols.
            scids=[scid],                  # DMRS scrambling ID.
            layers=[layers],               # Number of layers (transmission rank).
            dmrs_ports=[dmrs_ports],       # DMRS port(s) to be used.
            rntis=[rnti],                  # UE RNTI.
            data_scids=[data_scid],        # Data scrambling ID.
            code_rates=[code_rate * 10],   # Code rate x 1024 x 10.
            mod_orders=[mod_order]         # Modulation order.
        )

        # Apply fading channel with AWGN noise.
        rx_tensor = apply_channel(tx_tensor, esno_db, slot_idx)

        # Run the fused PUSCH receiver.
        # Note that this is where we set the dynamically changing parameters.
        tb_crcs, tbs = pusch_rx.run(
            rx_slot=rx_tensor,
            slot=slot_number,
            pusch_configs=pusch_configs
        )
        num_tb_errors["PUSCH Rx"] += int(np.array_equal(tbs[0], tb_input_np) == False)

        # Run the neural receiver.
        tbs = neural_rx.run(
            rx_slot=rx_tensor,
            slot=slot_number,
            pusch_configs=pusch_configs
        )
        num_tb_errors["Neural Rx"] += int(np.array_equal(tbs[0], tb_input_np) == False)

        monitor.update(num_tbs=slot_idx + 1, num_tb_errors=num_tb_errors)
        if (np.array(list(num_tb_errors.values())) >= min_num_tb_errors).all():
            break  # Next Es/No value.

    monitor.finish_step(num_tbs=slot_idx + 1, num_tb_errors=num_tb_errors)
monitor.finish()
                           PUSCH Rx            Neural Rx
                     -------------------- --------------------
  Es/No (dB)     TBs    TB Errors    BLER    TB Errors    BLER    ms/TB
==================== ==================== ==================== ========
       -8.00     250          250  1.0000          250  1.0000    113.3
       -6.00     280          258  0.9214          250  0.8929    113.0
       -4.00     389          260  0.6684          250  0.6427    113.0
       -2.00     925          272  0.2941          250  0.2703    112.9
        0.00    2726          288  0.1056          250  0.0917    112.9
        2.00    9841          286  0.0291          250  0.0254    112.9
        4.00   10000           66  0.0066           55  0.0055    112.9