-
Notifications
You must be signed in to change notification settings - Fork 41
/
trainer.py
146 lines (118 loc) · 4.88 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import dill
import json
import multiprocessing
import os
import random
import shutil
import sys
import time
dill.settings['recurse'] = True
class HookedStdout:
original_stdout = sys.stdout
def __init__(self, filename, stdout=None) -> None:
if stdout is None:
self.stdout = self.original_stdout
else:
self.stdout = stdout
self.file = open(filename, 'w')
def write(self, data):
self.stdout.write(data)
self.file.write(data)
def flush(self):
self.stdout.flush()
self.file.flush()
def train_process(data, save_path, device, seed):
hooked = HookedStdout(f"{save_path}/log.txt")
sys.stdout = hooked
sys.stderr = HookedStdout(f"{save_path}/logerr.txt", sys.stderr)
import torch
import deepxde as dde
torch.cuda.set_device(device)
torch.set_default_tensor_type(torch.cuda.FloatTensor)
dde.config.set_default_float('float32')
dde.config.set_random_seed(seed)
get_model, train_args = dill.loads(data)
model = get_model()
model.train(**train_args, model_save_path=save_path)
class Trainer:
def __init__(self, exp_name, device) -> None:
self.exp_name = exp_name
self.device = device.split(",")
self.repeat = 1
self.tasks = []
def set_repeat(self, repeat):
self.repeat = repeat
def add_task(self, get_model, train_args):
data = dill.dumps((get_model, train_args))
self.tasks.append((data, train_args))
def setup(self, filename, seed):
os.makedirs(f"runs/{self.exp_name}", exist_ok=True)
shutil.copy(filename, f"runs/{self.exp_name}/script.py.bak")
json.dump({"seed": seed, "task": self.tasks}, open(f"runs/{self.exp_name}/config.json", 'w'), indent=4, default=lambda _: "...")
def train_all(self):
if len(self.device) > 1:
return self.train_all_parallel()
# no multi-processing when only one device is available
import torch
import deepxde as dde
if self.device[0] != 'cpu':
device = "cuda:" + self.device[0]
torch.cuda.set_device(device)
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
dde.config.set_default_float('float32')
for j in range(self.repeat):
for i, (data, _) in enumerate(self.tasks):
seed = random.randint(0, 10**9)
save_path = f"runs/{self.exp_name}/{i}-{j}"
os.makedirs(save_path, exist_ok=True)
hooked = HookedStdout(f"{save_path}/log.txt")
sys.stdout = hooked
sys.stderr = HookedStdout(f"{save_path}/logerr.txt", sys.stderr)
dde.config.set_random_seed(seed)
print(f"***** Begin #{i}-{j} *****")
get_model, train_args = dill.loads(data)
model = get_model()
model.train(**train_args, model_save_path=save_path)
print(f"***** End #{i}-{j} *****")
def train_all_parallel(self):
# maintain a pool of processes
# do not start all processes at the same time
# keep the number of processes equal to the number of devices
# if a process is done, start a new one on the same device
multiprocessing.set_start_method('spawn')
processes = [None] * len(self.device)
for j in range(self.repeat):
for i, (data, _) in enumerate(self.tasks):
# find a free device
for k, p in enumerate(processes):
if p is None:
device = "cuda:" + self.device[k]
seed = random.randint(0, 10**9)
save_path = f"runs/{self.exp_name}/{i}-{j}"
os.makedirs(save_path)
print(f"***** Start #{i}-{j} *****")
p = multiprocessing.Process(target=train_process, args=(data, save_path, device, seed), daemon=True)
p.start()
processes[k] = p
break
else:
raise RuntimeError("No free device")
# wait for a process to finish
while True:
for k, p in enumerate(processes):
if p is None or not p.is_alive():
# free device
processes[k] = None
break
else:
time.sleep(5)
continue
break
for p in processes:
if p is not None:
p.join()
def summary(self):
from src.utils import summary
summary.summary(f"runs/{self.exp_name}", len(self.tasks), self.repeat, list(map(lambda t:t[1]['iterations'], self.tasks)))