forked from biomed-AI/CMPRY
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
36 lines (30 loc) · 1.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import argparse
import torch
import numpy as np
from src import yp_data_preprocess, gnn_train
def print_setting(args):
print('\n===========================')
for k, v, in args.__dict__.items():
print('%s: %s' % (k, v))
print('===========================\n')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=5, help='the index of gpu device')
parser.add_argument('--test', type=int, default=1, help='0: 5-fold cross validation; 1: test on external')
parser.add_argument('--pretrained', type=str, default='./pretrain/saved/cmpnn_1024', help='pretrained model path')
parser.add_argument('--epochs', type=int, default=5000, help='number of epochs')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--decay', type=float, default=0.08, help='decay')
parser.add_argument('--seed', type=int, default=0, help='seed')
parser.add_argument('--save', type=bool, default=False, help='save model')
args = parser.parse_args()
print_setting(args)
np.random.seed(args.seed)
data = yp_data_preprocess.load_data(args, 'test')
if args.test == 0:
gnn_train.train(args, data)
else:
gnn_train.test_external(args, data)
if __name__ == '__main__':
main()