How to apply the batch-normalized net
继续填这篇文章的坑,如何测试和应用包含了Batch Normalization层的网络? 在训练过程中,每个BN层直接从输入样本中求取mean
和variance
量,不是通过学习获取的固定值。因此,在测试网络时,需要人工提供这两个值。
在BN的文章里的处理方法是,对所有参与训练的mini-batch的均值和方差进行收集,采用无偏估计的方式估计总体样本的均值和方差,来表征测试样本的均值和方差,其公式如下,
进而,BN layer的输出定义为,
那么有如下几个问题需要解决,
- 训练和测试过程中如何给BN传递
mean
和variance
?即如何在计算图上体现这一运算? - 如何动态收集每个mini-batch的mean和variance,用于总体样本的无偏估计moving_mean, moving_variance
针对以上问题,TensorFlow的解决思路是设定is_training
这个flag,如果为真,则每个mini-batch都会计算均值和方差,训练网络; 如果为假,则进入测试流程。
基于tf.nn.batch_normalization的底层实现
TF提供了tf.nn.batch_normalization
函数从底层搭建网络,其直接参考了Ioeff\&Szegdy的论文,这里需要利用tf.nn.moments
求取mini-batch的均值和方差,详细的实现代码参考这里.
基于tf.contrib.layers.batch_norm的实现
在tf.contrib.layers提供了batch_norm
方法,该方法是对tf.nn.batch_normalization
的封装,增加了如center
,is_training
等变量,并对BN的基础算法做了更新,用滑动平均来实现均值和房车的估计。
那么,如何实现包含BN层的网络的训练和测试? 其核心是利用is_training作为flag控制输入给BN的mean和variance的来源,以及如何将moving_mean和moving_variance加入网络的训练过程中。
TF官方的建议方法解释是,
Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:
参考这篇博客,作者对此做了更棒的解释!!!!!
When you execute an operation (such as train_step), only the subgraph components relevant to train_step will be executed. Unfortunately, the update_moving_averages operation is not a parent of train_step in the computational graph, so we will never update the moving averages!
作者的解决方法:Personally, I think it makes more sense to attach the update ops to the train_step itself. So I modified the code a little and created the following training function
以上代码在tf.slim.batch_norm
中也有体现,slim是对tf的一个更高层的封装,利用slim实现的ResNet-v2-152可以参考这里。
最后,贴上基于tf.contrib.layers.batch_norm
的实现样例,更详细的实现见我的notebook。
MLP是否采用BN的结果对比
最后,贴一个是否采用BN层的结果对比,效果还是比较显著的。但是我也发现由于我设置的网络层数和FC长度都比较可观,随着Epochs增大,BN的优势并没有那么明显了。。。

Enjoy it !! 我终于把这个问题看懂了,开心
References
[1] Ioffe, S. and Szegedy, C., 2015, June. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning (pp. 448-456).
[2] tensorflow 中batch normalize 的使用
[3] docs: batch normalization usage in slim #7469
[4] tf.layers.batch_normalization
[5] TENSORFLOW GUIDE: BATCH NORMALIZATION