# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DMRS helper utilities shared by LS and covariance paths.
This module centralizes small, well-tested building blocks to construct the
"transmitted DMRS at REs" reference and the indices/patterns needed by both
Least-Squares (LS) channel estimation and covariance estimation.
Design goals
------------
- Single source of truth for DMRS indices, OCC patterns, and per-layer grid
selection rules so LS and covariance remain in lockstep.
- Broadcast-friendly outputs to avoid unnecessary ``np.tile`` allocations.
- Explicit shapes in docstrings for easy integration and review.
"""
import numpy as np
from ran.phy.numpy.utils import focc_dmrs, tocc_dmrs
from ran.types import ComplexArrayNP, IntArrayNP, IntNP
def dmrs_even_rows_from_lc(lc: int) -> IntArrayNP:
"""Return even-row indices for DMRS tones given lc=6*n_prb."""
return 2 * np.arange(lc, dtype=IntNP)
def dmrs_focc_pattern(nf: int, n_sym: int) -> ComplexArrayNP:
"""Frequency-domain OCC pattern tiled to (lc, n_sym)."""
return focc_dmrs(nf)[:, None].repeat(n_sym, axis=1)
def dmrs_tocc_pattern(nf: int, n_sym: int) -> ComplexArrayNP:
"""Time-domain OCC pattern tiled to (lc, n_sym)."""
return tocc_dmrs(n_sym)[None, :].repeat(nf, axis=0)
def parse_port_cfg(port_idx: IntArrayNP) -> tuple[IntArrayNP, IntArrayNP, IntArrayNP]:
"""Parse per-layer DMRS config bitfields from ``port_idx``.
Parameters
----------
port_idx : ndarray of int64, shape (nl,)
Per-layer bitfields with the following layout (LS convention):
- bit0: fOCC enable (0/1)
- bit1: grid select (0 even, 1 odd) - determines DMRS grid offset
- bit2: tOCC enable (0/1)
Returns
-------
focc_cfg : ndarray of int64, shape (nl,)
grid_cfg : ndarray of int64, shape (nl,)
tocc_cfg : ndarray of int64, shape (nl,)
"""
focc_cfg = (port_idx & 0b001).astype(IntNP)
grid_cfg = ((port_idx & 0b010) >> 1).astype(IntNP)
tocc_cfg = ((port_idx & 0b100) >> 2).astype(IntNP)
return focc_cfg, grid_cfg, tocc_cfg
def layer_grid_offsets(port_idx: IntArrayNP, grid_table: IntArrayNP) -> IntArrayNP:
"""Return per-layer grid frequency offset ``delta`` from ``grid_table``.
Parameters
----------
port_idx : ndarray of int64, shape (nl,)
0-based layer bitfields (index into ``grid_table``).
grid_table : ndarray of int64, shape (>= max(port_idx)+1,)
Lookup table mapping port index to a grid offset (typically 0 or 1).
Returns
-------
deltas : ndarray of int64, shape (nl,)
Per-layer offsets to be added to PRB-band indices.
"""
return grid_table[port_idx.astype(IntNP)]
def build_layer_freq_indices(
freq_idx_dmrs_local: IntArrayNP,
port_idx: IntArrayNP,
grid_table: IntArrayNP,
) -> IntArrayNP:
"""Build per-layer PRB-band DMRS row indices with grid offsets.
Parameters
----------
freq_idx_dmrs_local : ndarray of int64, shape (lc,)
Local PRB-band DMRS rows (even subcarriers only).
port_idx : ndarray of int64, shape (nl,)
0-based per-layer bitfields (index into ``grid_table``).
grid_table : ndarray of int64, shape (>= max(port_idx)+1,)
Lookup table mapping port index to a grid offset (0/1).
Returns
-------
layer_freq : ndarray of int64, shape (lc, nl)
Per-layer indices inside the PRB band after applying per-layer grid
offsets. To get absolute subcarrier rows, add ``base = 12*start_prb``.
"""
deltas = layer_grid_offsets(port_idx=port_idx, grid_table=grid_table)
return (freq_idx_dmrs_local[:, None] + deltas[None, :]).astype(IntNP)
[docs]
def embed_dmrs_ul(
r_dmrs: ComplexArrayNP,
nl: int,
port_idx: IntArrayNP,
vec_scid: IntArrayNP,
energy: float,
) -> ComplexArrayNP:
"""Construct PRB-band DMRS slice for UL (allocation-local embedding).
Parameters
----------
r_dmrs : ndarray of complex128, shape (12*n_prb, n_dmrs_sym, 2)
DMRS references already trimmed to band rows and DMRS symbols.
nl : int
Number of layers.
port_idx : ndarray of int64, shape (nl,)
Per-layer bitfields (bit0: fOCC, bit1: grid select, bit2: tOCC).
vec_scid : ndarray of int64, shape (nl,)
Per-layer SCID selector in {0, 1}.
energy : float
Energy scaling factor to apply to the embedded DMRS.
Returns
-------
x_dmrs : ndarray of complex128, shape (12*n_prb, n_sym, nl)
PRB-band DMRS slice. Only DMRS REs are non-zero; other tones in the
PRB band are zero.
"""
lc, n_sym = r_dmrs.shape[:2]
prb_len = 2 * lc
# Initialize PRB-band grid (freq x dmrs_sym x layer)
x_dmrs = np.zeros((prb_len, n_sym, nl), dtype=r_dmrs.dtype)
# Indices and OCC patterns (broadcast-friendly)
freq_idx_dmrs_local = dmrs_even_rows_from_lc(lc)
focc = dmrs_focc_pattern(lc, n_sym)
tocc = dmrs_tocc_pattern(lc, n_sym)
# Energy scaling
r_scaled = r_dmrs * np.sqrt(energy)
# Per-layer enables and grid offsets from bitfields
focc_cfg, grid_cfg, tocc_cfg = parse_port_cfg(port_idx)
for layer in range(nl):
scid = vec_scid[layer]
r_sel = r_scaled[..., scid]
r_l = focc * r_sel if focc_cfg[layer] != 0 else r_sel
r_l = tocc * r_l if tocc_cfg[layer] != 0 else r_l
freq_ix = (freq_idx_dmrs_local + grid_cfg[layer]).astype(IntNP)
x_dmrs[freq_ix, :, layer] = r_l
return x_dmrs
__all__ = [
"build_layer_freq_indices",
"embed_dmrs_ul",
"extract_raw_dmrs_type_1",
"layer_grid_offsets",
"parse_port_cfg",
]