-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
29 lines (25 loc) · 910 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from utils import *
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from model.Model import *
from Dataset import *
import json
from collections import namedtuple
def train(args):
nmt_translator = NMTTranslator(args)
dm = LangDataModule(args)
early_stop_callback = EarlyStopping(monitor='val_loss', patience=args.patience, mode='min', strict=False, verbose=True)
trainer_args = {
'max_epochs' : args.max_epochs,
'val_check_interval':args.val_check_interval, 'callbacks' : [early_stop_callback]
}
if torch.cuda.is_available():
trainer_args['gpus'] = -1
trainer = pl.Trainer(**trainer_args)
trainer.fit(nmt_translator, dm)
if __name__ == '__main__':
args = open('config.json').read()
args = json.loads(args)
args = namedtuple("args", args.keys())(*args.values())
train(args)