Source code for ran.phy.jax.pusch.pusch_inner_receiver

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Optimized PUSCH Rx implementation using JAX and MLIR-TensorRT."""

from dataclasses import dataclass, fields

import numpy as np
from jax import Array
from jax import numpy as jnp

from ran.phy.jax.pusch.channel_estimation import (
    ChannelFilterConfig,
    channel_estimator,
)
from ran.phy.jax.pusch.equalizer import equalizer
from ran.phy.jax.pusch.soft_demapper import soft_demapper

# -----------------------------------------------------------------------------
# Data classes for static and dynamic inputs
# -----------------------------------------------------------------------------


@dataclass
class PuschInnerRxDynamicInputs:
    """Dynamic inputs passed at runtime to TensorRT.

    Attributes
    ----------
    xtf__rxant_sym_sc_ri : np.ndarray
        Received resource grid.
    """

    xtf__rxant_sym_sc_ri: np.ndarray

    def to_tuple(self) -> tuple:
        """Convert to ordered tuple for positional function call.

        Returns
        -------
        tuple
            Ordered tuple of field values.
        """
        return tuple(getattr(self, field.name) for field in fields(self))

    def __str__(self) -> str:
        """Return information about inputs."""
        lines = [f"\nDynamic inputs ({len(fields(self))} total):"]
        for i, field in enumerate(fields(self)):
            value = getattr(self, field.name)
            if hasattr(value, "shape"):
                shape_str = f"shape={value.shape}"
            else:
                shape_str = f"value={value}"
            dtype_str = ""
            if hasattr(value, "dtype"):
                dtype_str = f", dtype={value.dtype}"
            lines.append(f"  [{i}] {field.name}: {shape_str}{dtype_str}")
        return "\n".join(lines)


@dataclass
class PuschInnerRxStaticInputs:
    """Static inputs compiled as constants in TensorRT engine.

    These inputs are not passed at runtime.

    Attributes
    ----------
    slot_number : int
        Slot number.
    n_dmrs_id : int
        DMRS identity.
    rww_regularizer_val : float
        Regularization value for covariance matrix.
    start_prb : int
        Starting PRB index (0-based).
    nl_offset : int
        Layer offset for multi-layer processing.
    scids : tuple
        SCID selection per layer.
    apply_cov_shrinkage : bool
        Whether to apply RBLW shrinkage to covariance.
    channel_filter_method : str
        Channel filter method: 'free_energy_filter' or 'ai_tukey_filter'.
    qam_bits : int
        QAM modulation order (bits per symbol).
    dmrs_sym_idxs : tuple
        DMRS symbol indices.
    data_sym_idxs : tuple
        Data symbol indices.
    dmrs_port_nums : tuple
        Per-layer DMRS port bitfield.
    layer2ue : tuple
        Mapping from layer index to UE index.
    n_prb : int
        Number of PRBs in the allocation.
    n_ue : int
        Number of UEs.
    n_f : int
        Number of subcarriers in the full resource grid.
    n_t : int
        Number of OFDM symbols per slot.
    energy : float
        Energy scaling factor for DMRS transmission.
    channel_filter_config : ChannelFilterConfig | None, optional
        Configuration for channel filter.
    """

    slot_number: int
    n_dmrs_id: int
    rww_regularizer_val: float
    start_prb: int
    nl_offset: int
    scids: tuple
    apply_cov_shrinkage: bool
    channel_filter_method: str
    qam_bits: int
    dmrs_sym_idxs: tuple
    data_sym_idxs: tuple
    dmrs_port_nums: tuple
    layer2ue: tuple
    n_prb: int
    n_ue: int
    n_f: int
    n_t: int
    energy: float
    channel_filter_config: ChannelFilterConfig | None = None

    def to_tuple(self) -> tuple:
        """Convert to ordered tuple for positional function call.

        Returns
        -------
        tuple
            Ordered tuple of field values.
        """
        return tuple(getattr(self, field.name) for field in fields(self))

    def __str__(self) -> str:
        """Return information about inputs."""
        total = len(fields(self))
        lines = [f"\nStatic inputs ({total} total, compiled as constants):"]
        for i, field in enumerate(fields(self)):
            value = getattr(self, field.name)
            if hasattr(value, "shape"):
                shape_str = f"shape={value.shape}"
            else:
                shape_str = f"value={value}"
            dtype_str = ""
            if hasattr(value, "dtype"):
                dtype_str = f", dtype={value.dtype}"
            lines.append(f"  [{i}] {field.name}: {shape_str}{dtype_str}")
        return "\n".join(lines)


# -----------------------------------------------------------------------------
# Data classes for outputs
# -----------------------------------------------------------------------------


@dataclass
class PuschInnerRxOutputs:
    """PUSCH Inner Receiver outputs from the pusch_inner_receiver function.

    Attributes
    ----------
    llr__time_allocfreq_layer_qambit : np.ndarray
        Log-likelihood ratios with shape (n_datasym, n_allocsc, n_layer, qam_bits).
    post_eq_noise_var_db__ue_sym : np.ndarray
        Post-equalization noise variance per UE with shape (n_ue,).
    post_eq_sinr_db__ue_sym : np.ndarray
        Post-equalization SINR per UE with shape (n_ue,).
    """

    llr__time_allocfreq_layer_qambit: np.ndarray
    post_eq_noise_var_db__ue_sym: np.ndarray
    post_eq_sinr_db__ue_sym: np.ndarray

    @classmethod
    def from_tuple(cls, outputs: tuple) -> "PuschInnerRxOutputs":
        """Create PuschInnerRxOutputs from tuple.

        Parameters
        ----------
        outputs : tuple
            Tuple returned by pusch_inner_receiver function.

        Returns
        -------
        PuschInnerRxOutputs
            Instance with named fields.
        """
        return cls(
            llr__time_allocfreq_layer_qambit=outputs[0],
            post_eq_noise_var_db__ue_sym=outputs[1],
            post_eq_sinr_db__ue_sym=outputs[2],
        )

    def to_tuple(self) -> tuple:
        """Convert to ordered tuple.

        Returns
        -------
        tuple
            Ordered tuple of field values.
        """
        return tuple(getattr(self, field.name) for field in fields(self))

    def __str__(self) -> str:
        """Return debug information about outputs."""
        lines = [f"\nOutputs ({len(fields(self))} total):"]
        for i, field in enumerate(fields(self)):
            value = getattr(self, field.name)
            if hasattr(value, "shape"):
                shape_str = f"shape={value.shape}"
            else:
                shape_str = f"value={value}"
            dtype_str = ""
            if hasattr(value, "dtype"):
                dtype_str = f", dtype={value.dtype}"
            lines.append(f"  [{i}] {field.name}: {shape_str}{dtype_str}")
        return "\n".join(lines)


# -----------------------------------------------------------------------------
# PUSCH Inner Receiver function
# -----------------------------------------------------------------------------


[docs] def pusch_inner_rx( # Dynamic inputs (runtime) xtf__rxant_sym_sc_ri: Array, # Static inputs (compile-time) slot_number: jnp.int32, n_dmrs_id: jnp.int32, rww_regularizer_val: jnp.float32, start_prb: jnp.int32, nl_offset: jnp.int32, scids: tuple, apply_cov_shrinkage: bool, channel_filter_method: str, qam_bits: jnp.int32, dmrs_sym_idxs: tuple, data_sym_idxs: tuple, dmrs_port_nums: tuple, layer2ue: tuple, n_prb: jnp.int32, n_ue: jnp.int32, n_f: jnp.int32, n_t: jnp.int32, energy: jnp.float32, channel_filter_config: ChannelFilterConfig | None = None, ) -> tuple[Array, Array, Array]: """PUSCH Inner Receiver function. The PUSCH inner receiver performs the following steps: 1. DMRS-based channel estimation and covariance estimation 2. MMSE-IRC equalization 3. Soft demapping and LLR generation The function returns LLRs and post-equalization noise variance and SINR estimates. The function can be compiled with MLIR-TensorRT to a single TensorRT engine for inclusion in higher-performance C++ pipelines. The function has both dynamic and static arguments (static arguments are fixed at compile time). Parameters ---------- xtf__rxant_sym_sc_ri : Array Received resource grid with shape (n_rxant, n_sym, n_sc, 2). slot_number : jnp.int32 Slot number (compile-time constant). n_dmrs_id : jnp.int32 DMRS identity (compile-time constant). rww_regularizer_val : jnp.float32 Regularization value for covariance matrix (compile-time constant). start_prb : jnp.int32 0-based starting PRB index for the allocation (compile-time constant). nl_offset : jnp.int32 Layer offset for multi-layer processing (compile-time constant). scids : tuple SCID selection (0 or 1) per layer as tuple (compile-time constant). apply_cov_shrinkage : bool Whether to apply RBLW shrinkage to covariance (compile-time constant). channel_filter_method : str Channel filter method: 'free_energy_filter' or 'ai_tukey_filter' (compile-time constant). qam_bits : jnp.int32 QAM modulation order (bits per symbol): 1, 2, 4, 6, or 8 (compile-time constant). dmrs_sym_idxs : tuple DMRS symbol indices as tuple (compile-time constant). data_sym_idxs : tuple Data symbol indices as tuple (compile-time constant). dmrs_port_nums : tuple Per-layer DMRS port bitfield as tuple (compile-time constant). layer2ue : tuple Mapping from layer index to UE index as tuple (compile-time constant). n_prb : jnp.int32 Number of PRBs in the allocation (compile-time constant). n_ue : jnp.int32 Number of UEs (compile-time constant). n_f : jnp.int32 Number of subcarriers in the full resource grid (compile-time constant). n_t : jnp.int32 Number of OFDM symbols per slot (compile-time constant). energy : jnp.float32 Energy scaling factor for DMRS transmission (compile-time constant). channel_filter_config : ChannelFilterConfig | None, optional Configuration for channel filter. Required when channel_filter_method is 'ai_tukey_filter' (compile-time constant). Returns ------- llr__time_allocfreq_layer_qambit : Array LLRs with shape (n_datasym, n_allocsc, n_layer, qam_bits). post_eq_noise_var_db__ue : Array Post-equalization noise variance per UE with shape (n_ue,). post_eq_sinr_db__ue : Array Post-equalization SINR per UE with shape (n_ue,). """ # Convert input format # C++ pipeline emits (column-major, real/imag interleaved). # Transpose for row-major layout needed by TensorRT (row-major, real/imag stacked) xtf__ri_sym_rxant_sc = jnp.einsum("abcd->dbac", xtf__rxant_sym_sc_ri) # Channel estimation ( h_interp__ri_port_rxant_sc, n_cov__ri_rxant_rxant_prb, _h_est__ri_port_dsym_rxant_dsc, # Unused in main pipeline _noise_var_db, _rsrp_db__ue_dsym, _sinr_db__ue_dsym, ) = channel_estimator( xtf__ri_sym_rxant_sc=xtf__ri_sym_rxant_sc, slot_number=slot_number, n_dmrs_id=n_dmrs_id, rww_regularizer_val=rww_regularizer_val, start_prb=start_prb, scids=scids, apply_cov_shrinkage=apply_cov_shrinkage, channel_filter_method=channel_filter_method, dmrs_sym_idxs=dmrs_sym_idxs, dmrs_port_nums=dmrs_port_nums, layer2ue=layer2ue, n_prb=n_prb, n_ue=n_ue, n_f=n_f, n_t=n_t, energy=energy, channel_filter_config=channel_filter_config, ) # MIMO equalizer (x_est__ri_port_datasym_sc, ree__port_sc, post_eq_noise_var_db__ue, post_eq_sinr_db__ue) = ( equalizer( xtf__ri_sym_rxant_sc=xtf__ri_sym_rxant_sc, h_interp__ri_port_rxant_sc=h_interp__ri_port_rxant_sc, n_cov__ri_rxant_rxant_prb=n_cov__ri_rxant_rxant_prb, data_sym_idxs=data_sym_idxs, layer2ue=layer2ue, n_ue=n_ue, start_prb=start_prb, n_prb=n_prb, ) ) # Soft demapping llr__qambit_layer_allocfreq_time = soft_demapper( x_est__ri_port_datasym_sc=x_est__ri_port_datasym_sc, ree__port_sc=ree__port_sc, start_prb=start_prb, nl_offset=nl_offset, qam_bits=qam_bits, n_prb=n_prb, ) # Output LLR format conversion. # C++ pipeline ingests column-major interleaved format (time, freq, layer, qam_bits) # Also cast to float16 to match C++ pipeline. llr__time_allocfreq_layer_qambit = jnp.copy( jnp.einsum("abcd->dcba", llr__qambit_layer_allocfreq_time) ).astype(jnp.float16) return ( llr__time_allocfreq_layer_qambit, post_eq_noise_var_db__ue, post_eq_sinr_db__ue, )