Tensorflow namespace and network restoration
昨天讨论了基于TensorFlow的迁移学习,提到了网络的存储和恢复的问题,然而我并没有说清楚,而且网络的恢复要考虑的问题其实挺多的。。。
概括为如下三个问题:
- 变量的命名空间问题;
- 网络的恢复问题;
- 网络中tensor名称和数据的获取
变量命名空间
首先,来说明命名空间的问题。通常我们利用tf.Variable
或tf.placeholder
等初始话一个变量,这两个类都有缺省的命名参数,例如tf.Variable
的参数name=Variable
。相应的,tf在我们新建变量的过程中按顺序在Variable
后添加数字,例如Variable_1
,Variable_2
等。
这种缺省命名虽然方便,但对网络的存储和恢复会造成很大的影响,所以如果要保存我们的网络,最好的方法是给每个变量设置命名空间。这些变量包括Variables
,placeholder
以及optimizer
。
网络的恢复问题
如上一篇博客所述,tensorflow的tf.train.Saver
类既可以存储网络,也可以恢复网络。其中Saver保存的.ckpt
文件包含checkpoint和metadata,分别存储了graph的命名空间和元数据。恢复的时候便是基于他们读取数据到网络中。
但是,单有checkpoint和metadata是没有用的,tensorflow的核心就是graph,所以我们需要在恢复网络前重新搭建graph,并且这个graph的命名空间要与checkpoint的相同。因此,一定要养成好习惯,在定义网络的时候,给每个变量都设置固定的名称。
除此之外,在一个ipython环境或者notebook下,如果要保存网络,建议只搭建一个graph。因为我发现,tf.train.Saver
类在保存checkpoint的时候,会将目前存在的graph全部保存。
下面给一段代码,从这篇文章复制来的。。。
保存模型
|
|
恢复模型
|
|
从上面可以看出,二者的区别在于是否需要初始化变量,这也是网络恢复的核心问题,因为我们的目的就是恢复参数。
网络中tensor名称和数据的获取
如何从存储的网络中提取变量的命名及其数据,可以参考下面的程序,也是抄来的。。。
输出结果类似这样,