-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmixture.coco
137 lines (114 loc) · 4.77 KB
/
mixture.coco
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
The mixture backend. Lets you specify a distribution over different possible algorithms.
"""
from bbopt import constants
from bbopt.util import convert_match_errors
from bbopt.registry import alg_registry
from bbopt.backends.util import (
Backend,
get_backend,
get_cum_probs_for,
random_from_cum_probs,
)
# Backend:
class MixtureBackend(Backend):
"""Mixture backend. Takes in a distribution over different possible algorithms
of the form [(algorithm, weight)]. The properties selected_alg and selected_backend
can be used to retrieve which alg/backend is currently being used."""
backend_name = "mixture"
request_backend_store = True
remove_erroring_algs = None
@override
@convert_match_errors
match def attempt_update(self, examples, params, distribution, remove_erroring_algs=False, *, _backend_store):
"""Special method that allows fast updating of the backend."""
self.use_distribution(distribution, force=remove_erroring_algs != self.remove_erroring_algs)
self.examples = examples
self.params = params
self.remove_erroring_algs = remove_erroring_algs
self.backend_store = _backend_store
self.select_new_backend()
return True
def use_distribution(self, distribution, force=False):
"""Set the distribution to the given distribution."""
distribution = tuple(
(alg, weight() if callable(weight) else weight)
for alg, weight in distribution
)
if force or distribution != self.distribution:
self.cum_probs = get_cum_probs_for(distribution)
self.distribution = distribution
def select_new_backend(self):
"""Randomly select a new backend."""
# randomly select algorithm
self.selected_alg = random_from_cum_probs(self.cum_probs)
if self.selected_alg is None:
raise ValueError(f"could not select backend from distribution: {self.distribution}")
# initialize backend
self.selected_backend, options = alg_registry[self.selected_alg]
try:
self.current_backend = get_backend(
self.backend_store,
self.selected_backend,
self.examples,
self.params,
**options,
)
except constants.erroring_backend_errs:
if not self.remove_erroring_algs:
raise
self.reselect_backend()
def reselect_backend(self):
"""Choose a new backend when the current one errors."""
new_distribution = []
for alg, weight in self.distribution:
if alg != self.selected_alg:
new_distribution.append((alg, weight))
self.cum_probs = get_cum_probs_for(new_distribution)
self.select_new_backend()
@override
def param(self, name, func, *args, **kwargs):
"""Defer parameter selection to the selected backend."""
try:
return self.current_backend.param(name, func, *args, **kwargs)
except constants.erroring_backend_errs:
if not self.remove_erroring_algs:
raise
self.reselect_backend()
return self.param(name, func, *args, **kwargs)
@classmethod
def register_safe_alg_for(cls, base_alg, new_alg_name=None, fallback_alg=None):
"""Register a version of base_alg that defaults to the fallback if base_alg fails."""
if new_alg_name is None:
new_alg_name = "safe_" + base_alg
if fallback_alg is None:
fallback_alg = constants.safe_fallback_alg
cls.register_alg(
new_alg_name,
distribution=(
(base_alg, float("inf")),
(fallback_alg, 1),
),
remove_erroring_algs=True,
)
@classmethod
def register_epsilon_exploration_alg_for(cls, base_alg, new_alg_name=None, eps=None):
"""Register a version of base_alg with epsilon greedy exploration."""
if new_alg_name is None:
new_alg_name = "epsilon_" + base_alg
cls.register_alg(
new_alg_name,
distribution=(
# we defer evaluation here so that constants.eps_greedy_explore_prob
# can be modified to change the epsilon
(base_alg, -> 1 - (eps ?? constants.eps_greedy_explore_prob)),
("random", -> eps ?? constants.eps_greedy_explore_prob),
),
)
# Registered names:
MixtureBackend.register()
MixtureBackend.register_epsilon_exploration_alg_for("max_greedy")
MixtureBackend.register_safe_alg_for("gaussian_process")
MixtureBackend.register_safe_alg_for("random_forest")
MixtureBackend.register_safe_alg_for("extra_trees")
MixtureBackend.register_safe_alg_for("gradient_boosted_regression_trees")