PUSCH#
JAX implementations of Physical Uplink Shared Channel algorithms.
Overview#
The PUSCH JAX module provides differentiable implementations of the inner receiver signal processing chain. These implementations can be lowered to TensorRT for high-performance execution and are organized into signal processing stages:
Inner Receiver (Signal Processing)
Channel Estimation - Traditional and neural network-based channel estimators that can be trained end-to-end
AI Tukey Filter - ML-based channel estimation filter with pretrained models for improved performance in challenging channel conditions
Noise Estimation - Noise covariance estimation for MMSE equalization
Delay Compensation - Time-domain delay correction
Equalization - MMSE channel equalization
Free Energy Filter - Advanced filtering for channel estimation refinement
Soft Demapping - LLR generation from equalized symbols
Signal Quality Metrics - Noise variance, RSRP, and SINR computation
End-to-End Processing
Complete Inner Receiver - Full signal processing pipeline that can be lowered to TensorRT for real-time execution
API Reference#
PUSCH optimized package.
- ran.phy.jax.pusch.awgn(rng: jax.Array, H: jax.Array, snr_db: float) jax.Array[source]#
Add AWGN to channel (JAX version).
- Parameters:
rng – JAX PRNG key
H – Channel with shape (n_sc, n_sym, n_ant) complex
snr_db – SNR in dB
- Returns:
Noisy channel
- ran.phy.jax.pusch.pusch_inner_rx(
- xtf__rxant_sym_sc_ri: jax.Array,
- slot_number: jax.numpy.int32,
- n_dmrs_id: jax.numpy.int32,
- rww_regularizer_val: jax.numpy.float32,
- start_prb: jax.numpy.int32,
- nl_offset: jax.numpy.int32,
- scids: tuple,
- apply_cov_shrinkage: bool,
- channel_filter_method: str,
- qam_bits: jax.numpy.int32,
- dmrs_sym_idxs: tuple,
- data_sym_idxs: tuple,
- dmrs_port_nums: tuple,
- layer2ue: tuple,
- n_prb: jax.numpy.int32,
- n_ue: jax.numpy.int32,
- n_f: jax.numpy.int32,
- n_t: jax.numpy.int32,
- energy: jax.numpy.float32,
- channel_filter_config: AITukeyFilterConfig | FreeEnergyFilterConfig | IdentityFilterConfig | WeightedThresholdFilterConfig | None = None,
PUSCH Inner Receiver function.
The PUSCH inner receiver performs the following steps:
DMRS-based channel estimation and covariance estimation
MMSE-IRC equalization
Soft demapping and LLR generation
The function returns LLRs and post-equalization noise variance and SINR estimates.
The function can be compiled with MLIR-TensorRT to a single TensorRT engine for inclusion in higher-performance C++ pipelines.
The function has both dynamic and static arguments (static arguments are fixed at compile time).
- Parameters:
xtf__rxant_sym_sc_ri (Array) – Received resource grid with shape (n_rxant, n_sym, n_sc, 2).
slot_number (jnp.int32) – Slot number (compile-time constant).
n_dmrs_id (jnp.int32) – DMRS identity (compile-time constant).
rww_regularizer_val (jnp.float32) – Regularization value for covariance matrix (compile-time constant).
start_prb (jnp.int32) – 0-based starting PRB index for the allocation (compile-time constant).
nl_offset (jnp.int32) – Layer offset for multi-layer processing (compile-time constant).
scids (tuple) – SCID selection (0 or 1) per layer as tuple (compile-time constant).
apply_cov_shrinkage (bool) – Whether to apply RBLW shrinkage to covariance (compile-time constant).
channel_filter_method (str) – Channel filter method: ‘free_energy_filter’ or ‘ai_tukey_filter’ (compile-time constant).
qam_bits (jnp.int32) – QAM modulation order (bits per symbol): 1, 2, 4, 6, or 8 (compile-time constant).
dmrs_sym_idxs (tuple) – DMRS symbol indices as tuple (compile-time constant).
data_sym_idxs (tuple) – Data symbol indices as tuple (compile-time constant).
dmrs_port_nums (tuple) – Per-layer DMRS port bitfield as tuple (compile-time constant).
layer2ue (tuple) – Mapping from layer index to UE index as tuple (compile-time constant).
n_prb (jnp.int32) – Number of PRBs in the allocation (compile-time constant).
n_ue (jnp.int32) – Number of UEs (compile-time constant).
n_f (jnp.int32) – Number of subcarriers in the full resource grid (compile-time constant).
n_t (jnp.int32) – Number of OFDM symbols per slot (compile-time constant).
energy (jnp.float32) – Energy scaling factor for DMRS transmission (compile-time constant).
channel_filter_config (ChannelFilterConfig | None, optional) – Configuration for channel filter. Required when channel_filter_method is ‘ai_tukey_filter’ (compile-time constant).
- Returns:
llr__time_allocfreq_layer_qambit (Array) – LLRs with shape (n_datasym, n_allocsc, n_layer, qam_bits).
post_eq_noise_var_db__ue (Array) – Post-equalization noise variance per UE with shape (n_ue,).
post_eq_sinr_db__ue (Array) – Post-equalization SINR per UE with shape (n_ue,).