reformat_model#

nvmath.sparse.ust.interfaces.torch_interface.reformat_model(model, *, func=None)[source]#

This function potentially converts the linear weights in a model into UST format. If func is None, all weights are unconditionally converted into the UST COO format (stress testing the system with exactly the same weights but using sparse operations). If a user-defined function func=reformat is given, this method is applied to every weight and replaces the parameter only if the function returns not None.

Parameters:
  • model – the model to be reformatted

  • func – if set, user-defined reformatting function

Examples

>>> import torch
>>> from nvmath.sparse.ust.interfaces.torch_interface import TorchUST

Inside the reformat method, inspect the weight sparsity (note that we could even prune here, but it is more common to rely on other pruning frameworks like torch.nn.utils.prune in combination with fine-tuning for accuracy).

If the condition is met, pick a suitable format for weight and then return TorchUST.from_torch(weight). Otherwise, just return None.

>>> def reformat(weight):
...     nel = weight.numel()
...     nnz = torch.count_nonzero(weight)
...     sparsity = (1.0 - float(nnz) / float(nel)) * 100.0
...     if sparsity >= sparse_threshold:
...         # TODO: Pick suitable format for weight
...         return TorchUST.from_torch(weight)
...     return None

This approach enables experimenting with novel formats to speedup sparsified models during inference by simply calling the method ust.reformat_model(model, func=reformat). No source code changes inside the model are required! If used during training, always make sure to construct the optimizer (by calling e.g. torch.optim.Adam(model.parameters(), lr=0.001)) after the reformatting method call, so that the new parameters will be involved in the optimizer steps.