generated from ZIZUN/pytorch_lightning_template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
60 lines (47 loc) · 1.79 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pytorch_lightning as pl
from util.config import ex
from util.dataset.lightning_dataset import IntentCLSDataModule
from util.model.IntentClsModule import IntentCLSModule
import copy
import os
@ex.automain
def main(_config):
_config = copy.deepcopy(_config)
# Print config
for key, val in _config.items():
key_str = "{}".format(key) + (" " * (30 - len(key)))
print(f"{key_str} = {val}")
pl.seed_everything(_config["seed"])
exp_name = f'{_config["exp_name"]}'
os.makedirs(_config["log_dir"], exist_ok=True)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=5,
verbose=True,
monitor="val/accuracy",
filename='epoch={epoch}-step={step}-val_acc={val/accuracy:.5f}',
mode="max",
save_last=True,
auto_insert_metric_name=False
)
logger = pl.loggers.TensorBoardLogger(
_config["log_dir"],
name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
)
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
callbacks = [checkpoint_callback, lr_callback]
accumulate_grad_batches = max(_config["batch_size"] // (
_config["per_gpu_batch_size"] * len(_config['gpus']) * _config["num_nodes"]
), 1)
dm = IntentCLSDataModule(_config=_config)
model = IntentCLSModule(_config=_config, num_labels=len(dm.train_labels_li))
trainer = pl.Trainer(
gpus=_config['gpus'],
max_steps=_config["max_steps"],
accelerator="ddp",
callbacks=callbacks,
logger=logger,
log_every_n_steps=10,
accumulate_grad_batches=accumulate_grad_batches,
val_check_interval=_config['val_check_interval']
)
trainer.fit(model, datamodule=dm)