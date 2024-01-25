# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import warnings def _no_grad_trunc_normal_ ( tensor , mean , std , a , b ): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf ( x ): # pragma: no cover # Computes standard normal cumulative distribution function return ( 1.0 + math . erf ( x / math . sqrt ( 2.0 ))) / 2.0 if ( mean < a - 2 * std ) or ( mean > b + 2 * std ): warnings . warn ( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect." , stacklevel = 2 , ) with torch . no_grad (): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf (( a - mean ) / std ) u = norm_cdf (( b - mean ) / std ) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor . uniform_ ( 2 * l - 1 , 2 * u - 1 ) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor . erfinv_ () # Transform to proper mean, std tensor . mul_ ( std * math . sqrt ( 2.0 )) tensor . add_ ( mean ) # Clamp to ensure it's in the proper range tensor . clamp_ ( min = a , max = b ) return tensor [docs] def trunc_normal_ ( tensor , mean = 0.0 , std = 1.0 , a =- 2.0 , b = 2.0 ): # pragma: no cover r """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ return _no_grad_trunc_normal_ ( tensor , mean , std , a , b )