deeplearning/modulus/modulus-core-v030/_modules/modulus/models/sfno/factorizations.html

Core v0.3.0

Source code for modulus.models.sfno.factorizations

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

import tensorly as tl

tl.set_backend("pytorch")

from functools import partial

from modulus.models.sfno.contractions import (
    _contract_diagonal,
    _contract_dhconv,
    _contract_sep_diagonal,
    _contract_sep_dhconv,
    _contract_diagonal_real,
    _contract_dhconv_real,
    _contract_sep_diagonal_real,
    _contract_sep_dhconv_real,
)

from tltorch.factorized_tensors.core import FactorizedTensor

einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"


def _contract_dense(
    x, weight, separable=False, operator_type="diagonal"
):  # pragma: no cover
    order = tl.ndim(x)
    # batch-size, in_channels, x, y...
    x_syms = list(einsum_symbols[:order])

    # in_channels, out_channels, x, y...
    weight_syms = list(x_syms[1:])  # no batch-size

    # batch-size, out_channels, x, y...
    if separable:
        out_syms = [x_syms[0]] + list(weight_syms)
    else:
        weight_syms.insert(1, einsum_symbols[order])  # outputs
        out_syms = list(weight_syms)
        out_syms[0] = x_syms[0]

    if operator_type == "diagonal":
        pass
    elif operator_type == "block-diagonal":
        weight_syms.insert(-1, einsum_symbols[order + 1])
        out_syms[-1] = weight_syms[-2]
    elif operator_type == "dhconv":
        weight_syms.pop()
    else:
        raise ValueError(f"Unkonw operator type {operator_type}")

    eq = "".join(x_syms) + "," + "".join(weight_syms) + "->" + "".join(out_syms)

    if not torch.is_tensor(weight):
        weight = weight.to_tensor()

    return tl.einsum(eq, x, weight)


def _contract_cp(
    x, cp_weight, separable=False, operator_type="diagonal"
):  # pragma: no cover
    order = tl.ndim(x)

    x_syms = str(einsum_symbols[:order])
    rank_sym = einsum_symbols[order]
    out_sym = einsum_symbols[order + 1]
    out_syms = list(x_syms)

    if separable:
        factor_syms = [einsum_symbols[1] + rank_sym]  # in only
    else:
        out_syms[1] = out_sym
        factor_syms = [einsum_symbols[1] + rank_sym, out_sym + rank_sym]  # in, out

    factor_syms += [xs + rank_sym for xs in x_syms[2:]]  # x, y, ...

    if operator_type == "diagonal":
        pass
    elif operator_type == "block-diagonal":
        out_syms[-1] = einsum_symbols[order + 2]
        factor_syms += [out_syms[-1] + rank_sym]
    elif operator_type == "dhconv":
        factor_syms.pop()
    else:
        raise ValueError(f"Unkonw operator type {operator_type}")

    eq = (
        x_syms + "," + rank_sym + "," + ",".join(factor_syms) + "->" + "".join(out_syms)
    )

    return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors)


def _contract_tucker(
    x, tucker_weight, separable=False, operator_type="diagonal"
):  # pragma: no cover
    order = tl.ndim(x)

    x_syms = str(einsum_symbols[:order])
    out_sym = einsum_symbols[order]
    out_syms = list(x_syms)
    if separable:
        core_syms = einsum_symbols[order + 1 : 2 * order]
        # factor_syms = [einsum_symbols[1]+core_syms[0]] #in only
        factor_syms = [xs + rs for (xs, rs) in zip(x_syms[1:], core_syms)]  # x, y, ...

    else:
        core_syms = einsum_symbols[order + 1 : 2 * order + 1]
        out_syms[1] = out_sym
        factor_syms = [
            einsum_symbols[1] + core_syms[0],
            out_sym + core_syms[1],
        ]  # out, in
        factor_syms += [
            xs + rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])
        ]  # x, y, ...

    if operator_type == "diagonal":
        pass
    elif operator_type == "block-diagonal":
        raise NotImplementedError(
            f"Operator type {operator_type} not implemented for Tucker"
        )
    else:
        raise ValueError(f"Unkonw operator type {operator_type}")

    eq = (
        x_syms
        + ","
        + core_syms
        + ","
        + ",".join(factor_syms)
        + "->"
        + "".join(out_syms)
    )

    return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors)


def _contract_tt(
    x, tt_weight, separable=False, operator_type="diagonal"
):  # pragma: no cover
    order = tl.ndim(x)

    x_syms = list(einsum_symbols[:order])
    weight_syms = list(x_syms[1:])  # no batch-size

    if not separable:
        weight_syms.insert(1, einsum_symbols[order])  # outputs
        out_syms = list(weight_syms)
        out_syms[0] = x_syms[0]
    else:
        out_syms = list(x_syms)

    if operator_type == "diagonal":
        pass
    elif operator_type == "block-diagonal":
        weight_syms.insert(-1, einsum_symbols[order + 1])
        out_syms[-1] = weight_syms[-2]
    elif operator_type == "dhconv":
        weight_syms.pop()
    else:
        raise ValueError(f"Unkonw operator type {operator_type}")

    rank_syms = list(einsum_symbols[order + 2 :])
    tt_syms = []
    for i, s in enumerate(weight_syms):
        tt_syms.append([rank_syms[i], s, rank_syms[i + 1]])
    eq = (
        "".join(x_syms)
        + ","
        + ",".join("".join(f) for f in tt_syms)
        + "->"
        + "".join(out_syms)
    )

    return tl.einsum(eq, x, *tt_weight.factors)


# jitted PyTorch contractions:
def _contract_dense_pytorch(
    x, weight, separable=False, operator_type="diagonal", complex=True
):  # pragma: no cover

    # to cheat the fused optimizers convert to real here
    x = torch.view_as_real(x)

    if separable:
        if operator_type == "diagonal":
            if complex:
                x = _contract_sep_diagonal(x, weight)
            else:
                x = _contract_sep_diagonal_real(x, weight)
        elif operator_type == "dhconv":
            if complex:
                x = _contract_sep_dhconv(x, weight)
            else:
                x = _contract_sep_dhconv_real(x, weight)
        else:
            raise ValueError(f"Unkonw operator type {operator_type}")
    else:
        if operator_type == "diagonal":
            if complex:
                x = _contract_diagonal(x, weight)
            else:
                x = _contract_diagonal_real(x, weight)
        elif operator_type == "dhconv":
            if complex:
                x = _contract_dhconv(x, weight)
            else:
                x = _contract_dhconv_real(x, weight)
        else:
            raise ValueError(f"Unkonw operator type {operator_type}")

    # to cheat the fused optimizers convert to real here
    x = torch.view_as_complex(x)
    return x


[docs]def get_contract_fun( weight, implementation="reconstructed", separable=False, operator_type="diagonal", complex=True, ): # pragma: no cover """Generic ND implementation of Fourier Spectral Conv contraction Parameters ---------- weight : tensorly-torch's FactorizedTensor implementation : {'reconstructed', 'factorized'}, default is 'reconstructed' whether to reconstruct the weight and do a forward pass (reconstructed) or contract directly the factors of the factorized weight with the input (factorized) Returns ------- function : (x, weight) -> x * weight in Fourier space """ if implementation == "reconstructed": return _contract_dense elif implementation == "factorized": if torch.is_tensor(weight): handle = partial( _contract_dense_pytorch, separable=separable, complex=complex, operator_type=operator_type, ) return handle elif isinstance(weight, FactorizedTensor): if weight.name.lower() == "complexdense" or weight.name.lower() == "dense": return _contract_dense elif weight.name.lower() == "complextucker": return _contract_tucker elif weight.name.lower() == "complextt": return _contract_tt elif weight.name.lower() == "complexcp": return _contract_cp else: raise ValueError(f"Got unexpected factorized weight type {weight.name}") else: raise ValueError( f"Got unexpected weight type of class {weight.__class__.__name__}" ) else: raise ValueError( f'Got {implementation}, expected "reconstructed" or "factorized"' )
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.