# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
import torch.nn as nn
from torch import Tensor
import modulus # noqa: F401 for docs
from modulus.models.layers import FCLayer, get_activation
from ..meta import ModelMetaData
from ..module import Module
[docs]class FullyConnected(Module):
"""A densely-connected MLP architecture
Parameters
----------
in_features : int, optional
Size of input features, by default 512
layer_size : int, optional
Size of every hidden layer, by default 512
out_features : int, optional
Size of output features, by default 512
num_layers : int, optional
Number of hidden layers, by default 6
activation_fn : Union[str, List[str]], optional
Activation function to use, by default 'silu'
skip_connections : bool, optional
Add skip connections every 2 hidden layers, by default False
adaptive_activations : bool, optional
Use an adaptive activation function, by default False
weight_norm : bool, optional
Use weight norm on fully connected layers, by default False
weight_fact : bool, optional
Use weight factorization on fully connected layers, by default False
Example
-------
>>> model = modulus.models.mlp.FullyConnected(in_features=32, out_features=64)
>>> input = torch.randn(128, 32)
>>> output = model(input)
>>> output.size()
torch.Size([128, 64])
"""
def __init__(
self,
in_features: int = 512,
layer_size: int = 512,
out_features: int = 512,
num_layers: int = 6,
activation_fn: Union[str, List[str]] = "silu",
skip_connections: bool = False,
adaptive_activations: bool = False,
weight_norm: bool = False,
weight_fact: bool = False,
) -> None:
super().__init__(meta=MetaData())
self.skip_connections = skip_connections
if adaptive_activations:
activation_par = nn.Parameter(torch.ones(1))
else:
activation_par = None
if not isinstance(activation_fn, list):
activation_fn = [activation_fn] * num_layers
if len(activation_fn) < num_layers:
activation_fn = activation_fn + [activation_fn[-1]] * (
num_layers - len(activation_fn)
)
activation_fn = [get_activation(a) for a in activation_fn]
self.layers = nn.ModuleList()
layer_in_features = in_features
for i in range(num_layers):
self.layers.append(
FCLayer(
layer_in_features,
layer_size,
activation_fn[i],
weight_norm,
weight_fact,
activation_par,
)
)
layer_in_features = layer_size
self.final_layer = FCLayer(
in_features=layer_size,
out_features=out_features,
activation_fn=None,
weight_norm=False,
weight_fact=False,
activation_par=None,
)
[docs] def forward(self, x: Tensor) -> Tensor:
x_skip: Optional[Tensor] = None
for i, layer in enumerate(self.layers):
x = layer(x)
if self.skip_connections and i % 2 == 0:
if x_skip is not None:
x, x_skip = x + x_skip, x
else:
x_skip = x
x = self.final_layer(x)
return x