节省显存的重计算技巧也有了Keras版了
别忘了对比不同版本的修订内容,有时旧版也有其价值。 #生活知识# #购物技巧# #图书选购建议#
29 Apr
By 苏剑林 | 2020-04-29 | 49369位读者 |
不少读者最近可能留意到了公众号文章《BERT重计算:用22.5%的训练时间节省5倍的显存开销(附代码)》,里边介绍了一个叫做“重计算”的技巧,简单来说就是用来省显存的方法,让平均训练速度慢一点,但batch_size可以增大好几倍。该技巧首先发布于论文《Training Deep Nets with Sublinear Memory Cost》,其实在2016年就已经提出了,只不过似乎还没有特别流行起来。
探索 #
公众号文章提到该技巧在pytorch和paddlepaddle都有原生实现了,但tensorflow还没有。但事实上从tensorflow 1.8开始,tensorflow就已经自带了该功能了,当时被列入了tf.contrib这个子库中,而从tensorflow 1.15开始,它就被内置为tensorflow的主函数之一,那就是tf.recompute_grad。
找到tf.recompute_grad之后,笔者就琢磨了一下它的用法,经过一番折腾,最终居然真的成功地用起来了,居然成功地让batch_size从48增加到了144!然而,在继续整理测试的过程中,发现这玩意居然在tensorflow 2.x是失效的...于是再折腾了两天,查找了各种资料并反复调试,最终算是成功地补充了这一缺陷。
最后是笔者自己的开源实现:
该实现已经内置在bert4keras中,使用bert4keras的读者可以升级到最新版本(0.7.5+)来测试该功能。
使用 #
笔者的实现也命名为recompute_grad,它是一个装饰器,用于自定义Keras层的call函数,比如
from recompute import recompute_grad class MyLayer(Layer): @recompute_grad def call(self, inputs): return inputs * 2
对于已经存在的层,可以通过继承的方式来装饰:
from recompute import recompute_grad class MyDense(Dense): @recompute_grad def call(self, inputs): return super(MyDense, self).call(inputs)
自定义好层之后,在代码中嵌入自定义层,然后在执行代码之前,加入环境变量RECOMPUTE=1来启用重计算。
注意:不是在总模型里插入了@recompute_grad,就能达到省内存的目的,而是要在每个层都插入@recompute_grad才能更好地省显存。简单来说,就是插入的@recompute_grad越多,就省显存。具体原因请仔细理解重计算的原理。
效果 #
bert4keras 0.7.5+已经内置了重计算,直接传入环境变量RECOMPUTE=1就会启用重计算,读者可以自行尝试,大概的效果是:
1、在BERT Base版本下,batch_size可以增大为原来的3倍左右;
2、在BERT Large版本下,batch_size可以增大为原来的4倍左右;
3、平均每个样本的训练时间大约增加25%;
4、理论上,层数越多,batch_size可以增大的倍数越大。
环境 #
在下面的环境下测试通过:
tensorflow 1.14 + keras 2.3.1
tensorflow 1.15 + keras 2.3.1
tensorflow 2.0 + keras 2.3.1
tensorflow 2.1 + keras 2.3.1
tensorflow 2.0 + 自带tf.keras
tensorflow 2.1 + 自带tf.keras
确认不支持的环境:
tensorflow 1.x + 自带tf.keras欢迎报告更多的测试结果。
顺便说一下,强烈建议用keras 2.3.1配合tensorflow 1.x/2.x来跑,强烈不建议使用tensorflow 2.x自带的tf.keras来跑。
参考 #
最后,笔者的实现主要参考自如下两个源码,在此表示感谢。
https://github.com/davisyoshida/tf2-gradient-checkpointing
https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/ops/custom_gradient.py#L454-L499
转载到请包括本文地址:https://spaces.ac.cn/archives/7367
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Apr. 29, 2020). 《节省显存的重计算技巧也有了Keras版了 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/7367
@online{kexuefm-7367,
title={节省显存的重计算技巧也有了Keras版了},
author={苏剑林},
year={2020},
month={Apr},
url={\url{https://spaces.ac.cn/archives/7367}},
}
网址:节省显存的重计算技巧也有了Keras版了 https://www.yuejiaxmz.com/news/view/460893
相关内容
把显存用在刀刃上!17 种 pytorch 节约显存技巧PyTorch 节省显存的策略总结
了解:20个节省时间的技巧,以提高设计师的工作流程
4个小技巧帮你节省装修预算
keras遇到的问题
存款技巧:如何在日常生活中节省开支并存钱
计算机技巧(计算机技巧分享)
钢筋重量计算器:精准计算,省时省力(钢筋 重量计算器)
有哪些存钱或省钱的小技巧?
节省显存新思路,在 PyTorch 里使用 2 bit 激活压缩训练神经网络