-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutils.py
67 lines (61 loc) · 2.17 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# ------------------------------------------------------------------------------
#
# Utility Functions for scripts
#
# ------------------------------------------------------------------------------
import sys
def get_args(argc, err_msg):
if argc == 0:
return None
if len(sys.argv) != argc + 1:
print 'Script usage: $ python ' + err_msg
sys.exit(0)
return sys.argv[1:]
def error_exit(msg):
print "\n" + 'Message for user:' + "\n " + msg + "\n" + 'Exiting' + "\n"
sys.exit(0)
def read_config_file(fname):
print 'Reading config file: ' + fname
f = open(fname)
key_to_value = {}
for line in f:
line = line.strip()
try:
key = line.split()[0]
except:
continue
try:
val = int(line.split()[1])
except:
val = None
key_to_value[key] = val
f.close()
return key_to_value
# Write a summary file of an imported graph
def write_summary_file(name, sess, graph_def, imported=False):
f = open(name, 'w')
for node in graph_def.node:
f.write(node.name + "\n")
prefix = ''
if imported:
prefix = 'import/'
if node.op not in ['NoOp', 'SaveV2']:
f.write(' op = ' + node.op + "\n")
try:
f.write(' output size = ' + str(sess.graph.get_tensor_by_name(prefix + node.name + ':0').get_shape()) + "\n")
except:
f.write(' output size = 0' + "\n")
if node.op in ['Conv2D', 'DepthwiseConv2dNative']:
f.write(' padding = ' + str(node.attr['padding'].s) + "\n")
f.write(' stride = ' + str(node.attr['strides'].list.i) + "\n")
# f.write(' stride = ' + str(node.attr['strides'].list.i[1]) + "\n")
# f.write(' k = ' + str(sess.graph.get_tensor_by_name(prefix + node.input[1] + ':0').get_shape().as_list()[1]) + "\n")
f.write(' k = ' + str(sess.graph.get_tensor_by_name(prefix + node.input[1] + ':0').get_shape()) + "\n")
elif node.op == 'MaxPool':
f.write(' stride = ' + str(node.attr['strides'].list.i) + "\n")
f.write(' k = ' + str(node.attr['ksize'].list.i) + "\n")
input_idx = 0
for input in node.input:
f.write(' in' + str(input_idx) + ' = ' + str(input) + "\n")
input_idx += 1
f.close()