-
Notifications
You must be signed in to change notification settings - Fork 1.8k
DNGO tuner #3479
DNGO tuner #3479
Changes from 17 commits
a216727
b6e9501
5ec623b
02e3e0b
5ac8213
eada38d
d12c1cb
2a8483e
b906af8
6089a0e
f770cb6
74a5d7b
833f45e
d454288
943e79c
4fc9ad9
4103dca
4ad1966
932fd25
05625b0
2cd3c78
441d6e2
b881f2c
a821c23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .dngo_tuner import DngoTuner, DNGOClassArgsValidator |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
authorName: Ayan Mao | ||
experimentName: arch2vec_dngo_general_mnist | ||
trialConcurrency: 1 | ||
maxExecDuration: 1h | ||
maxTrialNum: 100 | ||
trainingServicePlatform: local # choices: local, remote, pai | ||
searchSpacePath: search_space.json | ||
useAnnotation: false | ||
tuner: | ||
codeDir: nni/examples/tuners/dngo_tuner | ||
classFileName: dngo_tuner.py | ||
className: DngoTuner | ||
# Any parameter need to pass to your tuner class __init__ constructor | ||
# can be specified in this optional classArgs field, for example | ||
trial: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this config should be put into example folder, and should be written in V2 format. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree |
||
codeDir: nni/examples/trials/mnist-pytorch | ||
command: python mnist.py | ||
gpuNum: 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import random | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed it. |
||
import torch | ||
import numpy as np | ||
|
||
from nni import ClassArgsValidator | ||
from nni.tuner import Tuner | ||
import nni.parameter_expressions as parameter_expressions | ||
from torch.distributions import Normal | ||
from pybnn import DNGO | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, will change it this weekend |
||
|
||
def random_archi_generator(nas_ss, random_state): | ||
'''random | ||
''' | ||
chosen_arch = {} | ||
for key, val in nas_ss.items(): | ||
if val['_type'] == 'choice': | ||
choices = val['_value'] | ||
index = random_state.randint(len(choices)) | ||
# check values type | ||
if type(choices[0]) == int or type(choices[0]) == float: | ||
chosen_arch[key] = choices[index] | ||
else: | ||
chosen_arch[key] = index | ||
elif val['_type'] == 'uniform': | ||
chosen_arch[key] = random.uniform(val['_value'][0], val['_value'][1]) | ||
elif val['_type'] == 'randint': | ||
chosen_arch[key] = random_state.randint( | ||
val['_value'][0], val['_value'][1]) | ||
elif val['_type'] == 'quniform': | ||
chosen_arch[key] = parameter_expressions.quniform( | ||
val['_value'][0], val['_value'][1], val['_value'][2], random_state) | ||
elif val['_type'] == 'loguniform': | ||
chosen_arch[key] = parameter_expressions.loguniform( | ||
val['_value'][0], val['_value'][1], random_state) | ||
elif val['_type'] == 'qloguniform': | ||
chosen_arch[key] = parameter_expressions.qloguniform( | ||
val['_value'][0], val['_value'][1], val['_value'][2], random_state) | ||
|
||
else: | ||
raise ValueError('Unknown key %s and value %s' % (key, val)) | ||
return chosen_arch | ||
|
||
class DngoTuner(Tuner): | ||
|
||
def __init__(self): | ||
|
||
self.searchspace_json = None | ||
self.random_state = None | ||
self.model = DNGO(do_mcmc=False) | ||
self.first_flag = True | ||
self.x = [] | ||
self.y = [] | ||
|
||
|
||
def receive_trial_result(self, parameter_id, parameters, value, **kwargs): | ||
''' | ||
Receive trial's final result. | ||
parameter_id: int | ||
parameters: object created by 'generate_parameters()' | ||
value: final metrics of the trial, including default metric | ||
''' | ||
# update DNGO model | ||
self.y.append(value) | ||
|
||
def generate_parameters(self, parameter_id, **kwargs): | ||
''' | ||
Returns a set of trial (hyper-)parameters, as a serializable object | ||
parameter_id: int | ||
''' | ||
if self.first_flag: | ||
self.first_flag = False | ||
first_x = random_archi_generator(self.searchspace_json, self.random_state) | ||
self.x.append(list(first_x.values())) | ||
return first_x | ||
|
||
self.model.train(np.array(self.x), np.array(self.y), do_optimize=True) | ||
# random samples | ||
candidate_x = [] | ||
for _ in range(1000): | ||
a = random_archi_generator(self.searchspace_json, self.random_state) | ||
candidate_x.append(a) | ||
|
||
x_test = np.array([np.array(list(xi.values())) for xi in candidate_x]) | ||
m, v = self.model.predict(x_test) | ||
mean = torch.Tensor(m) | ||
sigma = torch.Tensor(v) | ||
# u = (mean - torch.Tensor([args.objective]).expand_as(mean)) / sigma | ||
u = (mean - torch.Tensor([0.95]).expand_as(mean)) / sigma | ||
normal = Normal(torch.zeros_like(u), torch.ones_like(u)) | ||
ucdf = normal.cdf(u) | ||
updf = torch.exp(normal.log_prob(u)) | ||
ei = sigma * (updf + u * ucdf) | ||
|
||
indices = torch.argsort(ei) | ||
rev_indices = reversed(indices) | ||
ind = rev_indices[0].item() | ||
new_x = candidate_x[ind] | ||
self.x.append(list(new_x.values())) | ||
|
||
return new_x | ||
|
||
|
||
def update_search_space(self, search_space): | ||
''' | ||
Tuners are advised to support updating search space at run-time. | ||
If a tuner can only set search space once before generating first hyper-parameters, | ||
it should explicitly document this behaviour. | ||
search_space: JSON object created by experiment owner | ||
''' | ||
# your code implements here. | ||
self.searchspace_json = search_space | ||
self.random_state = np.random.RandomState() | ||
|
||
# DNGO tuner do not have much input arg, so the validation is actually harly used | ||
class DNGOClassArgsValidator(ClassArgsValidator): | ||
def validate_class_args(self, **kwargs): | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file should also be put into examples. |
||
"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]}, | ||
"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]}, | ||
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]}, | ||
"momentum":{"_type":"uniform","_value":[0, 1]} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does DNGO tuner works on string choices for example There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, technically it supports any string choices. But it's quite meaningless. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I don't think it's handled in your code. Have you tested? |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,3 +76,7 @@ tuners: | |
classArgsValidator: nni.algorithms.hpo.regularized_evolution_tuner.EvolutionClassArgsValidator | ||
className: nni.algorithms.hpo.regularized_evolution_tuner.RegularizedEvolutionTuner | ||
source: nni | ||
- builtinName: DNGOuner | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrong spelling. |
||
classArgsValidator: nni.algorithms.hpo.dngo_tuner.DNGOClassArgsValidator | ||
className: nni.algorithms.hpo.dngo_tuner.DNGOTuner | ||
source: nni |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove your name.