Using pyAerial to run CSI-RS transmission and reception#
The pyAerial library supports transmission of 5G NR compliant CSI reference signals (CSI-RS), as well as the UE side channel estimation based on CSI-RS. This example shows how to use the pyAerial cuPHY Python bindings to run CSI-RS transmission and reception using the pyAerial CSI-RS transmitter and receiver pipelines. The notebook runs CSI-RS transmission and reception and plots the estimated channel against the actual channel realization, for the given signal-to-noise ratio.
Imports#
[1]:
%matplotlib widget
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
# pyAerial imports
from aerial.phy5g.csirs import CsiRsConfig
from aerial.phy5g.csirs import CsiRsTxConfig
from aerial.phy5g.csirs import CsiRsRxConfig
from aerial.phy5g.csirs import CsiRsTx
from aerial.phy5g.csirs import CsiRsRx
from aerial.phy5g.channel_models import FadingChannel
from aerial.phy5g.channel_models import TdlChannelConfig
from aerial.phy5g.channel_models import CdlChannelConfig
from aerial.util.cuda import CudaStream
Parameters#
Set channel and numerology parameters.
[2]:
# Channel parameters
esno_db = 20. # Es/No
num_tx_ant = 8 # Number of Tx antennas (gNB)
num_rx_ant = 2 # Number of Rx antennas (UE)
carrier_frequency = 3.5e9 # Carrier frequency in Hz
channel_type = "cdl" # Channel type: "tdl" or "cdl"
delay_profile = "C" # Delay profile: 'A', 'B', 'C' (NLOS). Note: 'D', 'E' (LOS) not yet supported.
delay_spread = 100.0 # RMS delay spread in nanoseconds
speed = 0.8333 # UE speed [m/s]. Used to calculate Doppler shift.
# Numerology and frame structure. See 3GPP TS 38.211.
num_symb_per_slot = 14
fft_size = 4096
num_prb = 273
CSI-RS configuration#
Set CSI-RS resource mapping and sequence configuration. Refer to 3GPP TS 38.211 section 7.4.1.5.3 and in particular table 7.4.1.5.3-1 for the exact definitions of the fields. The parameterization here follows closely the 3GPP specification.
Note: CSI-RS type is by default non-zero power (NZP) CSI-RS - only this is currently supported by cuPHY.
[3]:
csirs_configs = [CsiRsConfig(
start_prb=0, # Start PRB.
num_prb=num_prb, # Number of PRBs.
freq_alloc=[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # Bitmap defining the frequencyDomainAllocation field in CSI-RS RRC parameters.
row=6, # CSI-RS parameter table row number.
symb_L0=0, # Time domain location L0. This corresponds to the `firstOFDMSymbolInTimeDomain`
# field in CSI-RS RRC parameters.
symb_L1=8, # Time domain location L1. This corresponds to the
# `firstOFDMSymbolInTimeDomain2` field in CSI-RS RRC parameters.
freq_density=2, # The `density` field in CSI-RS RRC parameters.
scramb_id=0, # CSI-RS scrambling ID.
idx_slot_in_frame=0, # Slot index in frame.
cdm_type=1, # CDM Type - this is the `cdm-Type` in CSI-RS RRC parameters.
beta=1.0 # CSI-RS power control.
)]
Create CSI-RS transmitter and receiver objects#
The CSI-RS transmitter and receiver objects are created here. Also, the dynamically changing slot configurations for both the transmitter and the receiver are instantiated. These are given as parameters later when the Tx/Rx are actually called.
[4]:
cuda_stream = CudaStream()
csirs_rx_config = CsiRsRxConfig(
csirs_configs=[csirs_configs], # One cell with multiple CSI-RS configurations.
ue_cell_association=[0] # One UE associated with cell 0.
)
csirs_tx_config = CsiRsTxConfig(
csirs_configs=[csirs_configs], # One cell with multiple CSI-RS configurations.
precoding_matrices=[] # No precoding.
)
csirs_rx = CsiRsRx(num_prb_dl_bwp=[num_prb], cuda_stream=cuda_stream)
csirs_tx = CsiRsTx(num_prb_dl_bwp=[num_prb], cuda_stream=cuda_stream)
Channel modeling#
The radio channel is simulated using the pyAerial GPU-accelerated FadingChannel class from aerial.phy5g.channel_models. It supports both TDL (Tapped Delay Line) and CDL (Clustered Delay Line) channel models as defined in 3GPP TR 38.901. The channel operates in the frequency domain: it applies the fading channel to the transmitted signal and adds AWGN noise at the specified SNR. Reference signal patterns and data-carrying resource elements are defined elsewhere in the pyAerial code, so
here we only pass the number of used subcarriers to the channel model.
[5]:
# Calculate Doppler shift from speed and carrier frequency.
max_doppler_shift = speed * carrier_frequency / 3e8
# Create channel configuration.
# Note: This is downlink, so gNB transmits (n_bs_ant=TX), UE receives (n_ue_ant=RX).
# No enable_swap_tx_rx needed for downlink.
if channel_type == "tdl":
channel_config = TdlChannelConfig(
delay_profile=delay_profile,
delay_spread=delay_spread,
max_doppler_shift=max_doppler_shift,
n_bs_ant=num_tx_ant, # gNB antennas (transmitter in downlink)
n_ue_ant=num_rx_ant # UE antennas (receiver in downlink)
)
elif channel_type == "cdl":
channel_config = CdlChannelConfig(
delay_profile=delay_profile,
delay_spread=delay_spread,
max_doppler_shift=max_doppler_shift,
bs_ant_size=(1, num_tx_ant // 2, 2), # gNB antenna array (dual-pol)
ue_ant_size=(1, num_rx_ant // 2, 2) # UE antenna array (dual-pol)
)
else:
raise ValueError(f"Invalid channel type: {channel_type}. Use 'tdl' or 'cdl'.")
# Create FadingChannel with OFDM parameters.
n_sc = num_prb * 12
channel = FadingChannel(
channel_config=channel_config,
n_sc=n_sc,
numerology=1, # 30 kHz subcarrier spacing
n_fft=fft_size,
n_symbol_slot=num_symb_per_slot
)
def apply_channel(tx_signal, snr_db, slot_idx):
"""Apply fading channel with AWGN noise.
Args:
tx_signal: Transmitted signal, shape (n_sc, n_symbol, n_tx_ant).
snr_db: Signal-to-noise ratio in dB.
slot_idx: Slot index for time-varying channel.
Returns:
Tuple of (rx_signal, channel_freq_response):
- rx_signal: Received signal after channel and noise, shape (n_sc, n_symbol, n_rx_ant).
- channel_freq_response: CFR, shape (n_sc, n_symbol, n_tx_ant, n_rx_ant).
"""
# Reshape for FadingChannel: (n_cell, n_ue, n_tx_ant, n_symbol, n_sc)
tx_reshaped = tx_signal.transpose((2, 1, 0))[None, None, ...]
# Run channel (downlink: no swap needed)
rx_signal = channel(
freq_in=tx_reshaped,
tti_idx=slot_idx,
snr_db=snr_db,
enable_swap_tx_rx=False
)
# Get channel frequency response for comparison
# Shape: (n_cell, n_ue, n_rx_ant, n_tx_ant, n_symbol, n_sc)
cfr = channel.get_channel_frequency_response()
# Reshape back to expected shapes
rx_out = rx_signal[0, 0].transpose((2, 1, 0)) # (n_sc, n_symbol, n_rx_ant)
cfr_out = cfr[0, 0].transpose((3, 0, 2, 1)) # (n_sc, n_symbol, n_tx_ant, n_rx_ant)
return rx_out, cfr_out
Run the CSI-RS transmission and reception#
Run the CSI-RS generation at the transmitter side, pass the frequency-domain slot signal through the radio channel, and run the UE side CSI-RS channel estimation.
[6]:
tx_buffer = cp.zeros((num_prb * 12, num_symb_per_slot, num_tx_ant), dtype=cp.complex64)
tx_buffer = csirs_tx(config=csirs_tx_config, tx_buffers=[tx_buffer])[0]
rx_data, cfr = apply_channel(tx_buffer, esno_db, slot_idx=0)
# Convert CFR to NumPy for plotting
if hasattr(cfr, 'get'):
cfr = cfr.get()
else:
cfr = np.array(cfr)
ch_est = csirs_rx(rx_data=[rx_data], config=csirs_rx_config)
ch_est = ch_est[0][0].get()
Plot channel estimation results#
[7]:
# Number of first PRBs to plot (for better visualization)
num_prb_to_plot = 20
subc_idx = np.arange(0, num_prb_to_plot * 12, 12)
for tx_ant in range(num_tx_ant):
for rx_ant in range(num_rx_ant):
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
fig.suptitle(f"CSI-RS channel estimates for Tx antenna {tx_ant} / Rx antenna {rx_ant}")
axs[0].plot(np.real(ch_est[:num_prb_to_plot, tx_ant, rx_ant]), 'bo', label='Channel estimates')
axs[0].plot(np.real(cfr[subc_idx, 0, tx_ant, rx_ant]), 'k:', label='Channel')
axs[1].plot(np.imag(ch_est[:num_prb_to_plot, tx_ant, rx_ant]), 'bo', label='Channel estimates')
axs[1].plot(np.imag(cfr[subc_idx, 0, tx_ant, rx_ant]), 'k:', label='Channel')
axs[0].set_title("Real part")
axs[1].set_title("Imaginary part")
axs[0].set_ylim(np.real(ch_est[:num_prb_to_plot, ...]).min(), np.real(ch_est[:num_prb_to_plot, ...]).max())
axs[1].set_ylim(np.imag(ch_est[:num_prb_to_plot, ...]).min(), np.real(ch_est[:num_prb_to_plot, ...]).max())
for ax in axs:
ax.grid(True)
ax.set_xlim(0, num_prb_to_plot)
ax.set_xlabel('PRB index')
ax.legend()
axs[0].grid(True)
axs[1].grid(True)
plt.show()