-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun_train.py
45 lines (33 loc) · 1.64 KB
/
run_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# @Time : 2020/7/20
# @Author : Shanlei Mu
# @Email : slmu@ruc.edu.cn
# UPDATE
# @Time : 2020/10/3, 2020/10/1
# @Author : Yupeng Hou, Zihan Lin
# @Email : houyupeng@ruc.edu.cn, zhlin@ruc.edu.cn
import argparse
from recbole.quick_start import run_recbole
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', type=str, default='SASRec', help='name of models')
parser.add_argument('--dataset', '-d', type=str, default='ml-100k', help='name of datasets')
parser.add_argument('--config_files', type=str, default='seq.yaml', help='config files')
parser.add_argument('--method', type=str, default='DuoRec_XAUG', \
help='None, CL4SRec, CL4SRec_XAUG, DuoRec, DuoRec_XAUG, ...')
parser.add_argument('--cl_loss_weight', type=float, default=0.1, help='weight for contrastive loss')
parser.add_argument('--temp_ratio', type=float, default=1.0, help='temperature ratio')
parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
### ours
parser.add_argument('--xai_method', type=str, default='occlusion', help='saliency, occlusion')
args, _ = parser.parse_known_args()
config_dict = {
'neg_sampling': None,
'method': args.method,
'cl_loss_weight': args.cl_loss_weight,
'temp_ratio': args.temp_ratio,
'gpu_id': args.gpu_id,
'xai_method': args.xai_method,
}
config_file_list = args.config_files.strip().split(' ') if args.config_files else None
run_recbole(model=args.model, dataset=args.dataset, method=args.method,
config_file_list=config_file_list, config_dict=config_dict)