-
Notifications
You must be signed in to change notification settings - Fork 3
/
inference.py
129 lines (101 loc) · 4.92 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import os
import torch
from torch.utils.data import DataLoader, SequentialSampler
from datasets import SASRecDataset
from sasrec_models import S3RecModel
from trainers import FinetuneTrainer
from sasrec_utils import (
check_path,
generate_submission_file,
get_item2attribute_json,
get_user_seqs,
set_seed,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default="../data/train/", type=str)
parser.add_argument("--output_dir", default="output/", type=str)
parser.add_argument("--data_name", default="Ml", type=str)
parser.add_argument("--do_eval", action="store_true")
# model args
parser.add_argument(
"--hidden_size", type=int, default=128, help="hidden size of transformer model"
)
parser.add_argument(
"--num_hidden_layers", type=int, default=2, help="number of layers"
)
parser.add_argument("--num_attention_heads", default=2, type=int)
parser.add_argument("--hidden_act", default="gelu", type=str) # gelu relu
parser.add_argument(
"--attention_probs_dropout_prob",
type=float,
default=0.5,
help="attention dropout p",
)
parser.add_argument(
"--hidden_dropout_prob", type=float, default=0.5, help="hidden dropout p"
)
parser.add_argument("--initializer_range", type=float, default=0.01)
parser.add_argument("--max_seq_length", default=200, type=int)
# train args
parser.add_argument("--lr", type=float, default=0.001, help="learning rate of adam")
parser.add_argument(
"--batch_size", type=int, default=64, help="number of batch_size"
)
parser.add_argument("--epochs", type=int, default=200, help="number of epochs")
parser.add_argument("--no_cuda", action="store_true")
parser.add_argument("--log_freq", type=int, default=1, help="per epoch print res")
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--mask_p", type=float, default=0.2, help="mask probability")
parser.add_argument("--aap_weight", type=float, default=0.2, help="aap loss weight")
parser.add_argument("--mip_weight", type=float, default=2.0, help="mip loss weight")
parser.add_argument("--map_weight", type=float, default=1.5, help="map loss weight")
parser.add_argument("--sp_weight", type=float, default=0.5, help="sp loss weight")
parser.add_argument(
"--weight_decay", type=float, default=0.0, help="weight_decay of adam"
)
parser.add_argument(
"--adam_beta1", type=float, default=0.95, help="adam first beta value"
)
parser.add_argument(
"--adam_beta2", type=float, default=0.999, help="adam second beta value"
)
parser.add_argument("--gpu_id", type=str, default="0", help="gpu_id")
parser.add_argument("--scheduler", type=str, default="None", help="You don`t need to handle this option.")
parser.add_argument("--model_name", default="Finetune_full", type=str)
parser.add_argument("--tqdm", default=1, type=int, help="option for running tqdm")
args = parser.parse_args()
set_seed(args.seed)
check_path(args.output_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
args.cuda_condition = torch.cuda.is_available() and not args.no_cuda
args.data_file = args.data_dir + "train_ratings.csv"
item2attribute_file = args.data_dir + args.data_name + "_item2attributes.json"
user_seq, max_item, _, _, submission_rating_matrix = get_user_seqs(args.data_file)
item2attribute, attribute_size = get_item2attribute_json(item2attribute_file)
args.item_size = max_item + 2
args.mask_id = max_item + 1
args.attribute_size = attribute_size + 1
# save model args
# args_str = f"{args.model_name}_{args.data_name}_bs_{args.batch_size}_max_seq_len_{args.max_seq_length}_hidden_{args.hidden_size}_aap_{args.aap_weight}_mip_{args.mip_weight}_map_{args.map_weight}"
# args_str = f"{args.model_name}_{args.data_name}_max_seq_len_{args.max_seq_length}_hidden_{args.hidden_size}_beta2_{args.adam_beta2}_attn_drop_{args.attention_probs_dropout_prob}"
args_str = f"pretrain_max_seq_len_{args.max_seq_length}"
print(str(args))
args.item2attribute = item2attribute
args.train_matrix = submission_rating_matrix
checkpoint = args_str + ".pt"
args.checkpoint_path = os.path.join(args.output_dir, checkpoint)
submission_dataset = SASRecDataset(args, user_seq, data_type="submission")
submission_sampler = SequentialSampler(submission_dataset)
submission_dataloader = DataLoader(
submission_dataset, sampler=submission_sampler, batch_size=args.batch_size
)
model = S3RecModel(args=args)
trainer = FinetuneTrainer(model, None, None, None, submission_dataloader, args)
trainer.load(args.checkpoint_path)
print(f"Load model from {args.checkpoint_path} for submission!")
preds = trainer.submission(0)
generate_submission_file(args, args.data_file, preds)
if __name__ == "__main__":
main()