.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples/magnitude_pruning.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_magnitude_pruning.py: .. _magnitude_pruning_tutorial: Magnitude Pruning ================= .. GENERATED FROM PYTHON SOURCE LINES 11-19 In this tutorial, you learn how to train a simple convolutional neural network on `MNIST `_ using :py:class:`~.pruning.MagnitudePruner`. The tutorial demonstrates how to achieve ~75% sparsity without incurring significant loss in model accuracy. Learn more about other pruners and schedulers in the coremltools `Training-Time Pruning Documentation `_. .. GENERATED FROM PYTHON SOURCE LINES 21-25 Network and Dataset Definition ------------------------------ First define your network, which consists of a single convolution layer followed by a dense (linear) layer. .. GENERATED FROM PYTHON SOURCE LINES 25-47 .. code-block:: default from collections import OrderedDict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def mnist_net(num_classes=10): return nn.Sequential( OrderedDict( [('conv', nn.Conv2d(1, 12, 3, padding='same')), ('relu', nn.ReLU()), ('pool', nn.MaxPool2d(2, stride=2, padding=0)), ('flatten', nn.Flatten()), ('dense', nn.Linear(2352, num_classes)), ('softmax', nn.LogSoftmax())] ) ) .. GENERATED FROM PYTHON SOURCE LINES 48-51 Use the `MNIST dataset provided by PyTorch `_ for training. Apply a very simple transformation to the input images to normalize them. .. GENERATED FROM PYTHON SOURCE LINES 51-68 .. code-block:: default import os from torchvision import datasets, transforms def mnist_dataset(data_dir="~/.mnist_pruning_data"): transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) data_path = os.path.expanduser(f"{data_dir}/mnist") if not os.path.exists(data_path): os.makedirs(data_path) train = datasets.MNIST(data_path, train=True, download=True, transform=transform) test = datasets.MNIST(data_path, train=False, transform=transform) return train, test .. GENERATED FROM PYTHON SOURCE LINES 69-70 Next, initialize the model and the dataset. .. GENERATED FROM PYTHON SOURCE LINES 70-79 .. code-block:: default model = mnist_net() batch_size = 128 train_dataset, test_dataset = mnist_dataset() train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) .. GENERATED FROM PYTHON SOURCE LINES 80-83 Training the Model Without Pruning ---------------------------------- Train the model without any pruning applied. .. GENERATED FROM PYTHON SOURCE LINES 83-135 .. code-block:: default optimizer = torch.optim.Adam(model.parameters(), eps=1e-07) accuracy_unpruned = 0.0 num_epochs = 4 def train_step(model, optimizer, train_loader, data, target, batch_idx, epoch): optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % 100 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) def eval_model(model, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100. * correct / len(test_loader.dataset) print( "\nTest set: Average loss: {:.4f}, Accuracy: {:.1f}%\n".format( test_loss, accuracy ) ) return accuracy for epoch in range(num_epochs): # train one epoch model.train() for batch_idx, (data, target) in enumerate(train_loader): train_step(model, optimizer, train_loader, data, target, batch_idx, epoch) # evaluate accuracy_unpruned = eval_model(model, test_loader) print("Accuracy of unpruned network: {:.1f}%\n".format(accuracy_unpruned)) .. GENERATED FROM PYTHON SOURCE LINES 136-148 Installing the Pruner in the Model ---------------------------------- Install :py:class:`~.pruning.MagnitudePruner` in the trained model. First, construct a :py:class:`~.pruning.pruning_scheduler.PruningScheduler` class, which specifies how the sparsity of your pruned layers should evolve over the course of the training. For this tutorial, use a :py:class:`~.pruning.PolynomialDecayScheduler`, which is introduced in the paper `"To prune or not to prune" `_. Begin pruning from step ``0`` and prune every ``100`` steps for two epochs. As you step through this pruning scheduler, the sparsity of pruned modules will increase gradually from the initial value to the target value. .. GENERATED FROM PYTHON SOURCE LINES 148-153 .. code-block:: default from coremltools.optimize.torch.pruning import PolynomialDecayScheduler scheduler = PolynomialDecayScheduler(update_steps=list(range(0, 900, 100))) .. GENERATED FROM PYTHON SOURCE LINES 154-161 Next, create an instance of the :py:class:`~.pruning.MagnitudePrunerConfig` class to specify how you want different submodules to be pruned. Set the target sparsity of the convolution layer to ``70 %`` and the dense layer to ``80 %``. The point of this is to demonstrate that different layers can be targeted at different sparsity levels. In practice, the sparsity level of a layer is a hyperparameter, which needs to be tuned for your requirements and the amenability of the layer to sparsification. .. GENERATED FROM PYTHON SOURCE LINES 161-177 .. code-block:: default from coremltools.optimize.torch.pruning import ( MagnitudePruner, MagnitudePrunerConfig, ModuleMagnitudePrunerConfig, ) conv_config = ModuleMagnitudePrunerConfig(target_sparsity=0.7) linear_config = ModuleMagnitudePrunerConfig(target_sparsity=0.8) config = MagnitudePrunerConfig().set_module_type(torch.nn.Conv2d, conv_config) config = config.set_module_type(torch.nn.Linear, linear_config) pruner = MagnitudePruner(model, config) .. GENERATED FROM PYTHON SOURCE LINES 178-184 Next, call :py:meth:`~.pruning.MagnitudePruner.prepare` to insert pruning ``forward pre hooks`` on the modules configured previously. These forward pre hooks are called before a call to the forward method of the module. They multiply the parameter with a pruning mask, which is a tensor of the same shape as the parameter, in which each element has a value of either ``1`` or ``0``. .. GENERATED FROM PYTHON SOURCE LINES 184-187 .. code-block:: default pruner.prepare(inplace=True) .. GENERATED FROM PYTHON SOURCE LINES 188-193 Fine-Tuning the Pruned Model ---------------------------- The next step is to fine tune the model with pruning applied. In order to prune the model, call the :py:meth:`~.pruning.MagnitudePruner.step` method on the pruner after every call to ``optimizer.step()`` to step through the pruning schedule. .. GENERATED FROM PYTHON SOURCE LINES 193-208 .. code-block:: default optimizer = torch.optim.Adam(model.parameters(), eps=1e-07) accuracy_pruned = 0.0 num_epochs = 2 for epoch in range(num_epochs): # train one epoch model.train() for batch_idx, (data, target) in enumerate(train_loader): train_step(model, optimizer, train_loader, data, target, batch_idx, epoch) pruner.step() # evaluate accuracy_pruned = eval_model(model, test_loader) .. GENERATED FROM PYTHON SOURCE LINES 209-214 The evaluation shows that you can train a pruned network without losing accuracy with the final model. In practice, for more complex models, you have a trade-off between the sparsity and the validation accuracy that can be achieved for the model. Finding the right sweet spot on this trade-off curve depends on the model and task. .. GENERATED FROM PYTHON SOURCE LINES 214-220 .. code-block:: default print("Accuracy of pruned network: {:.1f}%\n".format(accuracy_pruned)) print("Accuracy of unpruned network: {:.1f}%\n".format(accuracy_unpruned)) np.testing.assert_allclose(accuracy_pruned, accuracy_unpruned, atol=2) .. GENERATED FROM PYTHON SOURCE LINES 221-231 Finalizing the Model for Export ------------------------------- The example shows that you can prune the model with a few code changes to your existing PyTorch training code. Now you can deploy this model on a device. To finalize the model for export, call :py:meth:`~.pruning.MagnitudePruner.finalize` on the pruner. This removes all the forward pre-hooks you had attached on the submodules. It also freezes the state of the pruner and multiplies the pruning mask with the corresponding weight matrix. .. GENERATED FROM PYTHON SOURCE LINES 231-235 .. code-block:: default model.eval() pruner.finalize(inplace=True) .. GENERATED FROM PYTHON SOURCE LINES 236-246 Exporting the Model for On-Device Execution ------------------------------------------- In order to deploy the model, convert it to a Core ML model. Follow the same steps in Core ML Tools for exporting a regular PyTorch model (for details, see `Converting from PyTorch `_). The parameter ``ct.PassPipeline.DEFAULT_PRUNING`` signals to the converter that the model being converted is a pruned model, and allows the model weights to be represented as sparse matrices, which have a smaller memory footprint than dense matrices. .. GENERATED FROM PYTHON SOURCE LINES 246-260 .. code-block:: default import coremltools as ct example_input = torch.rand(1, 1, 28, 28) traced_model = torch.jit.trace(model, example_input) coreml_model = ct.convert( traced_model, inputs=[ct.TensorType(shape=example_input.shape)], pass_pipeline=ct.PassPipeline.DEFAULT_PRUNING, minimum_deployment_target=ct.target.iOS16, ) coreml_model.save("~/.mnist_pruning_data/pruned_model.mlpackage") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download__examples_magnitude_pruning.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: magnitude_pruning.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: magnitude_pruning.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_