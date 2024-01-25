# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # TODO(Dallas) Introduce Distributed Class for computation. import torch Tensor = torch . Tensor [docs] def entropy_from_counts ( p : Tensor , bin_edges : Tensor , normalized = True ) -> Tensor : """Computes the Statistical Entropy of a random variable using a histogram. Uses the formula: .. math:: Entropy(X) = \\int p(x) * \\log( p(x) ) dx Parameters ---------- p : Tensor Tensor [N, ...] containing counts/pdf, defined over bins. The non-zeroth dimensions of bin_edges and p must be compatible. bins_edges : Tensor Tensor [N+1, ...] containing bin edges. The leading dimension must represent the N+1 bin edges. normalized : Bool, Optional Boolean flag determining whether the returned statistical entropy is normalized. Normally the entropy for a compact bounded probability distribution is bounded between a pseudo-dirac distribution, ent_min, and a uniform distribution, ent_max. This normalization transforms the entropy from [ent_min, ent_max] to [0, 1] Returns ------- Tensor Tensor containing the Information/Statistical Entropy """ if bin_edges . shape [ 1 :] != p . shape [ 1 :]: raise ValueError ( "Expected bins and pdf to have compatible non-zeroth dimensions but have shapes" + str ( bin_edges . shape [ 1 :]) + " and " + str ( p . shape [ 1 :]) + "." ) if bin_edges . shape [ 0 ] != p . shape [ 0 ] + 1 : raise ValueError ( "Expected zeroth dimension of cdf to be equal to the zeroth dimension of bins + 1 but have shapes" + str ( bin_edges . shape [ 0 ]) + " and " + str ( p . shape [ 0 ]) + "+1." ) dbins = bin_edges [ 1 :] - bin_edges [: - 1 ] bin_mids = 0.5 * ( bin_edges [ 1 :] + bin_edges [: - 1 ]) p = p / torch . trapz ( p , bin_mids , dim = 0 ) + 1e-8 ent = torch . trapz ( - 1.0 * p * torch . log ( p ), bin_mids , dim = 0 ) if normalized : max_ent = torch . log ( bin_edges [ - 1 ] - bin_edges [ 0 ]) min_ent = 0.5 + 0.5 * torch . log ( 2 * torch . pi * dbins [ 0 ] ** 2 ) return ( ent - min_ent ) / ( max_ent - min_ent ) else : return ent [docs] def relative_entropy_from_counts ( p : Tensor , q : Tensor , bin_edges : Tensor , ) -> Tensor : """Computes the Relative Statistical Entropy, or KL Divergence of two random variables using their histograms. Uses the formula: .. math:: Entropy(X) = \\int p(x) * \\log( p(x)/q(x) ) dx Parameters ---------- p : Tensor Tensor [N, ...] containing counts/pdf, defined over bins. The non-zeroth dimensions of bin_edges and p must be compatible. q : Tensor Tensor [N, ...] containing counts/pdf, defined over bins. The non-zeroth dimensions of bin_edges and q must be compatible. bins_edges : Tensor Tensor [N+1, ...] containing bin edges. The leading dimension must represent the N+1 bin edges. Returns ------- Tensor Map of Statistical Entropy """ if bin_edges . shape [ 1 :] != p . shape [ 1 :]: raise ValueError ( "Expected bins and pdf to have compatible non-zeroth dimensions but have shapes" + str ( bin_edges . shape [ 1 :]) + " and " + str ( p . shape [ 1 :]) + "." ) if bin_edges . shape [ 0 ] != p . shape [ 0 ] + 1 : raise ValueError ( "Expected zeroth dimension of cdf to be equal to the zeroth dimension of bins + 1 but have shapes" + str ( bin_edges . shape [ 0 ]) + " and " + str ( p . shape [ 0 ]) + "+1." ) if p . shape != q . shape : raise ValueError ( "Expected p and q to have compatible shapes but have shapes" + str ( p . shape ) + " and " + str ( q . shape ) + "." ) bin_mids = 0.5 * ( bin_edges [ 1 :] + bin_edges [: - 1 ]) p = p / torch . trapz ( p , bin_mids , dim = 0 ) + 1e-8 q = q / torch . trapz ( q , bin_mids , dim = 0 ) + 1e-8 return torch . trapz ( p * torch . log ( p / q ), bin_mids , dim = 0 )