-
Notifications
You must be signed in to change notification settings - Fork 719
/
keras_utils.py
79 lines (66 loc) · 2.58 KB
/
keras_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import defaultdict
import numpy as np
from keras.models import save_model
import tensorflow as tf
import keras
from keras import backend as K
import tqdm_utils
class TqdmProgressCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
self.epochs = self.params['epochs']
def on_epoch_begin(self, epoch, logs=None):
print('\nEpoch %d/%d' % (epoch + 1, self.epochs))
if "steps" in self.params:
self.use_steps = True
self.target = self.params['steps']
else:
self.use_steps = False
self.target = self.params['samples']
self.prog_bar = tqdm_utils.tqdm_notebook_failsafe(total=self.target)
self.log_values_by_metric = defaultdict(list)
def _set_prog_bar_desc(self, logs):
for k in self.params['metrics']:
if k in logs:
self.log_values_by_metric[k].append(logs[k])
desc = "; ".join("{0}: {1:.4f}".format(k, np.mean(values)) for k, values in self.log_values_by_metric.items())
if hasattr(self.prog_bar, "set_description_str"): # for new tqdm versions
self.prog_bar.set_description_str(desc)
else:
self.prog_bar.set_description(desc)
def on_batch_end(self, batch, logs=None):
logs = logs or {}
if self.use_steps:
self.prog_bar.update(1)
else:
batch_size = logs.get('size', 0)
self.prog_bar.update(batch_size)
self._set_prog_bar_desc(logs)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self._set_prog_bar_desc(logs)
self.prog_bar.update(1) # workaround to show description
self.prog_bar.close()
class ModelSaveCallback(keras.callbacks.Callback):
def __init__(self, file_name):
super(ModelSaveCallback, self).__init__()
self.file_name = file_name
def on_epoch_end(self, epoch, logs=None):
model_filename = self.file_name.format(epoch)
save_model(self.model, model_filename)
print("Model saved in {}".format(model_filename))
# !!! remember to clear session/graph if you rebuild your graph to avoid out-of-memory errors !!!
def reset_tf_session():
curr_session = tf.get_default_session()
# close current session
if curr_session is not None:
curr_session.close()
# reset graph
K.clear_session()
# create new session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
s = tf.InteractiveSession(config=config)
K.set_session(s)
return s