Pb文件跑前向

直接使用转换成的pb文件进行inference.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tensorflow as tf
# from tensorflow.python.framework import graph_util
import cv2
import numpy as np

img = cv2.imread('img_250.jpg')
img = img.astype(np.float32)*2/255 - 1

output_graph_path = './graph.pb'
with tf.Session() as sess:
    # with tf.gfile.FastGFile(output_graph_path, 'rb') as f:
    #     graph_def = tf.GraphDef()
    #     graph_def.ParseFromString(f.read())
    #     sess.graph.as_default()
    #     tf.import_graph_def(graph_def, name='')
    tf.global_variables_initializer().run()
    output_graph_def = tf.GraphDef()
    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")

    input_x = sess.graph.get_tensor_by_name("input:0")
    print input_x

    output = sess.graph.get_tensor_by_name("out:0")
    print output
    res = sess.run(output,feed_dict={input_x:[img]})
    print np.unique(np.argmax(res[0],axis=2))