-
Notifications
You must be signed in to change notification settings - Fork 57
/
utils.py
109 lines (86 loc) · 4.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright 2018 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Short utility functions for LEO."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
from six.moves import range
import tensorflow as tf
import config
import data
def unpack_data(problem_instance):
"""Map data.ProblemInstance to a list of Tensors, to process with map_fn."""
if isinstance(problem_instance, data.ProblemInstance):
return list(problem_instance)
return problem_instance
def copy_checkpoint(checkpoint_path, global_step, accuracy):
"""Copies the checkpoint to a separate directory."""
tmp_checkpoint_path = os.path.join(checkpoint_path, "tmp_best_checkpoint")
best_checkpoint_path = os.path.join(checkpoint_path, "best_checkpoint")
if _is_previous_accuracy_better(best_checkpoint_path, accuracy):
tf.logging.info("Not copying the checkpoint: there is a better one from "
"before a preemption.")
return
checkpoint_regex = os.path.join(checkpoint_path,
"model.ckpt-{}.*".format(global_step))
checkpoint_files = tf.gfile.Glob(checkpoint_regex)
graph_file = os.path.join(checkpoint_path, "graph.pbtxt")
checkpoint_files.append(graph_file)
_save_files_in_tmp_directory(tmp_checkpoint_path, checkpoint_files, accuracy)
new_checkpoint_index_file = os.path.join(tmp_checkpoint_path, "checkpoint")
with tf.gfile.Open(new_checkpoint_index_file, "w") as f:
f.write("model_checkpoint_path: \"{}/model.ckpt-{}\"\n".format(
best_checkpoint_path, global_step))
# We first copy the better checkpoint to a temporary directory, and only
# when it's created move it to avoid inconsistent state when job is preempted
# when copying the checkpoint.
if tf.gfile.Exists(best_checkpoint_path):
tf.gfile.DeleteRecursively(best_checkpoint_path)
tf.gfile.Rename(tmp_checkpoint_path, best_checkpoint_path)
tf.logging.info("Copied new best checkpoint with accuracy %.5f", accuracy)
def _save_files_in_tmp_directory(tmp_checkpoint_path, checkpoint_files,
accuracy):
"""Saves the checkpoint files and accuracy in a temporary directory."""
if tf.gfile.Exists(tmp_checkpoint_path):
tf.logging.info("The temporary directory exists, because job was preempted "
"before it managed to move it. We're removing it.")
tf.gfile.DeleteRecursively(tmp_checkpoint_path)
tf.gfile.MkDir(tmp_checkpoint_path)
def dump_in_best_checkpoint_path(obj, filename):
full_path = os.path.join(tmp_checkpoint_path, filename)
with tf.gfile.Open(full_path, "wb") as f:
pickle.dump(obj, f)
for file_ in checkpoint_files:
just_filename = file_.split("/")[-1]
tf.gfile.Copy(
file_,
os.path.join(tmp_checkpoint_path, just_filename),
overwrite=False)
dump_in_best_checkpoint_path(config.get_inner_model_config(), "inner_config")
dump_in_best_checkpoint_path(config.get_outer_model_config(), "outer_config")
dump_in_best_checkpoint_path(accuracy, "accuracy")
def _is_previous_accuracy_better(best_checkpoint_path, accuracy):
if not tf.gfile.Exists(best_checkpoint_path):
return False
previous_accuracy_file = os.path.join(best_checkpoint_path, "accuracy")
with tf.gfile.Open(previous_accuracy_file, "rb") as f:
previous_accuracy = pickle.load(f)
return previous_accuracy > accuracy
def evaluate_and_average(session, tensor, num_estimates):
tensor_value_estimates = [session.run(tensor) for _ in range(num_estimates)]
average_tensor_value = sum(tensor_value_estimates) / num_estimates
return average_tensor_value