#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import functools
import io
import json
from collections import OrderedDict
from polygraphy import constants, mod, util
from polygraphy.logger import G_LOGGER
np = mod.lazy_import("numpy")
torch = mod.lazy_import("torch>=1.13.0")
def legacy_str_from_type(typ):
return "__polygraphy_encoded_" + typ.__name__
def str_from_type(typ):
return typ.__name__
class BaseCustomImpl:
"""
Base class for Polygraphy's JSON encoder/decoder.
"""
@classmethod
def register(cls, typ, alias=None):
"""
Decorator that registers JSON encoding/decoding functions for types.
Args:
typ (type): The type to register
alias (str):
An alias under which to also register the decoder function.
This can be used to retain backwards-compatibility when a class
name changes.
For the documentation that follows, assume we have a class:
::
class Dummy:
def __init__(self, x):
self.x = x
========
Encoders
========
Encoder functions should accept instances of the specified type and
return dictionaries.
For example:
::
@Encoder.register(Dummy)
def encode(dummy):
return {"x": dummy.x}
To use the custom encoder, use the `to_json` helper:
::
d = Dummy(x=1)
d_json = to_json(d)
========
Decoders
========
Decoder functions should accept dictionaries, and return instances of the
type.
For example:
::
@Decoder.register(Dummy)
def decode(dct):
return Dummy(x=dct["x"])
To use the custom decoder, use the `from_json` helper:
::
from_json(d_json)
Args:
typ (type): The type of the class for which to register the function.
"""
def register_impl(func):
def add(key, val):
if key in cls.polygraphy_registered:
G_LOGGER.critical(
f"Duplicate serialization function for type: {key}.\nNote: Existing function: {cls.polygraphy_registered[key]}, New function: {func}"
)
cls.polygraphy_registered[key] = val
if cls == Encoder:
def wrapped(obj):
dct = func(obj)
dct[constants.TYPE_MARKER] = str_from_type(typ)
return dct
add(typ, wrapped)
return wrapped
elif cls == Decoder:
def wrapped(dct):
if constants.TYPE_MARKER in dct:
del dct[constants.TYPE_MARKER]
type_name = legacy_str_from_type(typ)
if type_name in dct:
del dct[type_name]
return func(dct)
add(legacy_str_from_type(typ), wrapped)
add(str_from_type(typ), wrapped)
if alias is not None:
add(alias, wrapped)
else:
G_LOGGER.critical("Cannot register for unrecognized class type: ")
return register_impl
@mod.export()
class Encoder(BaseCustomImpl, json.JSONEncoder):
"""
Polygraphy's custom JSON Encoder implementation.
"""
polygraphy_registered = {}
def default(self, o):
if type(o) in self.polygraphy_registered:
return self.polygraphy_registered[type(o)](o)
return super().default(o)
@mod.export()
class Decoder(BaseCustomImpl):
"""
Polygraphy's custom JSON Decoder implementation.
"""
polygraphy_registered = {}
def __call__(self, pairs):
# The encoder will insert special key-value pairs into dictionaries encoded from
# custom types. If we find one, then we know to decode using the corresponding custom
# type function.
dct = OrderedDict(pairs)
# Handle legacy naming first - these keys should not be present in JSON generated by more recent versions of Polygraphy.
for type_str, func in self.polygraphy_registered.items():
if (
type_str in dct and dct[type_str] == constants.LEGACY_TYPE_MARKER
): # Found a custom type!
return func(dct)
type_name = dct.get(constants.TYPE_MARKER)
if type_name is not None:
if type_name not in self.polygraphy_registered:
user_type_name = {
"Tensor": "torch.Tensor",
"ndarray": "np.ndarray",
}.get(type_name, type_name)
G_LOGGER.critical(
f"Could not decode serialized type: {user_type_name}. This could be because a required module is missing. "
)
return self.polygraphy_registered[type_name](dct)
return dct
NUMPY_REGISTRATION_SUCCESS = False
TORCH_REGISTRATION_SUCCESS = False
COMMON_REGISTRATION_SUCCESS = False
def try_register_common_json(func):
"""
Decorator that attempts to register common JSON encode/decode methods
if the methods have not already been registered.
This needs to be attempted multiple times because dependencies may become available in the
middle of execution - for example, if using dependency auto-installation.
"""
@functools.wraps(func)
def wrapped(*args, **kwargs):
global NUMPY_REGISTRATION_SUCCESS
if not NUMPY_REGISTRATION_SUCCESS and np.is_installed() and np.is_importable():
# We define this alongside load_json/save_json so that it is guaranteed to be
# imported before we need to encode/decode NumPy arrays.
@Encoder.register(np.ndarray)
def encode(array):
outfile = io.BytesIO()
np.save(outfile, array, allow_pickle=False)
outfile.seek(0)
data = base64.b64encode(outfile.read()).decode()
return {"array": data}
@Decoder.register(np.ndarray)
def decode(dct):
def load(mode="base64"):
if mode == "base64":
data = base64.b64decode(dct["array"].encode(), validate=True)
elif mode == "latin-1":
data = dct["array"].encode(mode)
else:
assert False, f"Unsupported mode: {mode}"
infile = io.BytesIO(data)
return np.load(infile, allow_pickle=False)
try:
arr = load()
except:
arr = load("latin-1") # For backwards compatibility
if isinstance(arr, np.ndarray):
return arr
return list(arr.values())[0] # For backwards compatibility
NUMPY_REGISTRATION_SUCCESS = True
global TORCH_REGISTRATION_SUCCESS
if (
not TORCH_REGISTRATION_SUCCESS
and torch.is_installed()
and torch.is_importable()
):
@Encoder.register(torch.Tensor)
def encode(tensor):
outfile = io.BytesIO()
torch.save(tensor, outfile)
outfile.seek(0)
data = base64.b64encode(outfile.read()).decode()
return {"tensor": data}
@Decoder.register(torch.Tensor)
def decode(dct):
data = base64.b64decode(dct["tensor"].encode(), validate=True)
infile = io.BytesIO(data)
return torch.load(infile)
TORCH_REGISTRATION_SUCCESS = True
global COMMON_REGISTRATION_SUCCESS
if not COMMON_REGISTRATION_SUCCESS:
# Pull in some common types so that we can get their associated serialization/deserialization
# functions. This allows the user to avoid importing these manually.
# Note: We can only do this here for submodules with no external dependencies.
# That means, for example, nothing from `backend/` can be imported here.
from polygraphy.common import FormattedArray
from polygraphy.comparator import RunResults
COMMON_REGISTRATION_SUCCESS = True
return func(*args, **kwargs)
return wrapped
[docs]
@mod.export()
@try_register_common_json
def to_json(obj):
"""
Encode an object to JSON.
NOTE: For Polygraphy objects, you should use the ``to_json()`` method instead.
Returns:
str: A JSON representation of the object.
"""
return json.dumps(obj, cls=Encoder, indent=constants.TAB)
[docs]
@mod.export()
@try_register_common_json
def from_json(src):
"""
Decode a JSON string to an object.
NOTE: For Polygraphy objects, you should use the ``from_json()`` method instead.
Args:
src (str):
The JSON representation of the object
Returns:
object: The decoded instance
"""
return json.loads(src, object_pairs_hook=Decoder())
[docs]
@mod.export()
@try_register_common_json
def save_json(obj, dest, description=None):
"""
Encode an object as JSON and save it to a file.
NOTE: For Polygraphy objects, you should use the ``save()`` method instead.
Args:
obj : The object to save.
src (Union[str, file-like]): The path or file-like object to save to.
"""
util.save_file(to_json(obj), dest, mode="w", description=description)
[docs]
@mod.export()
@try_register_common_json
def load_json(src, description=None):
"""
Loads a file and decodes the JSON contents.
NOTE: For Polygraphy objects, you should use the ``load()`` method instead.
Args:
src (Union[str, file-like]): The path or file-like object to load from.
Returns:
object: The object, or `None` if nothing could be read.
"""
return from_json(util.load_file(src, mode="r", description=description))
@mod.export()
def add_json_methods(description=None):
"""
Decorator that adds 4 JSON helper methods to a class:
- to_json(): Convert to JSON string
- from_json(): Convert from JSON string
- save(): Convert to JSON and save to file
- load(): Load from file and convert from JSON
Args:
description (str):
A description of what is being saved or loaded.
"""
def add_json_methods_impl(cls):
# JSON methods
def check_decoded(obj):
if not isinstance(obj, cls):
G_LOGGER.critical(
f"Provided JSON cannot be decoded into a {cls.__name__}.\nNote: JSON was decoded into a {type(obj)}:\n{obj}"
)
return obj
def _to_json_method(self):
"""
Encode this instance as a JSON object.
Returns:
str: A JSON representation of this instance.
"""
return to_json(self)
def _from_json_method(src):
return check_decoded(from_json(src))
_from_json_method.__doc__ = f"""
Decode a JSON object and create an instance of this class.
Args:
src (str):
The JSON representation of the object
Returns:
{cls.__name__}: The decoded instance
Raises:
PolygraphyException:
If the JSON cannot be decoded to an instance of {cls.__name__}
"""
cls.to_json = _to_json_method
cls.from_json = staticmethod(_from_json_method)
# Save/Load methods
def _save_method(self, dest):
"""
Encode this instance as a JSON object and save it to the specified path
or file-like object.
Args:
dest (Union[str, file-like]):
The path or file-like object to write to.
"""
save_json(self, dest, description=description)
def _load_method(src):
return check_decoded(load_json(src, description=description))
_load_method.__doc__ = f"""
Loads an instance of this class from a JSON file.
Args:
src (Union[str, file-like]): The path or file-like object to read from.
Returns:
{cls.__name__}: The decoded instance
Raises:
PolygraphyException:
If the JSON cannot be decoded to an instance of {cls.__name__}
"""
cls.save = _save_method
cls.load = staticmethod(_load_method)
return cls
return add_json_methods_impl