prune_model#

nvmath.sparse.ust.interfaces.torch_interface.prune_model(model, *, local=True, amount=0.5)[source]#

This is a convenience wrapper that uses the framework torch.nn.utils.prune to prune all the weights of linear layers in a model either locally (per layer) or globally (over all layers) with the given amount.

Parameters:
  • model – the model to be pruned

  • local – local per layer pruning if True, otherwise global over all layers

  • amount – amount of weights to be dropped (e.g. 0.50 drops 50%)