Source code for polygraphy.datatype.datatype

# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from textwrap import dedent

from polygraphy import mod
from polygraphy.exception import PolygraphyException, DataTypeConversionException
from polygraphy.logger import G_LOGGER, LogMode

import enum

class _SkipImporterException(PolygraphyException):

class _DataTypeKind(enum.Enum):
    INTEGRAL = 1
    _OTHER = 3

[docs] @mod.export() class DataTypeEntry: """ Represents a data type. Can be transformed to and from data type classes of external modules, like NumPy. Do *not* construct objects of this type directly. Instead, use the predefined data types provided in the ``DataType`` class. """ def __init__(self, name, itemsize, type: _DataTypeKind): = name """The human-readable name of the data type""" self.itemsize = itemsize """The size in bytes of a single element of this data type""" # self._type describes the basic kind of the type we have. # For example, this can self._type = type @property def is_floating(self): return self._type == _DataTypeKind.FLOATING_POINT @property def is_integral(self): return self._type == _DataTypeKind.INTEGRAL def __str__(self): return def __repr__(self): return f"DataType.{}"
[docs] @mod.export() class DataType: # Docstring will be populated by the loop below """ Aggregates supported Polygraphy data types. Each data type is accessible via this class as a class member of type ``DataTypeEntry``. Members: """ _IMPORTER_FUNCS = {} _EXPORTER_FUNCS = {} __members__ = { "FLOAT64": DataTypeEntry("float64", 8, _DataTypeKind.FLOATING_POINT), "FLOAT32": DataTypeEntry("float32", 4, _DataTypeKind.FLOATING_POINT), "FLOAT16": DataTypeEntry("float16", 2, _DataTypeKind.FLOATING_POINT), "INT16": DataTypeEntry("int16", 2, _DataTypeKind.INTEGRAL), "INT32": DataTypeEntry("int32", 4, _DataTypeKind.INTEGRAL), "INT64": DataTypeEntry("int64", 8, _DataTypeKind.INTEGRAL), "INT8": DataTypeEntry("int8", 1, _DataTypeKind.INTEGRAL), "INT4": DataTypeEntry("int4", 0.5, _DataTypeKind.INTEGRAL), "UINT16": DataTypeEntry("uint16", 2, _DataTypeKind.INTEGRAL), "UINT32": DataTypeEntry("uint32", 4, _DataTypeKind.INTEGRAL), "UINT64": DataTypeEntry("uint64", 8, _DataTypeKind.INTEGRAL), "UINT8": DataTypeEntry("uint8", 1, _DataTypeKind.INTEGRAL), "BOOL": DataTypeEntry("bool", 1, _DataTypeKind._OTHER), "STRING": DataTypeEntry("string", 0, _DataTypeKind._OTHER), "BFLOAT16": DataTypeEntry("bfloat16", 2, _DataTypeKind.FLOATING_POINT), "FLOAT8E4M3FN": DataTypeEntry("float8e4m3fn", 1, _DataTypeKind.FLOATING_POINT), "FLOAT8E4M3FNUZ": DataTypeEntry( "float8e4m3fnuz", 1, _DataTypeKind.FLOATING_POINT ), "FLOAT8E5M2": DataTypeEntry("float8e5m2", 1, _DataTypeKind.FLOATING_POINT), "FLOAT8E5M2FNUZ": DataTypeEntry( "float8e5m2fnuz", 1, _DataTypeKind.FLOATING_POINT ), }
[docs] @staticmethod def from_dtype(dtype, source_module=None): """ Converts a data type from any known external libraries to a corresponding Polygraphy data type. Args: dtype (Any): A data type from an external library. source_module (str): The name of the module from where the provided `dtype` originates. If this is not provided, Polygraphy will attempt to guess the module in order to convert the data type. Returns: DataTypeEntry: The corresponding Polygraphy data type. Raises: PolygraphyException: If the data type could not be converted. """ if dtype is None: G_LOGGER.critical(f"Could not convert: {dtype} to a Polygraphy data type") if isinstance(dtype, DataTypeEntry): return dtype if source_module is not None: if source_module not in DataType._IMPORTER_FUNCS: G_LOGGER.critical( f"Could not find source module: {source_module} in known importers. " f"Note: Importer functions have been registered for the following modules: {list(DataType._IMPORTER_FUNCS.keys())}" ) return DataType._IMPORTER_FUNCS[source_module](dtype) for func in DataType._IMPORTER_FUNCS.values(): try: ret = func(dtype) except _SkipImporterException: pass else: return ret msg = f"Could not convert: {dtype} to a corresponding Polygraphy data type. Leaving this type in its source format." G_LOGGER.warning(msg, mode=LogMode.ONCE) G_LOGGER.internal_error(msg) return dtype
[docs] @staticmethod def to_dtype(dtype, target_module): """ Converts a Polygraphy data type to one from any known external libraries. Args: dtype (DataType): A Polygraphy data type. If something other than a Polygraphy data type is provided, then this function will return it without modifying it. target_module (str): The name of the module whose data type class to convert this data type to. Returns: Any: The corresponding data type from the target module. Raises: PolygraphyException: If the data type could not be converted. """ if not isinstance(dtype, DataTypeEntry): G_LOGGER.internal_error( f"Received input of type other than DataType: {dtype}" ) return dtype if target_module not in DataType._EXPORTER_FUNCS: G_LOGGER.critical( f"Could not find target module: {target_module} in known exporters. " f"Note: Exporter functions have been registered for the following modules: {list(DataType._EXPORTER_FUNCS.keys())}" ) return DataType._EXPORTER_FUNCS[target_module](dtype)
DataType.__doc__ = dedent(DataType.__doc__) for name, value in DataType.__members__.items(): setattr(DataType, name, value) DataType.__doc__ += f"\t- {name}\n" def register_dtype_importer(source_module): """ Registers an importer function with the DataType class. IMPORTANT: You *must* ensure that the importer function does not attempt to automatically install or import modules which are not already installed. With a lazily imported module, `module.is_installed()/is_importable()` is an easy way to guard the code against this. We do not want to automatically install heavy modules like PyTorch or TensorRT just for the sake of DataType. For example: :: @register_dtype_importer("numpy") def func(dtype): ... The importer function should return `None` if no corresponding data type could be found or if the input type did not match what was expected. The newly registered function is then usable via `from_dtype`: :: dtype = DataType.from_dtype(np.int64) """ def register_importer_impl(func): @functools.wraps(func) def new_func(dtype): val = func(dtype) if val is None: # We raise an exception to indicate that `from_dtype` should skip this importer and try a different one. # We have to do it this way since we don't necessarily know which importer is the right one to use. raise _SkipImporterException() return val DataType._IMPORTER_FUNCS[source_module] = new_func return new_func return register_importer_impl def register_dtype_exporter(target_module): """ Registers an exporter function with the DataType class. For example: :: @register_dtype_exporter("numpy") def func(dtype): ... The newly registered function is then accessible with, for example: :: np_dtype = DataType.FLOAT32.numpy() """ def register_exporter_impl(func): @functools.wraps(func) def new_func(dtype): val = func(dtype) if val is None: G_LOGGER.critical( f"Could not convert Polygraphy data type: {dtype} to a corresponding {target_module} data type. ", ExceptionType=DataTypeConversionException, ) return val new_func.__name__ = target_module setattr(DataTypeEntry, new_func.__name__, new_func) DataType._EXPORTER_FUNCS[target_module] = new_func return new_func return register_exporter_impl