Source code for polygraphy.common.struct

#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from polygraphy import mod
from polygraphy.common.interface import TypedDict

np = mod.lazy_import("numpy")


class MetadataTuple(object):
    def __init__(self, dtype, shape):
        self.dtype = dtype
        self.shape = shape

    def __iter__(self):
        yield from [self.dtype, self.shape]

    def __repr__(self):
        return "MetadataTuple({:}, {:})".format(self.dtype, self.shape)

    def __str__(self):
        ret = ""
        meta_items = []
        if self.dtype is not None:
            meta_items.append("dtype={:}".format(np.dtype(self.dtype).name))
        if self.shape is not None:
            meta_items.append("shape={:}".format(tuple(self.shape)))
        if meta_items:
            ret += "[" + ", ".join(meta_items) + "]"
        return ret


[docs]@mod.export() class TensorMetadata(TypedDict(lambda: str, lambda: MetadataTuple)): """ An OrderedDict[str, MetadataTuple] that maps input names to their data types and shapes. Shapes may include negative values, ``None``, or strings to indicate dynamic dimensions. Example: :: shape = tensor_meta["input0"].shape dtype = tensor_meta["input0"].dtype """
[docs] @staticmethod def from_feed_dict(feed_dict): """ Constructs a new TensorMetadata using information from the provided feed_dict. Args: feed_dict (OrderedDict[str, numpy.ndarray]): A mapping of input tensor names to corresponding input NumPy arrays. Returns: TensorMetadata """ meta = TensorMetadata() for name, arr in feed_dict.items(): meta.add(name, arr.dtype, arr.shape) return meta
[docs] def add(self, name, dtype, shape): """ Convenience function for adding entries. Args: name (str): The name of the input. dtype (numpy.dtype): The data type of the input. shape (Sequence[Union[int, str]]]): The shape of the input. Dynamic dimensions may be indicated by negative values, ``None``, or a string. Returns: The newly added entry. """ self[name] = MetadataTuple(dtype, shape) return self
def __repr__(self): ret = "TensorMetadata()" for name, (dtype, shape) in self.items(): ret += ".add('{:}', {:}, {:})".format(name, dtype, shape) return ret def __str__(self): sep = ",\n " elems = ["{:} {:}".format(name, meta_tuple).strip() for name, meta_tuple in self.items()] return "{" + sep.join(elems) + "}"