精简ckpt

下面是一个网络的inference文件,忽略一些不重要的.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse

import cv2
import numpy as np
import tensorflow as tf
import neuralgym as ng

from inpaint_model import InpaintCAModel


parser = argparse.ArgumentParser()
parser.add_argument('--image', default='', type=str,
                    help='The filename of image to be completed.')
parser.add_argument('--mask', default='', type=str,
                    help='The filename of mask, value 255 indicates mask.')
parser.add_argument('--output', default='output.png', type=str,
                    help='Where to write output.')
parser.add_argument('--checkpoint_dir', default='', type=str,
                    help='The directory of tensorflow checkpoint.')


if __name__ == "__main__":
    ng.get_gpus(1)
    args = parser.parse_args()

    model = InpaintCAModel()
    image = cv2.imread(args.image)
    mask = cv2.imread(args.mask)
    ### 添加了两个placeholder
    input_img = tf.placeholder(tf.float32,shape=[1,512,512,3],name = 'image')
    input_mask = tf.placeholder(tf.float32,shape=[1,512,512,3],name = 'mask')
    assert image.shape == mask.shape

    h, w, _ = image.shape
    grid = 8
    image = image[:h//grid*grid, :w//grid*grid, :]
    mask = mask[:h//grid*grid, :w//grid*grid, :]
    print('Shape of image: {}'.format(image.shape))

    image = np.expand_dims(image, 0)
    mask = np.expand_dims(mask, 0)
    # input_image = np.concatenate([image, mask], axis=2)
    input_image = tf.concat([input_img,input_mask],axis = 2)
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        # input_image = tf.constant(input_image, dtype=tf.float32)
        output = model.build_server_graph(input_image)
        ### 这边对最后的输出重新取一个名字
        output = tf.identity(output, name="output")
        output = (output + 1.) * 127.5
        output = tf.reverse(output, [-1])
        output = tf.saturate_cast(output, tf.uint8)
        # load pretrained model
        vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        assign_ops = []
        for var in vars_list:
            vname = var.name
            # print(vname)            
            from_name = vname
            var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, from_name)
            assign_ops.append(tf.assign(var, var_value))
        sess.run(assign_ops)
        print('Model loaded.')
        result = sess.run(output,feed_dict={input_img:image,input_mask:mask})
        cv2.imwrite(args.output, result[0][:, :, ::-1])
        ### 这里是添加的saver,重新保存一遍ckpt
        saver = tf.train.Saver()
        saver.save(sess,'./model')

具体的措施:

  • 在网络中添加placeholder,取一个容易记住的名字,取代网络的输入,其实也为了后面更好的封装,同样的道理也对输出取一个名字,使用tf.identity.因为一般的网络定义是把numpy的数组直接扔进网络的,在后面转pb,转lite的时候不方便.
  • 在inference文件中重新用saver保存一下CKPT,为什么要这样做呢? 主要原因是网络训练过程中保存的CKPT包含了训练的一些参数,这些参数在inference的时候是不需要的,所以这样操作可以把CKPT减小,在这个网络中model从130M -> 14M.

下面再放一个改完名字之后的inference文件,对比一下,代码是不是简化了很多,封装性更好.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tensorflow as tf
import cv2
import sys
import numpy as np
img = cv2.imread(sys.argv[1])
msk = cv2.imread(sys.argv[2])
sess = tf.Session()
## 从meta恢复网络
saver = tf.train.import_meta_graph('./model.meta')
saver.restore(sess,'./model')

graph = tf.get_default_graph()

input_img = graph.get_tensor_by_name('image:0')
input_mask = graph.get_tensor_by_name('mask:0')

output = graph.get_tensor_by_name('output:0')

out = sess.run(output,feed_dict={input_img: [img],input_mask: [msk]})

out = (out + 1)*127.5 
out = out.astype(np.uint8)

cv2.imwrite('out.png',out[0])
print(' ok !')