PUSCH#

JAX implementations of Physical Uplink Shared Channel algorithms.

Overview#

The PUSCH JAX module provides differentiable implementations of the inner receiver signal processing chain. These implementations can be lowered to TensorRT for high-performance execution and are organized into signal processing stages:

Inner Receiver (Signal Processing)

  • Channel Estimation - Traditional and neural network-based channel estimators that can be trained end-to-end

  • AI Tukey Filter - ML-based channel estimation filter with pretrained models for improved performance in challenging channel conditions

  • Noise Estimation - Noise covariance estimation for MMSE equalization

  • Delay Compensation - Time-domain delay correction

  • Equalization - MMSE channel equalization

  • Free Energy Filter - Advanced filtering for channel estimation refinement

  • Soft Demapping - LLR generation from equalized symbols

  • Signal Quality Metrics - Noise variance, RSRP, and SINR computation

End-to-End Processing

  • Complete Inner Receiver - Full signal processing pipeline that can be lowered to TensorRT for real-time execution

API Reference#

PUSCH optimized package.

ran.phy.jax.pusch.awgn(rng: jax.Array, H: jax.Array, snr_db: float) jax.Array[source]#

Add AWGN to channel (JAX version).

Parameters:
  • rng – JAX PRNG key

  • H – Channel with shape (n_sc, n_sym, n_ant) complex

  • snr_db – SNR in dB

Returns:

Noisy channel

ran.phy.jax.pusch.pusch_inner_rx(
xtf__rxant_sym_sc_ri: jax.Array,
slot_number: jax.numpy.int32,
n_dmrs_id: jax.numpy.int32,
rww_regularizer_val: jax.numpy.float32,
start_prb: jax.numpy.int32,
nl_offset: jax.numpy.int32,
scids: tuple,
apply_cov_shrinkage: bool,
channel_filter_method: str,
qam_bits: jax.numpy.int32,
dmrs_sym_idxs: tuple,
data_sym_idxs: tuple,
dmrs_port_nums: tuple,
layer2ue: tuple,
n_prb: jax.numpy.int32,
n_ue: jax.numpy.int32,
n_f: jax.numpy.int32,
n_t: jax.numpy.int32,
energy: jax.numpy.float32,
channel_filter_config: AITukeyFilterConfig | FreeEnergyFilterConfig | IdentityFilterConfig | WeightedThresholdFilterConfig | None = None,
) tuple[jax.Array, jax.Array, jax.Array][source]#

PUSCH Inner Receiver function.

The PUSCH inner receiver performs the following steps:

  1. DMRS-based channel estimation and covariance estimation

  2. MMSE-IRC equalization

  3. Soft demapping and LLR generation

The function returns LLRs and post-equalization noise variance and SINR estimates.

The function can be compiled with MLIR-TensorRT to a single TensorRT engine for inclusion in higher-performance C++ pipelines.

The function has both dynamic and static arguments (static arguments are fixed at compile time).

Parameters:
  • xtf__rxant_sym_sc_ri (Array) – Received resource grid with shape (n_rxant, n_sym, n_sc, 2).

  • slot_number (jnp.int32) – Slot number (compile-time constant).

  • n_dmrs_id (jnp.int32) – DMRS identity (compile-time constant).

  • rww_regularizer_val (jnp.float32) – Regularization value for covariance matrix (compile-time constant).

  • start_prb (jnp.int32) – 0-based starting PRB index for the allocation (compile-time constant).

  • nl_offset (jnp.int32) – Layer offset for multi-layer processing (compile-time constant).

  • scids (tuple) – SCID selection (0 or 1) per layer as tuple (compile-time constant).

  • apply_cov_shrinkage (bool) – Whether to apply RBLW shrinkage to covariance (compile-time constant).

  • channel_filter_method (str) – Channel filter method: ‘free_energy_filter’ or ‘ai_tukey_filter’ (compile-time constant).

  • qam_bits (jnp.int32) – QAM modulation order (bits per symbol): 1, 2, 4, 6, or 8 (compile-time constant).

  • dmrs_sym_idxs (tuple) – DMRS symbol indices as tuple (compile-time constant).

  • data_sym_idxs (tuple) – Data symbol indices as tuple (compile-time constant).

  • dmrs_port_nums (tuple) – Per-layer DMRS port bitfield as tuple (compile-time constant).

  • layer2ue (tuple) – Mapping from layer index to UE index as tuple (compile-time constant).

  • n_prb (jnp.int32) – Number of PRBs in the allocation (compile-time constant).

  • n_ue (jnp.int32) – Number of UEs (compile-time constant).

  • n_f (jnp.int32) – Number of subcarriers in the full resource grid (compile-time constant).

  • n_t (jnp.int32) – Number of OFDM symbols per slot (compile-time constant).

  • energy (jnp.float32) – Energy scaling factor for DMRS transmission (compile-time constant).

  • channel_filter_config (ChannelFilterConfig | None, optional) – Configuration for channel filter. Required when channel_filter_method is ‘ai_tukey_filter’ (compile-time constant).

Returns:

  • llr__time_allocfreq_layer_qambit (Array) – LLRs with shape (n_datasym, n_allocsc, n_layer, qam_bits).

  • post_eq_noise_var_db__ue (Array) – Post-equalization noise variance per UE with shape (n_ue,).

  • post_eq_sinr_db__ue (Array) – Post-equalization SINR per UE with shape (n_ue,).