-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathconvert_tf_to_pb.py
34 lines (29 loc) · 1.42 KB
/
convert_tf_to_pb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import tensorflow as tf
import os
from models.tf_modeling import BertModelMRC, BertConfig
import utils
from tensorflow.python.framework import graph_util
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
max_seq_length = 512
bert_config = BertConfig.from_json_file('check_points/pretrain_models/roberta_wwm_ext_large/bert_config.json')
input_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='input_ids')
segment_ids = tf.placeholder(tf.int32, shape=[None, max_seq_length], name='segment_ids')
input_mask = tf.placeholder(tf.float32, shape=[None, max_seq_length], name='input_mask')
eval_model = BertModelMRC(config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_float16=False)
utils.init_from_checkpoint('model.ckpt')
config = tf.ConfigProto()
config.allow_soft_placement = True
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
with tf.gfile.FastGFile('model.pb', 'wb') as f:
graph_def = sess.graph.as_graph_def()
output_nodes = ['start_logits', 'end_logits']
print('outputs:', output_nodes)
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, output_nodes)
f.write(output_graph_def.SerializeToString())