# Torch-Pruning
**Repository Path**: PolarisF/Torch-Pruning
## Basic Information
- **Project Name**: Torch-Pruning
- **Description**: No description available
- **Primary Language**: Unknown
- **License**: MIT
- **Default Branch**: master
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 0
- **Created**: 2023-12-14
- **Last Updated**: 2023-12-14
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
Towards Any Structural Pruning
Torch-Pruning (TP)是一个通用的结构化网络剪枝框架,主要包括以下功能:
* **通用的结构化剪枝工具:** 支持 *[LLaMA](https://github.com/horseee/LLaMA-Pruning), [Vision Transformers](benchmarks/prunability), [Yolov7](benchmarks/prunability/readme.md#3-yolo-v7), [yolov8](benchmarks/prunability/readme.md#2-yolo-v8), FasterRCNN, SSD, KeypointRCNN, MaskRCNN, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, FCN, DeepLab* 等神经网络. 不同于[torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html)中利用掩码(Masking)实现的“模拟剪枝”, Torch-Pruning采用了一种名为DepGraph的非深度图算法, 能够“物理”地移除模型中的耦合参数和通道。
* **可复线的[性能基准线](benchmarks)和[可剪枝性基准线](benchmarks/prunability)**: 目前, Torch-Pruning已经覆盖了 **81/85=95.3%** 的Torchvision预训练模型(v0.13.1). 您可以访问[Colab Demo](https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing)来快速体验Torchvision预训练模型的剪枝。
更多技术细节请参考我们的论文:
> [**DepGraph: Towards Any Structural Pruning**](https://arxiv.org/abs/2301.12900)
> [Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Mingli Song](https://person.zju.edu.cn/en/msong), [Michael Bi Mi](https://dblp.org/pid/317/0937.html), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
### Update:
* 2023.04.15 [Pruning and Post-training for YOLOv7 / YOLOv8](benchmarks/prunability)
* 2023.04.10 [Structural Pruning for LLaMA (pruning-only)](https://github.com/horseee/LLaMA-Pruning)
* 2023.04.21 加入技术分享群(微信群或Telegram Group):
* Wechat:
* Telegram: https://t.me/+NwjbBDN2ao1lZjZl
如有任何框架、论文相关的问题, 请新建[discussion](https://github.com/VainF/Torch-Pruning/discussions)或者[issue](https://github.com/VainF/Torch-Pruning/issues). 非常乐意回复您的问题.
### **特性:**
- [x] 结构化(通道)剪枝: 支持[CNNs](benchmarks/prunability/torchvision_pruning.py#L19), [Transformers](benchmarks/prunability/torchvision_pruning.py#L11),各类检测器以及语言模型。
- [x] 高级剪枝器(High-level pruners): [MagnitudePruner](https://arxiv.org/abs/1608.08710), [BNScalePruner](https://arxiv.org/abs/1708.06519), [GroupNormPruner](https://arxiv.org/abs/2301.12900) (论文所使用的组剪枝器), RandomPruner等
- [x] DepGraph, 支持计算图构建和依赖建模.
- [x] 支持的基础模块: Conv, Linear, BatchNorm, LayerNorm, Transposed Conv, PReLU, Embedding, MultiheadAttention, nn.Parameters and [自定义层(customized modules)](tests/test_customized_layer.py).
- [x] 支持的操作: split, concatenation, skip connection, flatten, reshape, view, all element-wise ops等
- [x] [底层剪枝函数 (Low-level pruning functions)](torch_pruning/pruner/function.py)
- [x] [Benchmarks](benchmarks) and [tutorials](tutorials)
- [x] 资源列表[resource list](practical_structural_pruning.md) for practical structrual pruning.
### **后续开发计划:**
- [ ] 剪枝适配性基准线, 覆盖 [Torchvision](https://pytorch.org/vision/stable/models.html) (**81/85=95.3%** , :heavy_check_mark:)和[timm](https://github.com/huggingface/pytorch-image-models)等常见模型库.
- [ ] Pruning from Scratch / at Initialization.
- [ ] 语言、语音、生成式模型剪枝
- [ ] 更多的高级剪枝器, 例如[FisherPruner](https://arxiv.org/abs/2108.00708), [GrowingReg](https://arxiv.org/abs/2012.09243)等.
- [ ] 更多的标准层: GroupNorm, InstanceNorm, Shuffle Layers, etc.
- [ ] 更多的Transformer网络: Vision Transformers (:heavy_check_mark:), Swin Transformers, PoolFormers.
- [ ] Block/Layer/Depth Pruning
- [ ] 性能基准线: 支持CIFAR, ImageNet and COCO.
## 安装
Torch-Pruning支持Pytorch 1.0和2.0版本。本项目主要使用Pytorch>=1.13.1进行开发和测试。
```bash
pip install torch-pruning # v1.1.6
```
或者
```bash
git clone https://github.com/VainF/Torch-Pruning.git # recommended
```
## Quickstart
本节内容提供了Torch-Pruning的简单例子, 用于快速了解项目的主要功能. 更多细节请参考[tutorals](./tutorials/)
### 0. 工作原理
在复杂的网络结构中, 参数之间可能存在依赖关系, 这种依赖要求算法对这类参数进行同步移除以保证结构正确性,这就涉及到耦合参数的分组问题. 我们的工作通过提供一种自动化机制来对参数进行分组. 具体而言, Torch-Pruning使用伪输入来运行您的模型, 跟踪网络计算图, 并记录层之间的依赖关系. 当您剪枝某一层时, Torch-Pruning会识别所有耦合层, 并返回包含这些耦合信息的``tp.Group``.此外, 如果存在像 torch.split 或 torch.cat 这样的操作, 所有剪枝索引都将自动对齐.
### 1. A minimal example
```python
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True).eval()
# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))
# 2. 指定剪枝的通道维度
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )
print(pruning_group.details()) # or print(pruning_group)
# 3. 检查剩余通道数是否<=0, 并执行剪枝
if DG.check_pruning_group(pruning_group):
pruning_group.prune()
```
这个例子演示了使用 DepGraph剪枝的基本流程.值得注意的是, resnet.conv1实际上会与多个层耦合在一起.通过打印返回的组, 我们观察到组内各个层之间的剪枝是如何互相“触发”的.在以下输出中, “A => B”表示剪枝操作“A”触发剪枝操作“B”.group[0]是用户在DG.get_pruning_group中给出的剪枝操作.
```
--------------------------------
Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------
```
更多细节请参考[tutorials/2 - Exploring Dependency Groups](https://github.com/VainF/Torch-Pruning/blob/master/tutorials/2%20-%20Exploring%20Dependency%20Groups.ipynb)
#### 如何遍历所有分组:
正如我们在[MetaPruner](https://github.com/VainF/Torch-Pruning/blob/b607ae3aa61b9dafe19d2c2364f7e4984983afbf/torch_pruning/pruner/algorithms/metapruner.py#L197)中所实现的, 我们可以利用``DG.get_all_groups(ignored_layers, root_module_types)``来按顺序扫描所有的分组. 每个分组都会以一个"root_module_types"中所指定的层作为起点. 默认情况下, 这些组包含了完整的剪枝索引``idxs=[0,1,2,3,...,K]``, 这个索引列表包含了所有的可修剪参数的索引. 如果我们希望对一个group进行剪枝, 我们需要使用``group.prune(idxs=idxs)``来指定具体的修剪通道/维度.
```python
for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
# handle groups in sequential order
idxs = [2,4,6] # your pruning indices
group.prune(idxs=idxs)
print(group)
```
### 2. 高级剪枝器(High-level Pruners)
利用 DependencyGraph, 我们在项目中开发了几个高级剪枝器, 以便实现一键式剪枝.通过指定所需的通道稀疏性, 您可以对整个模型进行修剪, 并使用自己的训练代码进行微调.关于此过程的详细信息, 我们建议您查阅[Tutorial-1](https://github.com/VainF/Torch-Pruning/blob/master/tutorials/1%20-%20Customize%20Your%20Own%20Pruners.ipynb), 该文档演示了如何基于Torch-Pruning快速实现一个经典的[slimming算法](https://arxiv.org/abs/1708.06519).此外, 您可以在[benchmarks/main.py](benchmarks/main.py)中找到更多实用的示例.
```python
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True)
# 重要性指标
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
iterative_steps = 5 # 迭代式剪枝, 该示例会分五步完成50%通道剪枝 (10%->20%->...->50%)
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
pruning_ratio=0.5, # 整体移除50%通道, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune your model here
# finetune(model)
# ...
```
#### 稀疏训练
一些剪枝器pruners例如[BNScalePruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py#L45)和 [GroupNormPruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/group_norm_pruner.py#L53)依赖稀疏训练来寻找冗余参数. 这个过程可以通过向您的训练代码中加入一行``pruner.regularize(model)``来实现. 该操作会将稀疏训练的梯度叠加到网络的梯度上, 您可以使用任意的优化器进行优化.
```python
for epoch in range(epochs):
model.train()
for i, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out = model(data)
loss = F.cross_entropy(out, target)
loss.backward()
pruner.regularize(model) # <== for sparse learning
optimizer.step()
```
#### 交互式剪枝
所有的高级剪枝器都支持交互式剪枝. 你可以利用``pruner.step(interactive=True)``来获得所有的待剪枝分组, 并根据需要调用``group.prune()``来完成修剪. 这一功能可以用于控制/监控整个剪枝过程.
```python
for i in range(iterative_steps):
for group in pruner.step(interactive=True): # Warning: 分组必须按顺序进行处理, 因为剪枝会影响模型建构, 改变坐标索引.
print(group)
# do whatever you like with the group
# ...
group.prune() # 此处需要手动调用group.prune()
# group.prune(idxs=[0, 2, 6]) # 您甚至可以手动修改剪枝的索引
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune your model here
# finetune(model)
# ...
```
#### 组剪枝
利用DepGraph, 我们可以比较轻松地设计出各种组级别重要性评估指标(group-level criteria), 用于一组参数的重要性. 这不同于过去仅用于单层的重要性评估. 在我们的论文中, 我们构造了一种简单的组剪枝器[GroupNormPruner](https://github.com/VainF/Torch-Pruning/blob/745f6d6bafba7432474421a8c1e5ce3aad25a5ef/torch_pruning/pruner/algorithms/group_norm_pruner.py#L8) (如下图c所示).该剪枝器通过组级别的稀疏来学习到具有一致重要性的耦合参数, 确保被移除的参数均具有较小的重要性得分.
### 3. 模型的保存与读取
#### 最简单的方式
以下代码直接将模型对象序列化为.pth文件,该方式足够简单但是会导致存储文件偏大,不方便通过互联网分享。
```python
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object
```
#### 剪枝历史(Pruning History)
我们介绍一种利用 ``pruning_history`` 来存储和读取剪枝后模型的方法,该方法与PyTorch采用的``state_dict``非常相似。请参考样例 [tests/test_load.py](https://github.com/VainF/Torch-Pruning/blob/master/tests/test_load.py)
```python
...
# Save
state_dict = {
'model': model.state_dict(), # 标准的Pytorch存储方式
'pruning': pruner.pruning_history(), # 依赖图DG支持相同的DG.pruning_history & DG.load_pruning_history接口
}
torch.save(state_dict, 'pruned_model.pth')
# Load
model = resnet18() # 创建一个未剪枝的模型
DG = tp.DependencyGraph().build_dependency(model, example_inputs) # 创建一个依赖图DG或者Pruner
state_dict = torch.load('pruned_model.pth') # 读取模型参数
DG.load_pruning_history(state_dict['pruning']) # 读取剪枝历史,并对网络重复相同的裁剪
model.load_state_dict(state_dict['model']) # 重新剪枝后,模型可以读取存储的参数
print(model)
```
### 4. 底层剪枝函数(Low-level pruning functions)
虽然使用低级函数可以手动修剪模型, 但这种方法可能非常繁琐, 因为它需要手动管理相关依赖项.因此, 我们建议利用前面提到的高级剪枝器来简化剪枝过程.
```python
tp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] )
# fix the broken dependencies manually
tp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] )
...
```
您可以使用以下的剪枝函数:
```python
tp.prune_conv_out_channels,
tp.prune_conv_in_channels,
tp.prune_depthwise_conv_out_channels,
tp.prune_depthwise_conv_in_channels,
tp.prune_batchnorm_out_channels,
tp.prune_batchnorm_in_channels,
tp.prune_linear_out_channels,
tp.prune_linear_in_channels,
tp.prune_prelu_out_channels,
tp.prune_prelu_in_channels,
tp.prune_layernorm_out_channels,
tp.prune_layernorm_in_channels,
tp.prune_embedding_out_channels,
tp.prune_embedding_in_channels,
tp.prune_parameter_out_channels,
tp.prune_parameter_in_channels,
tp.prune_multihead_attention_out_channels,
tp.prune_multihead_attention_in_channels,
```
### 5. 自定义层
请参考[tests/test_customized_layer.py](https://github.com/VainF/Torch-Pruning/blob/master/tests/test_customized_layer.py). 该示例演示了如何剪枝用户自定义的层.
### 6. 基准线 Benchmarks
Our results on {ResNet-56 / CIFAR-10 / 2.00x}
| Method | Base (%) | Pruned (%) | $\Delta$ Acc (%) | Speed Up |
|:-- |:--: |:--: |:--: |:--: |
| NIPS [[1]](#1) | - | - |-0.03 | 1.76x |
| Geometric [[2]](#2) | 93.59 | 93.26 | -0.33 | 1.70x |
| Polar [[3]](#3) | 93.80 | 93.83 | +0.03 |1.88x |
| CP [[4]](#4) | 92.80 | 91.80 | -1.00 |2.00x |
| AMC [[5]](#5) | 92.80 | 91.90 | -0.90 |2.00x |
| HRank [[6]](#6) | 93.26 | 92.17 | -0.09 |2.00x |
| SFP [[7]](#7) | 93.59 | 93.36 | +0.23 |2.11x |
| ResRep [[8]](#8) | 93.71 | 93.71 | +0.00 |2.12x |
||
| Ours-L1 | 93.53 | 92.93 | -0.60 | 2.12x |
| Ours-BN | 93.53 | 93.29 | -0.24 | 2.12x |
| Ours-Group | 93.53 | 93.77 | +0.38 | 2.13x |
详细信息请参考[benchmarks](benchmarks)
## Citation
```
@article{fang2023depgraph,
title={DepGraph: Towards Any Structural Pruning},
author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023}
}
```