-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
34 lines (25 loc) · 925 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import numpy as np
from utils import prepare_env
class Model:
def __init__(self, metadata):
root_path, models_path = prepare_env()
self.metadata = metadata
self.done_training = False
from sol.meta_model import MetaModel
self.model = MetaModel(metadata['class_num'], models_path, root_path)
print(metadata)
def train(self, train_dataset, remaining_time_budget=None):
X, y = train_dataset
if remaining_time_budget < 60:
self.done_training = True
r = self.model.train_for_budget(X, y)
if r:
self.done_training = True
def test(self, X, remaining_time_budget=None):
results = np.zeros(
(self.metadata['test_num'], self.metadata['class_num'])
)
y_pred = self.model.predict(X)
results[:y_pred.shape[0], :y_pred.shape[1]] = y_pred
return results