-
Notifications
You must be signed in to change notification settings - Fork 20
/
dump_filters.py
56 lines (47 loc) · 2.08 KB
/
dump_filters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from IPython import display
from matplotlib import pyplot
import numpy as np
import os
import sys
# Make sure that you set this to the location your caffe2 library lies.
caffe2_root = '/home/awesomebox/code/caffe2/'
sys.path.insert(0, os.path.join(caffe2_root, 'gen'))
# After setting the caffe2 root path, we will import all the caffe2 libraries needed.
from caffe2.proto import caffe2_pb2
from pycaffe2 import core, net_drawer, workspace, visualize
# net is the network definition.
net = caffe2_pb2.NetDef()
net.ParseFromString(open('inception_net.pb').read())
# tensors contain the parameter tensors.
tensors = caffe2_pb2.TensorProtos()
tensors.ParseFromString(open('inception_tensors.pb').read())
DEVICE_OPTION = caffe2_pb2.DeviceOption()
# Let's use CPU in our example.
DEVICE_OPTION.device_type = caffe2_pb2.CPU
# If you have a GPU and want to run things there, uncomment the below two lines.
# If you have multiple GPUs, you also might want to specify a gpu id.
#DEVICE_OPTION.device_type = caffe2_pb2.CUDA
#DEVICE_OPTION.cuda_gpu_id = 0
# Caffe2 has a concept of "workspace", which is similar to that of Matlab. Each workspace
# is a self-contained set of tensors and networks. In this case, we will just use the default
# workspace so we won't dive too deep into it.
workspace.SwitchWorkspace('default')
# First, we feed all the parameters to the workspace.
for param in tensors.protos:
workspace.FeedBlob(param.name, param, DEVICE_OPTION)
# The network expects an input blob called "input", which we create here.
# The content of the input blob is going to be fed when we actually do
# classification.
workspace.CreateBlob("input")
# Specify the device option of the network, and then create it.
net.device_option.CopyFrom(DEVICE_OPTION)
workspace.CreateNet(net)
########################################
### MY CODE ############################
for param in tensors.protos:
print(param.name)
filters = workspace.FetchBlob(param.name)
import h5py
h5f = h5py.File('dump/' + param.name + '.h5', 'w')
h5f.create_dataset(param.name, data=filters)
h5f.close()