昨天讨论了基于TensorFlow的迁移学习,提到了网络的存储和恢复的问题,然而我并没有说清楚,而且网络的恢复要考虑的问题其实挺多的。。。

概括为如下三个问题:

  1. 变量的命名空间问题;
  2. 网络的恢复问题;
  3. 网络中tensor名称和数据的获取

变量命名空间

首先,来说明命名空间的问题。通常我们利用tf.Variabletf.placeholder等初始话一个变量,这两个类都有缺省的命名参数,例如tf.Variable的参数name=Variable。相应的,tf在我们新建变量的过程中按顺序在Variable后添加数字,例如Variable_1,Variable_2等。

这种缺省命名虽然方便,但对网络的存储和恢复会造成很大的影响,所以如果要保存我们的网络,最好的方法是给每个变量设置命名空间。这些变量包括Variables,placeholder以及optimizer

网络的恢复问题

如上一篇博客所述,tensorflow的tf.train.Saver类既可以存储网络,也可以恢复网络。其中Saver保存的.ckpt文件包含checkpointmetadata,分别存储了graph的命名空间和元数据。恢复的时候便是基于他们读取数据到网络中。

但是,单有checkpoint和metadata是没有用的,tensorflow的核心就是graph,所以我们需要在恢复网络前重新搭建graph,并且这个graph的命名空间要与checkpoint的相同。因此,一定要养成好习惯,在定义网络的时候,给每个变量都设置固定的名称。

除此之外,在一个ipython环境或者notebook下,如果要保存网络,建议只搭建一个graph。因为我发现,tf.train.Saver类在保存checkpoint的时候,会将目前存在的graph全部保存。

下面给一段代码,从这篇文章复制来的。。。

保存模型
1
2
3
4
5
6
7
8
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.InteractiveSession()
sess.run(init_op)
savepath = "./model.ckpt"
saver.save(sess, savepath)
恢复模型
1
2
3
4
5
6
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.Saver()
sess = tf.InteractiveSession()
modelpath = "./model.ckpt"
sess = saver.restore(sess, modelpath)

从上面可以看出,二者的区别在于是否需要初始化变量,这也是网络恢复的核心问题,因为我们的目的就是恢复参数。

网络中tensor名称和数据的获取

如何从存储的网络中提取变量的命名及其数据,可以参考下面的程序,也是抄来的。。。

1
2
3
4
5
6
7
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key) # Ouput variables name
print(reader.get_tensor(key)) # Output variables data

输出结果类似这样,

1
2
3
4
tensor_name: Conv_En_b2/cae-optimizer
tensor_name: Conv_De_b1/cae-optimizer_1
tensor_name: Conv_En_W0
tensor_name: De_b/cae-optimizer

References