-
Notifications
You must be signed in to change notification settings - Fork 0
/
hpo.py
74 lines (58 loc) · 2.11 KB
/
hpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
import pickle
import mlflow
import numpy as np
from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from hyperopt.pyll import scope
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("random-forest-hyperopt")
def load_pickle(filename):
with open(filename, "rb") as f_in:
return pickle.load(f_in)
def run(data_path, num_trials):
X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
X_valid, y_valid = load_pickle(os.path.join(data_path, "valid.pkl"))
def objective(params):
with mlflow.start_run():
mlflow.set_tag("model", "random-forest")
mlflow.log_params(params)
rf = RandomForestRegressor(**params)
rf.fit(X_train, y_train)
y_pred = rf.predict(X_valid)
rmse = mean_squared_error(y_valid, y_pred, squared=False)
mlflow.log_metric("rmse", rmse)
return {'loss': rmse, 'status': STATUS_OK}
search_space = {
'max_depth': scope.int(hp.quniform('max_depth', 1, 20, 1)),
'n_estimators': scope.int(hp.quniform('n_estimators', 10, 50, 1)),
'min_samples_split': scope.int(hp.quniform('min_samples_split', 2, 10, 1)),
'min_samples_leaf': scope.int(hp.quniform('min_samples_leaf', 1, 4, 1)),
'random_state': 42
}
rstate = np.random.default_rng(42) # for reproducible results
fmin(
fn=objective,
space=search_space,
algo=tpe.suggest,
max_evals=num_trials,
trials=Trials(),
rstate=rstate
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_path",
default="./output",
help="the location where the processed NYC taxi trip data was saved."
)
parser.add_argument(
"--max_evals",
type=int,
default=50,
help="the number of parameter evaluations for the optimizer to explore."
)
args = parser.parse_args()
run(args.data_path, args.max_evals)