Source code for ran.phy.numpy.pusch.channel_estimation

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""NumPy-based PUSCH LS and delay-domain channel estimation (5G NR).

Implements efficient, test-aligned block routines for:
- channel_est_ls: Minimal LS on DMRS tones only (shape: 6*n_prb x nl x n_ant)
- channel_est_dd: LS + delay-domain truncation + interpolation
  (shape: 12*n_prb x nl x n_ant, allocation slice only)

Indexing convention:
- start_prb, sym_idx_dmrs, port_idx: 0-based
- xtf: (n_f, n_t, n_ant)
- r_dmrs: (n_f, n_t, 2)
- sym_idx_dmrs: (n_dmrs_sym,)
"""

import numpy as np

from ran.types import ComplexArrayNP


[docs] def channel_est_ls(x_dmrs: ComplexArrayNP, y_dmrs: ComplexArrayNP) -> ComplexArrayNP: """ Core LS on precomputed PRB-band DMRS slices. Args ---- x_dmrs: (12*n_prb, n_sym, nl) y_dmrs: (12*n_prb, n_sym, nl, n_ant) Returns ------- h_compact : (6*n_prb, nl, n_ant) Compact LS per DMRS tone (averaged across DMRS symbols). """ # Matched filter, keep even bins, average across DMRS symbols y_mf = np.conj(x_dmrs)[..., None] * y_dmrs # (12*n_prb, n_sym, nl, n_ant) return y_mf[::2, :, :, :].mean(axis=1) # (6*n_prb, nl, n_ant)
def _delay_domain_truncation( h_compact: ComplexArrayNP, trunc_ratio: float = 0.05, ) -> ComplexArrayNP: """ IFFT/FFT refinement on compact LS tones WITH truncation. Args ---- h_compact: (lc, nl, n_ant) Compact LS channel estimates with subcarriers on the first axis. trunc_ratio: Fraction (0.0 to 1.0) of the delay-domain taps to keep (e.g., 0.05 keeps 5%). Returns ------- h_ref: (lc, nl, n_ant) Refined compact LS tones (after truncation in delay domain). """ lc = h_compact.shape[0] trunc_len = max(1, round(trunc_ratio * lc)) # Go to delay domain h_delay = np.fft.ifft(h_compact, n=lc, axis=0) # Zero out taps beyond trunc_len h_delay[trunc_len:, ...] = 0.0 # Back to frequency domain return np.fft.fft(h_delay, n=lc, axis=0) def _interpolate( h_compact: ComplexArrayNP, # (lc, nl, n_ant) ) -> ComplexArrayNP: """ Expand compact LS tones (every 2 subcarriers) to the allocation slice only. Args ---- h_compact: (lc, nl, n_ant) Compact LS channel estimates with subcarriers on the first axis. Returns ------- h_alloc: (2*lc, nl, n_ant) Allocation slice with each compact bin repeated into a 2-bin block. Caller is responsible for assigning this slice into the full grid. """ # Repeat along subcarrier axis (first axis) return np.repeat(h_compact, 2, axis=0)
[docs] def channel_est_dd(x_dmrs: ComplexArrayNP, y_dmrs: ComplexArrayNP) -> ComplexArrayNP: """ Delay-domain truncation + interpolation on precomputed DMRS slices. Args ---- x_dmrs: (12*n_prb, n_sym, nl) Transmitted DMRS reference symbols, post scrambling & OCC (subcarrier, symbol, layer). y_dmrs: (12*n_prb, n_sym, nl, n_ant) Received demapped DMRS symbols (subcarrier, symbol, layer, antenna). Returns ------- h_alloc : (12*n_prb, nl, n_ant, n_dmrs_sym) Steps: 1) Compact LS on DMRS tones from slices. 2) IFFT/FFT with delay-domain truncation. 3) Interpolate to allocation (repeat along frequency). Note: n_dmrs_sym is the number of DMRS symbols per slot, currently only supporting 1. """ h_compact = channel_est_ls(x_dmrs=x_dmrs, y_dmrs=y_dmrs) h_compact_ref = _delay_domain_truncation(h_compact, trunc_ratio=0.05) return _interpolate(h_compact_ref)[..., None]
__all__ = [ "channel_est_dd", "channel_est_ls", ]