-
Notifications
You must be signed in to change notification settings - Fork 14
/
split.py
78 lines (65 loc) · 2.91 KB
/
split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""
import os
import argparse
import datetime
from tqdm import tqdm
from dataset import normalizer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--tag', type=str, default='full')
parser.add_argument('--train_start', type=str, default='2006-03-01 00:00:00')
parser.add_argument('--train_end', type=str, default='2006-05-18 00:00:00')
parser.add_argument('--valid_start', type=str, default='2006-05-18 00:00:00')
parser.add_argument('--valid_end', type=str, default='2006-05-25 00:00:00')
parser.add_argument('--test_start', type=str, default='2006-05-25 00:00:00')
parser.add_argument('--test_end', type=str, default='2006-06-01 00:00:00')
args = parser.parse_args()
return args
def main(args):
splits = ['train', 'valid', 'test']
columns = ['uid', 'query', 'time']
fmt = '%Y-%m-%d %H:%M:%S'
print(f"Split original data into data/aol/{args.tag}")
itv = {s: tuple(vars(args)[f"{s}_{i}"] for i in ['start', 'end']) for s in splits}
for s in splits:
print(f" {s:5s} data: from {itv[s][0]} until {itv[s][1]}")
itv = {k: tuple(datetime.datetime.strptime(x, fmt) for x in v) for k, v in itv.items()}
valid = (itv['train'][0] < itv['train'][1] <= itv['valid'][0] < itv['valid'][1] <= itv['test'][0] < itv['test'][1])
assert valid, "Invalid time intervals"
# make directory and open files to write
target_dir = f"data/aol/{args.tag}"
os.makedirs(target_dir, exist_ok=True)
f = {s: {column: open(os.path.join(target_dir, f"{s}.{column}.txt"), 'w') for column in columns} for s in splits}
# read original AOL query log dataset and write data into files
print("")
cnt = {s: 0 for s in splits}
for i in range(1, 11):
filename = f"user-ct-test-collection-{i:02d}.txt"
print(f"Reading {filename}...")
f_org = open(os.path.join("data/aol/org", filename))
f_org.readline()
prev = {column: '' for column in columns}
for line in tqdm(f_org):
data = {column: v for column, v in zip(columns, line.strip().split('\t')[:3])}
# normalize queries
data['query'] = normalizer(data['query'])
# filter out too short queries and redundant queries
# data['query'] == '-'
if len(data['query']) < 3 or (data['uid'], data['query']) == (prev['uid'], prev['query']):
continue
t = datetime.datetime.strptime(data['time'], fmt)
for s in splits:
if itv[s][0] <= t < itv[s][1]:
cnt[s] += 1
for column in columns:
f[s][column].write(data[column] + '\n')
prev = data
# print total number of data in each split
print("")
for s in splits:
print(f"Number of {s:5s} data: {cnt[s]:8d}")
if __name__ == "__main__":
main(get_args())