# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import importlib
import logging
from typing import Tuple
import torch
from torch.autograd import Function
from physicsnemo.core.version_check import check_version_spec
logger = logging.getLogger(__name__)
NV_FUSER_AVAILABLE = check_version_spec("nvfuser", hard_fail=False)
if NV_FUSER_AVAILABLE:
nvfuser = importlib.import_module("nvfuser")
FusionDefinition = nvfuser.FusionDefinition
DataType = nvfuser.DataType
_torch_dtype_to_nvfuser = {
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.int: DataType.Int,
torch.int32: DataType.Int32,
torch.bool: DataType.Bool,
torch.bfloat16: DataType.BFloat16,
torch.cfloat: DataType.ComplexFloat,
torch.cdouble: DataType.ComplexDouble,
}
@functools.lru_cache(maxsize=None)
def silu_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
"""
nvfuser frontend implmentation of SiLU backward as a fused kernel and with
activations recomputation
Parameters
----------
fd : FusionDefition
nvFuser's FusionDefition class
dtype : torch.dtype
Data type to use for the implementation
dim : int
Dimension of the input tensor
size : torch.Size
Size of the input tensor
stride : Tuple[int, ...]
Stride of the input tensor
"""
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# z = sigmoid(x)
grad_input = fd.ops.mul(y, fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y))))
grad_input = fd.ops.cast(grad_input, dtype)
fd.add_output(grad_input)
@functools.lru_cache(maxsize=None)
def silu_double_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
"""
nvfuser frontend implmentation of SiLU double backward as a fused kernel and with
activations recomputation
Parameters
----------
fd : FusionDefition
nvFuser's FusionDefition class
dtype : torch.dtype
Data type to use for the implementation
dim : int
Dimension of the input tensor
size : torch.Size
Size of the input tensor
stride : Tuple[int, ...]
Stride of the input tensor
"""
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# z = 1 + x * (1 - y)
z = fd.ops.add(one, fd.ops.mul(x, fd.ops.sub(one, y)))
# term1 = dy * z
term1 = fd.ops.mul(dy, z)
# term2 = y * ((1 - y) - x * dy)
term2 = fd.ops.mul(y, fd.ops.sub(fd.ops.sub(one, y), fd.ops.mul(x, dy)))
grad_input = fd.ops.add(term1, term2)
grad_input = fd.ops.cast(grad_input, dtype)
fd.add_output(grad_input)
@functools.lru_cache(maxsize=None)
def silu_triple_backward_for(
fd: FusionDefinition,
dtype: torch.dtype,
dim: int,
size: torch.Size,
stride: Tuple[int, ...],
): # pragma: no cover
"""
nvfuser frontend implmentation of SiLU triple backward as a fused kernel and with
activations recomputation
Parameters
----------
fd : FusionDefition
nvFuser's FusionDefition class
dtype : torch.dtype
Data type to use for the implementation
dim : int
Dimension of the input tensor
size : torch.Size
Size of the input tensor
stride : Tuple[int, ...]
Stride of the input tensor
"""
try:
dtype = _torch_dtype_to_nvfuser[dtype]
except KeyError:
raise TypeError("Unsupported dtype")
x = fd.define_tensor(
shape=[-1] * dim,
contiguity=nvfuser.compute_contiguity(size, stride),
dtype=dtype,
)
one = fd.define_constant(1.0)
two = fd.define_constant(2.0)
# y = sigmoid(x)
y = fd.ops.sigmoid(x)
# dy = y * (1 - y)
dy = fd.ops.mul(y, fd.ops.sub(one, y))
# ddy = (1 - 2y) * dy
ddy = fd.ops.mul(fd.ops.sub(one, fd.ops.mul(two, y)), dy)
# term1 = ddy * (2 + x - 2xy)
term1 = fd.ops.mul(
ddy, fd.ops.sub(fd.ops.add(two, x), fd.ops.mul(two, fd.ops.mul(x, y)))
)
# term2 = dy * (1 - 2 (y + x * dy))
term2 = fd.ops.mul(
dy, fd.ops.sub(one, fd.ops.mul(two, fd.ops.add(y, fd.ops.mul(x, dy))))
)
grad_input = fd.ops.add(term1, term2)
grad_input = fd.ops.cast(grad_input, dtype)
fd.add_output(grad_input)
class FusedSiLU(Function):
"""
Fused SiLU activation implementation using nvfuser for a custom fused backward
with activation recomputation
"""
[docs]
@staticmethod
def forward(ctx, x):
"""
Forward method for SiLU activation
Parameters
----------
ctx :
torch context
x :
input tensor
Returns
-------
output activation
"""
ctx.save_for_backward(x)
return torch.nn.functional.silu(x)
[docs]
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
"""
Backward method for SiLU activation
Parameters
----------
ctx :
torch context
grad_output :
output gradients
Returns
-------
input gradients
"""
(x,) = ctx.saved_tensors
return FusedSiLU_deriv_1.apply(x) * grad_output
[docs]
class FusedSiLU_deriv_1(Function):
"""
Fused SiLU first derivative implementation using nvfuser
with activation recomputation
"""
[docs]
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
with FusionDefinition() as fd:
silu_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out
[docs]
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
(x,) = ctx.saved_tensors
return FusedSiLU_deriv_2.apply(x) * grad_output
[docs]
class FusedSiLU_deriv_2(Function):
"""
Fused SiLU second derivative implementation using nvfuser
with activation recomputation
"""
[docs]
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
with FusionDefinition() as fd:
silu_double_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out
[docs]
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
(x,) = ctx.saved_tensors
return FusedSiLU_deriv_3.apply(x) * grad_output
[docs]
class FusedSiLU_deriv_3(Function):
"""
Fused SiLU third derivative implementation using nvfuser
with activation recomputation
"""
[docs]
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
with FusionDefinition() as fd:
silu_triple_backward_for(fd, x.dtype, x.dim(), x.size(), x.stride())
out = fd.execute([x])[0]
return out
[docs]
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
(x,) = ctx.saved_tensors
y = torch.sigmoid(x)
dy = y * (1 - y)
ddy = (1 - 2 * y) * dy
dddy = (1 - 2 * y) * ddy - 2 * dy * dy
z = 1 - 2 * (y + x * dy)
term1 = dddy * (2 + x - 2 * x * y)
term2 = 2 * ddy * z
term3 = dy * (-2) * (2 * dy + x * ddy)
return (term1 + term2 + term3) * grad_output
else:
def raise_missing_nvfuser():
msg = "FusedSiLU:An error occured. Either nvfuser is not installed or the version is "
"incompatible. Please retry after installing correct version of nvfuser. "
"The new version of nvfuser should be available in PyTorch container version "
">= 23.10. "
"https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html. "
"If using a source install method, please refer nvFuser repo for installation "
("guidelines https://github.com/NVIDIA/Fuser.",)
raise ImportError(msg)
[docs]
class FusedSiLU(Function):
"""Placeholder for when nvfuser is not available."""
def __init__(self):
raise_missing_nvfuser()
[docs]
def silu_backward_for(*args, **kwargs):
raise_missing_nvfuser()
[docs]
def silu_double_backward_for(*args, **kwargs):
raise_missing_nvfuser()
[docs]
def silu_triple_backward_for(*args, **kwargs):
raise_missing_nvfuser()