在实践中经常会遇到这样的情况:
1、用简单的模型预训练参数
2、把预训练的参数导入复杂的模型后训练复杂的模型
这时就产生一个问题:
如何加载预训练的参数。
下面就是我的总结。
为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。
卷积层的实现代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
import tensorflow as tf # PS:本篇的重担是saver,不过为了方便阅读还是说明下参数 # 参数 # name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope # input_data:输入数据 # width, high:卷积小窗口的宽、高 # deep_before, deep_after:卷积前后的神经元数量 # stride:卷积小窗口的移动步长 def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type = 'SAME' ): global parameters with tf.name_scope(name) asscope: weights = tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after], dtype = tf.float32,stddev = 0.01 ), trainable = True , name = 'weights' ) biases = tf.Variable(tf.constant( 0.1 , shape = [deep_after]), trainable = True , name = 'biases' ) conv = tf.nn.conv2d(input_data, weights, [ 1 , stride, stride, 1 ], padding = padding_type) bias = tf.add(conv,biases) bias = batch_norm(bias,deep_after, 1 ) # batch_norm是自己写的batchnorm函数 conv = tf.maximum( 0.1 * bias, bias) return conv |
简单的预训练模型就下面一句话
1
|
conv1 = make_conv( 'simple-conv1' , images, 3 , 3 , 3 , 32 , 1 ) |
复杂的模型是两个卷基层,如下:
1
2
3
|
conv1 = make_conv( 'complex-conv1' ,images, 3 , 3 , 3 , 32 , 1 ) pool1 = make_max_pool( 'layer1-pool1' , conv1, 2 , 2 ) conv2 = make_conv( 'complex-conv2' , pool1, 3 , 3 , 32 , 64 , 1 ) |
这时简简单单的在预训练模型中:
1
2
3
|
saver = tf.train.Saver() with tf.Session() as sess: saver.save(sess, 'model.ckpt' ) |
就不行了,因为:
1,如果你在预训练模型中使用下面的话打印所有tensor
1
2
|
all_v = tf.global_variables() for i in all_v: print i |
会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:
1
2
3
4
5
6
7
8
9
10
11
|
<tf.Variable 'simple-conv1/weights:0' shape = ( 3 , 3 , 3 , 32 ) dtype = float32_ref> <tf.Variable 'simple-conv1/biases:0' shape = ( 32 ,) dtype = float32_ref> <tf.Variable 'simple-conv1/Variable:0' shape = ( 32 ,)dtype = float32_ref> <tf.Variable 'simple-conv1/Variable_1:0' shape = ( 32 ,)dtype = float32_ref> <tf.Variable 'simple-conv1/Variable_2:0' shape = ( 32 ,)dtype = float32_ref> <tf.Variable 'simple-conv1/Variable_3:0' shape = ( 32 ,)dtype = float32_ref> |
同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。
2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。
解决方法:
1,在预训练模型中定义全局变量
1
|
parm_dict = {} |
并在“return conv”上面添加下面两行
1
2
|
parm_dict[ 'complex-conv1/weights' ] = weights parm_dict[ 'complex-conv1/' ] = biases |
然后在定义saver时使用下面这句话:
1
|
saver = tf.train.Saver(parm_dict) |
这样保存后的模型文件就对应到复杂模型上了。
2,在复杂模型中定义全局变量
1
|
parameters = [] |
并在“return conv”上面添加下面行
1
|
parameters + = [weights, biases] |
然后判断如果是第二个卷积层就不更新parameters。
接着在定义saver时使用下面这句话:
1
|
saver = tf.train.Saver(parameters) |
这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。
最后使用下面的代码加载就可以了
1
2
3
4
5
6
7
|
with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state( '.' ) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess,ckpt.model_checkpoint_path) else : print ' no saver.' exit() |
以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/xueyingxue001/article/details/70676253