Skip to content

Commit

Permalink
Fixed prediction using reset data instead of prepare (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChWick authored Jan 8, 2019
1 parent d158969 commit 9347420
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions calamari_ocr/ocr/backends/tensorflow_backend/tensorflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(self, network_proto, graph, session, graph_type="train", batch_size
self.output_seq_len, self.time_major_logits, self.time_major_softmax, self.logits, self.softmax, self.decoded, self.sparse_decoded, self.scale_factor = \
self.create_network(self.inputs, self.input_seq_len, self.dropout_rate, reuse_variables=reuse_weights)

self.uninitialized_variable_initializer = None
self.all_variable_initializer = None

def is_gpu_available(self):
# create a dummy session and list available devices
Expand Down Expand Up @@ -347,12 +349,17 @@ def prepare(self, uninitialized_variables_only=True):
super().prepare()
self.reset_data()
with self.graph.as_default():
# only create the initializers once, else the graph is growing...
if not self.uninitialized_variable_initializer:
self.uninitialized_variable_initializer = tf.variables_initializer(self.uninitialized_variables())
if not self.all_variable_initializer:
self.all_variable_initializer = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

# run the desired initializer
if uninitialized_variables_only:
self.session.run(tf.variables_initializer(self.uninitialized_variables()))
self.session.run(self.uninitialized_variable_initializer)
else:
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
self.session.run(init_op)
self.session.run(self.all_variable_initializer)

def load_weights(self, filepath, restore_only_trainable=True):
with self.graph.as_default() as g:
Expand Down
2 changes: 1 addition & 1 deletion calamari_ocr/ocr/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def predict_input_dataset(self, input_dataset: InputDataset, progress_bar=True):
"""

self.network.set_input_dataset(input_dataset, self.codec)
self.network.prepare(uninitialized_variables_only=True)
self.network.reset_data()

if progress_bar:
out = tqdm(self.network.prediction_step(), desc="Prediction", total=len(input_dataset))
Expand Down

0 comments on commit 9347420

Please sign in to comment.