-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_catalyst.py
42 lines (33 loc) · 1.1 KB
/
train_catalyst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import catalyst
from catalyst.dl import SupervisedRunner
from catalyst.dl.callbacks import AccuracyCallback
from share_funcs import get_model, get_loaders, get_criterion, get_optimizer
def main():
epochs = 5
num_class = 10
output_path = './output/catalyst'
# Use if you want to fix seed
# catalyst.utils.set_global_seed(42)
# catalyst.utils.prepare_cudnn(deterministic=True)
model = get_model()
train_loader, val_loader = get_loaders()
loaders = {"train": train_loader, "valid": val_loader}
optimizer, lr_scheduler = get_optimizer(model=model)
criterion = get_criterion()
runner = SupervisedRunner(device=catalyst.utils.get_device())
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=lr_scheduler,
loaders=loaders,
logdir=output_path,
callbacks=[AccuracyCallback(num_classes=num_class, accuracy_args=[1])],
num_epochs=epochs,
main_metric="accuracy01",
minimize_metric=False,
fp16=None,
verbose=True
)
if __name__ == '__main__':
main()