-
Notifications
You must be signed in to change notification settings - Fork 26
/
nn_function_caller.py
105 lines (93 loc) · 3.55 KB
/
nn_function_caller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
A function caller to work with MLPs.
-- kandasamy@cs.cmu.edu
"""
# pylint: disable=invalid-name
# pylint: disable=abstract-class-not-used
# pylint: disable=no-member
# pylint: disable=signature-differs
from glob import glob
import shutil
import numpy as np
from time import time, sleep
# Local imports
from opt.function_caller import FunctionCaller, EVAL_ERROR_CODE
from utils.reporters import get_reporter
import tempfile
_DEBUG_ERROR_PROB = 0.1
# _DEBUG_ERROR_PROB = 0.0
class NNFunctionCaller(FunctionCaller):
""" Function Caller for NN evaluations. """
def __init__(self, descr, domain, train_params, debug_mode=False,
debug_function=None, reporter=None, tmp_dir='/tmp'):
""" Constructor for train params. """
super(NNFunctionCaller, self).__init__(None, domain,
opt_pt=None, opt_val=None,
noise_type='none', noise_params=None,
descr=descr)
if debug_mode:
if debug_function is None:
raise ValueError('If in debug mode, debug_function cannot be None.')
self.debug_function = debug_function
self.train_params = train_params
self.debug_mode = debug_mode
self.reporter = get_reporter(reporter)
self.root_tmp_dir = tmp_dir
def eval_single(self, nn, qinfo, noisy=False):
""" Over-rides eval_single. """
qinfo.val = self._func_wrapper(nn, qinfo)
if qinfo.val == EVAL_ERROR_CODE:
self.reporter.writeln(('Error occurred when evaluating %s. Returning ' +
'EVAL_ERROR_CODE: %s.')%(nn, EVAL_ERROR_CODE))
qinfo.true_val = qinfo.val
qinfo.point = nn
return qinfo.val, qinfo
def _func_wrapper(self, nn, qinfo):
""" Evaluates the function here - mostly a wrapper to decide between
the synthetic function vs the real function. """
# pylint: disable=unused-argument
# pylint: disable=bare-except
# pylint: disable=broad-except
if self.debug_mode:
ret = self._eval_synthetic_function(nn, qinfo)
else:
try:
self.tmp_dir = tempfile.mkdtemp(dir=self.root_tmp_dir)
ret = self._eval_validation_score(nn, qinfo)
except Exception as exc:
self.reporter.writeln('Exception when evaluating %s: %s'%(nn, exc))
ret = EVAL_ERROR_CODE
# Write to the file and return
qinfo.val = ret
qinfo.true_val = qinfo.val
qinfo.point = nn
self._write_result_to_file(ret, qinfo.result_file)
try:
shutil.rmtree(self.tmp_dir)
except:
pass
return ret
def _eval_synthetic_function(self, nn, qinfo):
""" Evaluates the synthetic function. """
result = self.debug_function(nn)
np.random.seed(int(time() * 10 * int(qinfo.worker_id + 1)) % 100000)
# sleep_time = 10 + 30 * np.random.random()
sleep_time = 2 + 10 * np.random.random()
# sleep_time = .02 + 0.1 * np.random.random()
sleep(sleep_time)
if np.random.random() < _DEBUG_ERROR_PROB:
# For debugging, return an error code with small probability
return EVAL_ERROR_CODE
else:
return result
def _eval_validation_score(self, qinfo, nn):
""" Evaluates the validation score. """
# Design your API here. You can use self.training_params to store anything
# additional you need.
raise NotImplementedError('Implement this for specific application.')
@classmethod
def _write_result_to_file(cls, result, file_name):
""" Writes the result to the file name. """
file_handle = open(file_name, 'w')
file_handle.write(str(result))
file_handle.close()