reformat_model#
-
nvmath.
sparse. ust. interfaces. torch_interface. reformat_model(model, *, func=None)[source]# This function potentially converts the linear weights in a model into UST format. If
funcis None, all weights are unconditionally converted into the UST COO format (stress testing the system with exactly the same weights but using sparse operations). If a user-defined functionfunc=reformatis given, this method is applied to every weight and replaces the parameter only if the function returns not None.- Parameters:
model – the model to be reformatted
func – if set, user-defined reformatting function
Examples
>>> import torch >>> from nvmath.sparse.ust.interfaces.torch_interface import TorchUST
Inside the
reformatmethod, inspect the weight sparsity (note that we could even prune here, but it is more common to rely on other pruning frameworks liketorch.nn.utils.prunein combination with fine-tuning for accuracy).If the condition is met, pick a suitable format for
weightand then returnTorchUST.from_torch(weight). Otherwise, just returnNone.>>> def reformat(weight): ... nel = weight.numel() ... nnz = torch.count_nonzero(weight) ... sparsity = (1.0 - float(nnz) / float(nel)) * 100.0 ... if sparsity >= sparse_threshold: ... # TODO: Pick suitable format for weight ... return TorchUST.from_torch(weight) ... return None
This approach enables experimenting with novel formats to speedup sparsified models during inference by simply calling the method
ust.reformat_model(model, func=reformat). No source code changes inside the model are required! If used during training, always make sure to construct the optimizer (by calling e.g.torch.optim.Adam(model.parameters(), lr=0.001)) after the reformatting method call, so that the new parameters will be involved in the optimizer steps.