1. tensorflow模型文件打包成PB文件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
import tensorflow as tf from tensorflow.python.tools import freeze_graph with tf.Graph().as_default(): with tf.device( "/cpu:0" ): config = tf.ConfigProto(allow_soft_placement = True ) with tf.Session(config = config).as_default() as sess: model = Your_Model_Name() model.build_graph() sess.run(tf.initialize_all_variables()) saver = tf.train.Saver() ckpt_path = "/your/model/path" saver.restore(sess, ckpt_path) graphdef = tf.get_default_graph().as_graph_def() tf.train.write_graph(sess.graph_def, "/your/save/path/" , "save_name.pb" ,as_text = False ) frozen_graph = tf.graph_util.convert_variables_to_constants(sess,graphdef,[ 'output/node/name' ]) frozen_graph_trim = tf.graph_util.remove_training_nodes(frozen_graph) freeze_graph.freeze_graph( '/your/save/path/save_name.pb' ,' ',True, ckpt_path,' output / node / name ',' save / restore_all ',' save / Const: 0 ',' frozen_name.pb', True ,"") |
2. PB文件读取使用
1
2
3
4
5
6
7
8
9
10
|
output_graph_def = tf.GraphDef() with open ( "your_name.pb" , "rb" ) as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name = "") node_in = sess.graph.get_tensor_by_name( "input_node_name" ) model_out = sess.graph.get_tensor_by_name( "out_node_name" ) feed_dict = {node_in:in_data} pred = sess.run(model_out, feed_dict) |
以上这篇将tensorflow模型打包成PB文件及PB文件读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/hustchenze/article/details/83660960