-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimizer.js
170 lines (127 loc) · 6.58 KB
/
optimizer.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
'use strict';
var equal = require('deep-equal');
var difference = require('lodash/difference');
var math = require('mathjs');
var util = require('./util');
function Optimizer (domain, modelsDomains, mean=null, kernel=null, delays=null, strategy='ei') {
this.domain = domain;
this.modelsDomains = modelsDomains;
this.delays = delays;
if (this.delays === null) {
this.delays = math.ones(domain.length);
}
this.modelsSamples = {};
for (var model in modelsDomains) {
this.modelsSamples[model] = math.matrix([]);
}
this.allSamples = math.matrix([]);
this.allSamplesDelays = math.matrix([]);
this.observedValues = {};
this.best = null;
if (mean === null) {
mean = math.zeros(this.domain.length)
} else if (Array.isArray(mean)) {
mean = math.matrix(mean)
}
if (kernel === null) {
kernel = math.eye(this.domain.length)
} else if (Array.isArray(kernel)) {
kernel = math.matrix(kernel)
}
this.mean = mean;
this.kernel = kernel;
this.strategy = strategy;
};
Optimizer.prototype.addSample = function (point, value, delay=1.0) {
var pointIndex = this.domain.findIndex((x) => equal(x, point));
for (var model in this.modelsDomains) {
if (this.modelsDomains[model].findIndex((x) => equal(x, point)) >= 0) {
this.modelsSamples[model] = math.concat(this.modelsSamples[model], [pointIndex]);
}
}
this.allSamples = math.concat(this.allSamples, [pointIndex]);
this.allSamplesDelays = math.concat(this.allSamplesDelays, [delay]);
this.observedValues[point] = value;
if (this.best === null || this.observedValues[this.best] < value) {
this.best = point;
}
};
Optimizer.prototype.getNextPoint = function (excludeModels=[]) {
var domainIndices = Array.from(new Array(this.domain.length), (x,i) => i);
// If allSamples contains samples from the whole domain, then we will skip the posterior calculation step.
if (difference(domainIndices, this.allSamples).length === 0) {
var posteriorMean = math.matrix(Array.from(this.domain, (x) => this.observedValues[x]));
var posteriorStd = math.zeros(posteriorMean.size());
var expectedImprov = math.zeros(posteriorMean.size());
} else {
// Compute best rewards for each model.
var modelsBestRewards = {};
for (var model in this.modelsSamples) {
modelsBestRewards[model] = Math.max.apply(null, Array.from(this.modelsSamples[model].toArray(), (x) => this.observedValues[this.domain[x]]));
}
// Compute posterior distribution (mean and standard deviation).
var domainSize = this.mean.size()[0];
var sampleSize = this.allSamples.size()[0];
var sampleRewards = math.matrix(Array.from(this.allSamples.toArray(), (x) => this.observedValues[this.domain[x]]));
var samplePriorMean = this.mean.subset(math.index(this.allSamples));
var sampleKernel = this.kernel.subset(math.index(this.allSamples, this.allSamples));
var allToSampleKernel = this.kernel.subset(math.index(math.range(0, domainSize), this.allSamples));
// Sample kernel is sometimes a scalar.
if (typeof(sampleKernel) === 'number') {
sampleKernel = math.matrix([[sampleKernel]]);
}
// Defend against singular matrix inversion.
sampleKernel = math.add(sampleKernel, math.multiply(math.eye(sampleSize), 0.001));
var sampleKernelInv = math.inv(sampleKernel);
var sampleRewardGain = math.reshape(math.subtract(sampleRewards, samplePriorMean), [sampleSize, 1]);
var sampleKernelDotGain = math.multiply(sampleKernelInv, sampleRewardGain);
var posteriorMean = math.add(math.multiply(allToSampleKernel, sampleKernelDotGain), math.reshape(this.mean, [domainSize, 1]));
var posteriorKernel = math.multiply(allToSampleKernel, math.multiply(sampleKernelInv, math.transpose(allToSampleKernel)));
posteriorKernel = math.subtract(this.kernel, posteriorKernel);
this.posteriorMean = posteriorMean;
var posteriorStd = math.sqrt(math.diag(posteriorKernel).reshape([domainSize, 1]));
// Compute the expected improvement.
var expectedImprov = math.zeros(domainSize);
for (var model in this.modelsDomains) {
var modelPoints = difference(this.modelsDomains[model], this.modelsSamples[model].toArray());
var modelPosteriorMean = posteriorMean.subset(math.index(modelPoints, 0));
var modelPosteriorStd = posteriorStd.subset(math.index(modelPoints, 0));
var modelExpectedImprov = util.expectedImprovement(modelsBestRewards[model], modelPosteriorMean, modelPosteriorStd);
modelExpectedImprov = modelExpectedImprov.reshape([modelPoints.length]);
expectedImprov = expectedImprov.subset(math.index(modelPoints), math.add(expectedImprov.subset(math.index(modelPoints)), modelExpectedImprov));
}
// Rescale EI with delays.
expectedImprov = math.dotDivide(expectedImprov, this.delays);
// Ensure the expected improvement is zero for all observed points.
for (var i = 0; i < this.allSamples.length; i++) {
expectedImprov.set(this.allSamples[i], 0);
}
}
// Determine the model choice based on the specified strategy.
var allModels = Object.keys(this.modelsDomains);
if (this.strategy === 'ei') {
// Exclude some models from the domain if specified.
allModels = difference(allModels, excludeModels);
var excludedDomain = [];
for (var i = 0; i < excludeModels.length; i++) {
excludedDomain = excludedDomain.concat(this.modelsDomains[excludeModels[i]]);
}
var domain = [].concat.apply([], Array.from(allModels, (x) => this.modelsDomains[x]));
} else if (this.strategy === 'rr') {
model = allModels[this.allSamples.length % allModels.length];
var domain = this.modelsDomains[model];
} else if (this.strategy === 'rnd') {
model = math.pickRandom(allModels);
var domain = this.modelsDomains[model];
} else {
throw "Accepted values for strategy are: ei, rr and rnd.";
}
// Sample the point with maximal expected improvement over the given domain.
var idx = util.argmax(math.subset(expectedImprov.toArray(), math.index(domain)));
if (expectedImprov[idx] === 0) {
// If the whole domain has been sampled, then just run the point with the shortest delay.
idx = util.argmax(this.delays.subset(math.index(domain)).toArray());
}
return this.domain[domain[idx]];
};
module.exports = Optimizer;