Source code for ran.phy.jax.utils.complex_ops

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Complex arithmetic operations on stacked real/imag tensors.

Operations on tensors where dimension 0 contains [real, imag] components.
"""

import jax.numpy as jnp
from jax import Array


[docs] def complex_mul( a__ri: Array, b__ri: Array, ) -> Array: """Complex multiplication: a * b. Computes (a_real + i*a_imag) * (b_real + i*b_imag) for tensors where the first dimension contains [real, imag] components. Args: a__ri: Complex tensor with shape (2, ...) b__ri: Complex tensor with shape (2, ...) Returns: result__ri: Complex product with shape (2, ...) Notes: Formula: (a+bi) * (c+di) = (ac-bd) + (ad+bc)i - result[0] = real part = a[0]*b[0] - a[1]*b[1] - result[1] = imag part = a[0]*b[1] + a[1]*b[0] """ real_part = a__ri[0] * b__ri[0] - a__ri[1] * b__ri[1] imag_part = a__ri[0] * b__ri[1] + a__ri[1] * b__ri[0] return jnp.stack([real_part, imag_part], axis=0)
[docs] def complex_mul_conj( a__ri: Array, b__ri: Array, ) -> Array: """Complex multiplication with conjugate: a * conj(b). Computes (a_real + i*a_imag) * (b_real - i*b_imag) for tensors where the first dimension contains [real, imag] components. Args: a__ri: Complex tensor with shape (2, ...) b__ri: Complex tensor with shape (2, ...) Returns: result__ri: Complex product with shape (2, ...) Notes: Formula: (a+bi) * (c-di) = (ac+bd) + (bc-ad)i - result[0] = real part = a[0]*b[0] + a[1]*b[1] - result[1] = imag part = a[1]*b[0] - a[0]*b[1] """ real_part = a__ri[0] * b__ri[0] + a__ri[1] * b__ri[1] imag_part = a__ri[1] * b__ri[0] - a__ri[0] * b__ri[1] return jnp.stack([real_part, imag_part], axis=0)
__all__ = ["complex_mul", "complex_mul_conj"]