产生的原因与对应解决方案

根本原因:传播链路太深 + 每层出点小问题

梯度消失与梯度爆炸其实可以看成是同一个问题,其根本上是两个因素造成的。

因素一:传播链路太深

如果传播链路太深,且梯度是叠加前向传播的,容易使得梯度的小问题积累成大问题。例如:

  • 网络太深,例如 ResNet 之前的 网络。
  • RNN 传播链路太长。

解决这一因素的关键是 切断积累,例如:

  • 梯度裁切(Gradient Clipping),限制梯度的最大值。
  • 截断反向传播,最大不超过若干个 time step。
  • 使用 LSTMGRU 等模型,从结构上缓解 RNN 这一问题(状态的延续加入了累加的形式)。

因素二:每一层出一点点问题

每一层出一点点问题,累积下来都会导致最终出大问题,例如:

  • 使用了 Sigmoid 激活函数,其导数最大值只有 0.25,则容易梯度消失。
  • 网络权值 初始化 太大 / 太小,则容易梯度爆炸 / 消失。

解决这一因素的关键是 解决小问题,例如:

  • Sigmoid 换成的激活函数如 ReLULReLU 等。
  • 正确的初始化,均值为 1。

其他解决方案

针对以上问题,很多都可以针对性解决。当然,也存在一些相对通用的方法,例如:

  • 预训练 + 微调:相当于站在巨人的肩膀上,会一开始就比较稳。
  • Regularization:如果发生梯度爆炸,那么权值也会变得非常大,因此通过 Regularization 可以一定程度上限制梯度爆炸的发生。
  • Batch Normalization:强行把每一层的输出分布拉回到一个比较正常、稳定的正态分布(先 0 均值 1 方差再经过一定的 transformation)。
  • Shortcuts:残差结构,可以相隔多层进行连接,增强梯度的流动,防止梯度消失。

参考