|
作者丨盘子正@知乎(已授权)
写在前面: 本文介绍我们组在CVPR 2023的工作:[Stitchable Neural Networks],下文简称SN-Net。一种全新的模型部署方法,利用现有的model family直接做少量epoch finetune就可以得到大量插值般存在的子网络,运行时任意切换网络结构满足不同resource constraint。
Paper: https://arxiv.org/abs/2302.06586
Code: https://github.com/ziplab/SN-Net
背景
去年一次组会上,在和导师们讨论未来的research方向的时候,偶然聊到一个问题:
视频网站的视频播放会自动根据网络带宽调整画质,如网速好的时候到4K,网速差就720P甚至更低。那同一个神经网络能不能随时根据计算资源的变化调整推理速度? 从2012的AlexNet到2023年火出圈的ChatGPT, AI/ML这一社区在十年间少说已经训练了上百万个模型。截至这篇文章写作时,HuggingFace上可以直接下载的模型就有14万个,涵盖各个模态和任务。每个模型各司其职,用自己在训练中学到的知识去处理某一种场景,互不叨扰。

模型虽然越来越多,但是资源浪费也越来越严重。训练一个模型的成本很高,尤其是大模型训练,耗费数个节点和几天的算力才能得到一个好权重,但最后却受限于应用场景只能重新调整结构,然后再重新训练,如网络backbone设计中通常会有不同scale来满足不同的推理速度要求: ResNet-18/50/101,DeiT-Ti/S/B,Swin-Ti/S/B等等。
传统方法当然能加速模型推理,如pruning,distillation,quantization。但问题是这些方法一次大都只能针对一个模型,一个资源场景。我们也可以用NAS搜出来若干个子网络来满足不同推理速度需求,即使如此,NAS中训练一个Supernet的成本也是巨大的,典型的如OFA (https://arxiv.org/abs/1908.09791)和BigNAS (https://arxiv.org/abs/2003.11142),花费上千GPU hours才得到一个好网络,资源消耗巨大。
看着huggingface上这么大的model zoo,我们不禁想,整个社区花了大量时间,金钱和人力资源去训练网络,得到了这么多的 pretrained model,但是能不能有效利用起来? 况且这些模型已经训练好了,当需要他们的时候,能不能用少量计算资源就可以满足目标场景?
对这一问题的思考也是随着模型被工业界越推越大引出的。几年前一张1080就能跑完的实验,现在8张卡都很难train得动一个model,特别是Transformer出来之后。最新的ViT已经scale到22B,BAAI的 EVA (https://arxiv.org/abs/2211.07636)也把ViT扩展到了1B的参数级别。留给小组的空间越来越小,在资源有限(缺卡)的场景下,我们需要寻求新的突破方向。
Stitchable Neural Networks
Industry和Academia所关注的问题可以有些区别。既然大模型不是所有人能做得起的,那我们不如去利用好已有的pretrained model。现在我们有了一组训练好的model family,如DeiT-Tiny/Small/Base。不同模型有不同大小,推理速度,显存占用。那么能不能利用这些已有的weights和结构快速得到一批新网络来满足不同的资源场景?
我们在CVPR 2023最新的工作Stitchable Neural Network (SN-Net) 给出了一个非常具有潜力的方案。

SN-Net的主要思想是:在一组已经训练好的model family中插入若干个stitching layer (即1x1 conv), 使得forward时activation可以在模型间的不同位置游走。当模型在不同位置缝合的时候,一个个新网络结构就出来了!!!
此时,我们把原先model family中的网络叫做anchors,缝合出来的新网络叫做stitches。单个SN-Net可以cover众多FLOPs-accuracy的trade-off,如在基于Swin的实验中,一个SN-Net的可以挑战timm中200个独立的模型,整个实验不过是50 epochs,八张V100上训练不到一天。

下面会介绍详细的做法,以及我们当时方法设计时候的考虑。想直接看效果的朋友可以移步最后的结果展示。
1. 模型这么多,怎么去选择
这里主要考虑了几个地方:
- 不同模型结构在网络中各层学习到的representation会有较大差别,缝合出来的网络不一定保证较好的performance。
- 不同数据集学到的东西差别也很大,为了保证性能最好保持在相同pretrained的dataset下。
- 不同网络的实现和训练方式有差别,工程上很难权衡超参和data augmentation的选择。而同一个结构通常在一个repo里,更容易实现。
因此,我们初步关注在相同dataset上训练好的model family上, 即结构相似,但是模型scale不一样,如DeiT-Ti/S/B。
不同family能不能缝合?也能,我们paper里有展示结果,但是工程上会比较麻烦,需要combine不同repo并且权衡超参。
2. 怎么去做缝合?
model stitching在原先工作中大都是以研究representation similarity的形式呈现的,如
- Lenc, Karel, and Andrea Vedaldi. "Understanding image representations by measuring their equivariance and equivalence." CVPR 2015.
- Kornblith, Simon, et al. "Similarity of neural network representations revisited." ICML, 2019.
- Csiszárik, Adrián, et al. "Similarity and matching of neural network representations." NeurIPS 2021.
总结过去这些工作:同一个网络,用不同seed训练之后可以在某些位置缝合起来,此时性能不会掉的很离谱。后续的研究发现结构不一样的网络甚至也能缝合。
而stitching能够work在于,假设前一个网络出来的feature map属于activation 空间A,而另一个网络在此位置的输入feature map属于activation空间B,那么stitching layer做的事情就是把feature map从A空间映射到B空间,使得此时的feature map能模拟下一网络在这个位置的输入。
当网络是已经是pretrained,那么stitching这一过程完全可以formulate成一个求解least squares的问题。也就是说stitching layer这个weights的matrix是可以直接求出来的 (参考 Csiszárik, Adrián, et al (https://arxiv.org/abs/2110.14633) 这篇)。所以此时求解出来的matrix可以天然作为stitching layer的初始化。
3. 缝合方向的设定

现在我们有一个大模型:性能好但是推理速度慢,还有一个小模型:性能差点但是推理速度快。我们怎么决定谁stitch到谁呢?我们主要考虑了两个方面:
- 参考当前backbone设计的惯例,随着网络不断深入,channel dimension是在不断增大的。Fast-to-Slow这方向比较符合常见的网络设计。
- 实验验证Fast-to-Slow得到的curve要比Slow-to-Fast要smooth一点,详见论文。
所以目前SN-Net在方向上是从小模型缝合到大模型。同时我们提出一个constraint: nearest stitching,限制stitching只在复杂度(FLOPs)相邻的两个anchor之间。如补充材料中的Figure 10所示,以DeiT-Ti/S/B为例,我们的方法目前限制在(a), (b)两个case。

这个限制是因为我们发现anchor的gap比较大的时候,缝合出来的网络并不在一个optimal的区间。实验部分也证明直接stitch DeiT-Ti和DeiT-B效果不如中间加一个DeiT-S。
4. 怎么配置Stitching Layer

网络设计地千奇百怪,怎么去缝合是个问题。
我们以DeiT为例,在相同depth的缝合实验上采取了Paired Stitching这种策略。这种策略的启发来自于过去一些工作发现:相邻layer之间的representation是有较高的相似度的。所以我们选择在DeiT得相邻blocks中share同一个stitching layer,如滑窗一般进行stitching。
share的情况下,原先的初始化方法就是简单地对不同solution得到的matrix做一个average。选择share stitching layer还有其他好处,如减少过多stitching layer带来的参数量,同时扩大缝合出来的结构数量,即扩大stitching space。
另外一种情况是两个模型的depth不一样,小模型一般比较浅,block的数量要比大模型少。比如Swin-Ti的第三个stage只有6个block,而Swin-S在第三个stage有18个block。此时我们进行Unpaired Stitching,每个小模型的block都stitch到大模型的若干个block中。这样两个case就都解决了
5. SN-Net能缝出来多少网络?
这个由多种因素决定。
- 看选择的model family,即anchors的depth。显然anchor越深,那么能stitch的位置就越多,新网络结构也会更多。
- 相同depth下看stitching时sliding window的设置。
- 不加nearest stitching的时候得到的网络更多 (DeiT上的实验是十倍的差距,71 vs. 731)。但是此时不optimal。后续潜力尚待挖掘。
对比NAS中 10^{20}10^{20} 级别的search space, SN-Net在基于同一组model family得到的网络数量是有限的。但有一点不得不提,纵使search space再大,真正需要的时候也只是用pareto frontier上的网络结构,而SN-Net缝合出来的网络几乎天然落在pareto frontier上,同时部署的时候完全可以直接查表,几乎没有什么search cost。
另外一点是,SN-Net的潜力在于整个pretrained model zoo。有多少model familiy,就有多少潜在的SN-Net变种。这是NAS的单一supernet所不能比拟的。这意味着我们可以轻易缝合已有的model family达到NAS耗费大量计算资源搜出来的网络性能,比如简单缝合两个LeViT (https://arxiv.org/abs/2104.01136)就可以用更低的FLOPs(977M vs. 1040M) 达到媲美 BigNASModel-XL (https://arxiv.org/abs/2003.11142)的性能(80.7% vs. 80.9%),如下图所示

6. 简单的训练策略
训练SN-Net尤为简单。先提前把所有需要训练的stitches定义好,训练中每次iteration都随机sample出来一个stitch,后面和正常的训练一样进行loss回传,梯度下降。为了进一步提升stitches的性能,我们初步实验同时采用了RegNetY-160作为teacher model去做distillation。

结果展示
为了验证Joint Training和原有网络从头train的差距,我们选择了若干个和stitches相同的网络结构,然后在ImageNet上训满300 epochs。从下表可以看到,对比用了大量计算资源训练出来的网络,SN-Net利用已有的DeiT family只用50个epoch就可以得到比肩甚至更好的性能。同时整个网络只要118.4M的参数,而这71个stitches的总量如果单独训练需要2630M,耗费 71 × 300 epochs,和SN-Net比是22倍的差距。

基于DeiT和Swin Transformer, 我们验证了缝合plain ViT和hierarchical ViT的可行性。性能曲线如在anchors中进行插值一般。

值得一提的是,图中不同点所表示的子网络,即stitch,是可以在运行时随时切换的。这意味着网络在runtime完全可以依靠查表进行瞬时推理速度调整。这个是诸多网络无法实现的,但颇具现实意义。比如现在很多手机都有省电模式,一旦进行power saving, 手机掉帧,系统运行速度变慢,而此时neural network也可以调整推理速度,做一个speed-accuracy的trade-off。
我们当然也尝试了stitch cnn,甚至不同的family,结果非常promising。

更多实验内容和分析请移步我们的arxiv论文:Stitchable Neural Networks
SN-Net的可扩展空间
SN-Net生于large model zoo的时代。我们初版方法给出了一个最简单的baseline,相信未来有很大的扩展空间,比如
- 当前的训练策略比较简单,每次iteration sample出来一个stitch,但是当stitches特别多的时候,可能导致某些stitch训练的不够充分,除非增加训练时间。所以训练策略上可以继续改进。
- anchor的performance会比之前下降一些,虽然不大。直觉上,在joint training过程中,anchor为了保证众多stitches的性能在自身weights上做了一些trade-off。目前补充材料里发现finetune更多epoch可以把这部分损失补回来。
- 不用nearest stitching可以明显扩大space,但此时大部分网络不在pareto frontier上,未来可以结合训练策略进行改进,或者在其他地方发现advantage。
- 未来能否有个更好方法和统一的框架去缝合任意网络。到那时,整个model zoo就像积木一样,可操作空间更大,玩法更多,这一点NUS的Xingyi Yang (https://adamdad.github.io/)之前有尝试,参考Deep Model Reassembly (https://arxiv.org/abs/2210.17409).
更多探索就留给future work了。代码已经开源至https://github.com/ziplab/SN-Net,硬件要求十分友好,50个epoch (用8卡V100大约半天时间) 就可以复现结果。欢迎有兴趣的同学进行尝试!
个人主页:https://zizhengpan.github.io/
实验室主页:https://ziplab.github.io/
文中如有错误,欢迎指出,同时欢迎各位进行学术交流~
往期精选
数据集汇总:
- 人脸识别常用数据集大全
- 行人检测数据集汇总
- 10个开源工业检测数据集汇总
- 21个深度学习开源数据集分类汇总(持续更新)
- 小目标检测、图像分类、图像识别等开源数据集汇总
- 人体姿态估计相关开源数据集介绍及汇总
- 小目标检测相关开源数据集介绍及汇总
- 医学图像开源数据集汇总
- 自动驾驶方向开源数据集资源汇总
- 目标检测开源数据集汇总(二)
- RGB-T 相关开源数据集资源汇总
- 图像去雾开源数据集资源汇总
- 图像分类方向优质开源数据集汇总(附下载链接)
顶会资源:
- ECCV22 最新54篇论文分方向整理|包含目标检测、图像分割、监督学习等(附下载)
- CVPR 2022 全面盘点:最新250篇论文分方向汇总 / 代码 / 解读 / 直播 / 项目(更新中)
- CVPR'22 最新106篇论文分方向整理|包含目标检测、动作识别、图像处理等32个方向
- 一文看尽 CVPR2022 最新 22 篇论文(附打包下载)
- 17 篇 CVPR 2022 论文速递|涵盖 3D 目标检测、医学影像、车道线检测等方向CVPR 2021 结果出炉!最新500篇CVPR'21论文分方向汇总(更新中)
- CVPR 2021 结果出炉!最新600篇CVPR'21论文分方向汇总(更新中)
- CVPR 2020 Oral 汇总:论文/代码/解读(更新中)
- CVPR 2019 最全整理:全部论文下载,Github源码汇总、直播视频、论文解读等
- CVPR 2018 论文解读集锦(9月27日更新)
- CVPR 2018 目标检测(object detection)算法总览
- ECCV 2018 目标检测(object detection)算法总览(部分含代码)
- CVPR 2017 论文解读集锦(12-13更新)
- 2000 ~2020 年历届 CVPR 最佳论文汇总
技术综述:
- 万字长文 | 手把手教你优化轻量姿态估计模型(算法篇)
- 工业应用中如何选取合适的损失函数(MAE、MSE、Huber)-Pytorch版
- 综述:图像处理中的注意力机制
- 搞懂Transformer结构,看这篇PyTorch实现就够了
- 熬了一晚上,我从零实现了Transformer模型,把代码讲给你听
- YOLO算法最全综述:从YOLOv1到YOLOv5
- 图像匹配大领域综述!涵盖 8 个子领域,近 20年经典方法汇总
- 一文读懂深度学习中的各种卷积
- 万字综述|核心开发者全面解读Pytorch内部机制
- 19个损失函数汇总,以Pytorch为例
- 一文看尽深度学习中的15种损失函数
- 14种异常检测方法总结
- PyTorch常用代码段合集
- 神经网络压缩综述
理论深挖:
- 深入探讨:为什么要做特征归一化/标准化?
- 令人“细思极恐”的Faster-R-CNN
论文盘点:
- 图像分割二十年,盘点影响力最大的10篇论文
- 2020年54篇最新CV领域综述论文速递!涵盖14个方向:目标检测/图像分割/医学影像/人脸识别等
实践/面经/求学:
- 如何配置一台深度学习工作站?
- 国内外优秀的计算机视觉团队汇总
- 50种Matplotlib科研论文绘图合集,含代码实现
- 图像处理知多少?准大厂算法工程师30+场秋招后总结的面经问题详解
- 深度学习三十问!一位算法工程师经历30+场CV面试后总结的常见问题合集(含答案)
- 深度学习六十问!一位算法工程师经历30+场CV面试后总结的常见问题合集下篇(含答案)
- 一位算法工程师从30+场秋招面试中总结出的目标检测算法面经(含答案)
- 一位算法工程师从30+场秋招面试中总结出的语义分割超强面经(含答案)
Github优质资源:
- 25个【Awsome】GitHub 资源汇总(更新中)
- 超强合集:OCR 文本检测干货汇总(含论文、源码、demo 等资源)
- 2019-2020年目标跟踪资源全汇总(论文、模型代码、优秀实验室)
|
|