-
Notifications
You must be signed in to change notification settings - Fork 134
/
testCheckpoint.py
executable file
·81 lines (59 loc) · 1.96 KB
/
testCheckpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/python
#
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import tensorflow as tf
import os
import sys
import numpy as np
import argparse
parser = argparse.ArgumentParser(description='Checkpoint tester.')
parser.add_argument('-stats', type=int, default=0, help='Enable statistics')
parser.add_argument('-n', type=str, default="", help='Network checkpoint')
opt=parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ""
reader = tf.contrib.framework.load_checkpoint(opt.n)
sumSize=0.0
sizes = []
names = []
sMap = reader.get_variable_to_shape_map()
for v, s in sMap.items():
size=0.0
if len(s)>0:
size = np.prod(s)
size *= 4.0
sumSize += size
size /= 1024.0*1024.0
sizes.append(size)
names.append(v)
i = np.argsort(sizes)[::-1]
sizes = np.array(sizes)[i]
names = np.array(names)[i]
cumulativeSize = 0
for i in range(len(sizes)):
sSize= "%.2f Mb" % sizes[i]
sName = names[i]
cumulativeSize += sizes[i]
sTotalSize= "%.2f Mb" % cumulativeSize
print("%10s %20s %15s \t %s" % (sSize, str(sMap[sName]), sTotalSize, sName))
print("Total size: %.2f Mb" % (sumSize/(1024.0*1024.0)))
if opt.stats==1:
print("-------------------------------------------------")
print("Statistics:")
for i in range(len(sizes)):
t = reader.get_tensor(names[i])
s = " %f %f %f" % (t.min(), t.mean(), t.max())
print("%60s \t %s" % (s, names[i]))