Source code for nemo.core.classes.mixins.access_mixins
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Optional
import torch
from omegaconf import DictConfig
_ACCESS_CFG = DictConfig({"detach": False, "convert_to_cpu": False})
_ACCESS_ENABLED = False
def set_access_cfg(cfg: 'DictConfig'):
if cfg is None or not isinstance(cfg, DictConfig):
raise TypeError(f"cfg must be a DictConfig")
global _ACCESS_CFG
_ACCESS_CFG = cfg
[docs]class AccessMixin(ABC):
"""
Allows access to output of intermediate layers of a model
"""
def __init__(self):
super().__init__()
self._registry = {} # dictionary of lists
[docs] def register_accessible_tensor(self, name, tensor):
"""
Register tensor for later use.
"""
if self.access_cfg.get('convert_to_cpu', False):
tensor = tensor.cpu()
if self.access_cfg.get('detach', False):
tensor = tensor.detach()
if not hasattr(self, '_registry'):
self._registry = {}
if name not in self._registry:
self._registry[name] = []
self._registry[name].append(tensor)
[docs] @classmethod
def get_module_registry(cls, module: torch.nn.Module):
"""
Extract all registries from named submodules, return dictionary where
the keys are the flattened module names, the values are the internal registry
of each such module.
"""
module_registry = {}
for name, m in module.named_modules():
if hasattr(m, '_registry') and len(m._registry) > 0:
module_registry[name] = m._registry
return module_registry
[docs] def reset_registry(self: torch.nn.Module, registry_key: Optional[str] = None):
"""
Reset the registries of all named sub-modules
"""
if hasattr(self, "_registry"):
if registry_key is None:
self._registry.clear()
else:
if registry_key in self._registry:
self._registry.pop(registry_key)
else:
raise KeyError(
f"Registry key `{registry_key}` provided, but registry does not have this key.\n"
f"Available keys in registry : {list(self._registry.keys())}"
)
for _, m in self.named_modules():
if hasattr(m, "_registry"):
if registry_key is None:
m._registry.clear()
else:
if registry_key in self._registry:
self._registry.pop(registry_key)
else:
raise KeyError(
f"Registry key `{registry_key}` provided, but registry does not have this key.\n"
f"Available keys in registry : {list(self._registry.keys())}"
)
# Explicitly disable registry cache after reset
AccessMixin.set_access_enabled(access_enabled=False)
@property
def access_cfg(self):
"""
Returns:
The global access config shared across all access mixin modules.
"""
global _ACCESS_CFG
return _ACCESS_CFG
[docs] @classmethod
def update_access_cfg(cls, cfg: dict):
global _ACCESS_CFG
_ACCESS_CFG.update(cfg)
[docs] @classmethod
def is_access_enabled(cls):
global _ACCESS_ENABLED
return _ACCESS_ENABLED
[docs] @classmethod
def set_access_enabled(cls, access_enabled: bool):
global _ACCESS_ENABLED
_ACCESS_ENABLED = access_enabled