-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_solver.py
74 lines (60 loc) · 1.98 KB
/
train_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from __future__ import division
from __future__ import print_function
import datetime
import json
import logging
import os
import pickle
import time
import numpy as np
import optimizers
# import torch
from config import parser
from models.base_models import LPModel
from utils.data_utils import load_data
from utils.train_utils import get_dir_name, format_metrics, assign_gpus
from solver import Solver
def train(args):
np.random.seed(args.seed)
logging.getLogger().setLevel(logging.INFO)
if args.save:
if not args.save_dir:
dt = datetime.datetime.now()
date = f"{dt.year}_{dt.month}_{dt.day}"
if args.node_cluster == 1:
task = 'nc'
else:
task = 'lp'
models_dir = os.path.join(os.environ['LOG_DIR'], task, date)
save_dir = get_dir_name(models_dir)
else:
save_dir = args.save_dir
logging.basicConfig(level=logging.INFO,
handlers=[
logging.FileHandler(os.path.join(save_dir, 'log.txt')),
logging.StreamHandler()
])
logging.info(f"Logging model in {save_dir}")
args.save_dir = save_dir
if args.node_cluster == 1:
### NOTE : node clustering use full edge
args.val_prop = 0.0
args.test_prop = 0.0
import pprint
args_info_pprint = pprint.pformat(vars(args))
logging.info(args_info_pprint)
# Load data
logging.info("Loading Data : {}".format(args.dataset))
t_load = time.time()
data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset))
st0 = np.random.get_state()
args.np_seed = st0
t_load = time.time() - t_load
logging.info(data['info'])
logging.info('Loading data took time: {:.4f}s'.format(t_load))
sol = Solver(args, data)
sol.fit()
sol.eval()
if __name__ == '__main__':
args = parser.parse_args()
train(args)