-
Notifications
You must be signed in to change notification settings - Fork 0
/
optuna_optim.py
74 lines (58 loc) · 1.97 KB
/
optuna_optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import numpy as np
from argparse import Namespace
import optuna
from optuna.integration import TensorBoardCallback
from main import run_model
def objective(trial):
bs = trial.suggest_categorical("bs", [32, 64, 128, 256])
ns = trial.suggest_int("hidden_size", 50, 5000)
n_layer = trial.suggest_int("n_layer", 1, 5)
lr = trial.suggest_categorical("lr", [1e-1, 1e-2, 1e-3, 1e-4])
hparams = {
"bs": bs,
"lr": lr,
"ns": ns,
"n_layer": n_layer
}
args = Namespace(**{
'store_results': False,
'epochs': 100,
'cbits': 64,
'hparams': hparams
})
np.random.seed(0)
length = 4096
length = 100000
ids = list(range(length))
np.random.shuffle(ids)
train_ids = ids[:int(length * 0.8)]
val_ids = ids[len(train_ids):int(length * 0.9)]
test_ids = ids[len(train_ids) + len(val_ids):]
data_dir = f'data/SwitchableStarPUF/12bit_enumerate_0.csv'
data_dir = f'data/SwitchableStarPUF/64bit_rand100k.csv'
model = run_model(args, data_dir, (train_ids, val_ids, test_ids))
return np.max(model.val_accs)
if __name__ == "__main__":
if not os.path.exists('optim.db'):
study = optuna.create_study(
study_name="optim",
storage="sqlite:///optim.db",
direction="maximize"
)
study = optuna.load_study(
study_name="optim",
storage="sqlite:///optim.db"
)
tensorboard_callback = TensorBoardCallback("logs/", metric_name="value")
study.optimize(objective, n_trials=100, callbacks=[tensorboard_callback])
print("Study statistics: ")
print(" Number of finished trials: ", len(study.trials))
trials = study.best_trials
print(f"Best trials: ({len(trials)})")
for trial in trials:
print(trial.number)
print(f" Val loss:: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))