-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
47 lines (36 loc) · 1.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import torch
from omegaconf import OmegaConf
import wandb
import argparse
from train import *
from inference import test
from utils.util import set_seed
import os
import torch
from inference import *
import wandb
from arguments import cfg, training_args
from train import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "false"
def main():
torch.cuda.empty_cache()
## set sedd
set_seed(training_args.train_args.seed)
## train
if cfg.train.train_mode:
## wandb login
wandb.login()
wandb.init(project=cfg.wandb.project_name, entity=cfg.wandb.entity, name=cfg.wandb.exp_name)
print('---------------------- train start -------------------------')
train()
## wandb finish
wandb.finish()
## inference
if cfg.test.test_mode:
print('--------------------- test start ----------------------')
test()
print('----------------- Finish! ---------------------')
if __name__ == '__main__':
main()