forked from JasonLee-22/TR-Rules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrapher.py
143 lines (111 loc) · 4.68 KB
/
grapher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import json
import numpy as np
import scipy.sparse as sp
from tqdm import tqdm
import os
class Grapher(object):
def __init__(self, dataset_dir):
"""
Store information about the graph (train/valid/test set).
Add corresponding inverse quadruples to the data.
Parameters:
dataset_dir (str): path to the graph dataset directory
Returns:
None
"""
self.dataset_dir = dataset_dir
self.entity2id = json.load(open(dataset_dir + "entity2id.json"))
self.relation2id_old = json.load(open(dataset_dir + "relation2id.json"))
self.relation2id = self.relation2id_old.copy()
counter = len(self.relation2id_old)
for relation in self.relation2id_old:
self.relation2id["_" + relation] = counter # Inverse relation
counter += 1
self.ts2id = json.load(open(dataset_dir + "ts2id.json"))
self.id2entity = dict([(v, k) for k, v in self.entity2id.items()])
self.id2relation = dict([(v, k) for k, v in self.relation2id.items()])
self.id2ts = dict([(v, k) for k, v in self.ts2id.items()])
self.inv_relation_id = dict()
num_relations = len(self.relation2id_old)
for i in range(num_relations):
self.inv_relation_id[i] = i + num_relations
for i in range(num_relations, num_relations * 2):
self.inv_relation_id[i] = i % num_relations
self.train_idx = self.create_store("train.txt")
self.valid_idx = self.create_store("valid.txt")
self.test_idx = self.create_store("test.txt")
self.all_idx = np.vstack((self.train_idx, self.valid_idx, self.test_idx))
self.num_e = len(self.entity2id)
self.num_r = len(self.inv_relation_id)
self.train_times = list(set(self.train_idx[:, 3]))
print("Grapher initialized.")
def create_store(self, file):
"""
Store the quadruples from the file as indices.
The quadruples in the file should be in the format "subject\trelation\tobject\ttimestamp\n".
Parameters:
file (str): file name
Returns:
store_idx (np.ndarray): indices of quadruples
"""
with open(self.dataset_dir + file, "r", encoding="utf-8") as f:
quads = f.readlines()
store = self.split_quads(quads)
store_idx = self.map_to_idx(store)
store_idx = self.add_inverses(store_idx)
return store_idx
def split_quads(self, quads):
"""
Split quadruples into a list of strings.
Parameters:
quads (list): list of quadruples
Each quadruple has the form "subject\trelation\tobject\ttimestamp\n".
Returns:
split_q (list): list of quadruples
Each quadruple has the form [subject, relation, object, timestamp].
"""
split_q = []
for quad in quads:
split_q.append(quad[:-1].split("\t"))
return split_q
def map_to_idx(self, quads):
"""
Map quadruples to their indices.
Parameters:
quads (list): list of quadruples
Each quadruple has the form [subject, relation, object, timestamp].
Returns:
quads (np.ndarray): indices of quadruples
"""
subs = [self.entity2id[x[0]] for x in quads]
rels = [self.relation2id[x[1]] for x in quads]
objs = [self.entity2id[x[2]] for x in quads]
tss = [self.ts2id[x[3]] for x in quads]
quads = np.column_stack((subs, rels, objs, tss))
return quads
def add_inverses(self, quads_idx):
"""
Add the inverses of the quadruples as indices.
Parameters:
quads_idx (np.ndarray): indices of quadruples
Returns:
quads_idx (np.ndarray): indices of quadruples along with the indices of their inverses
"""
subs = quads_idx[:, 2]
rels = [self.inv_relation_id[x] for x in quads_idx[:, 1]]
objs = quads_idx[:, 0]
tss = quads_idx[:, 3]
inv_quads_idx = np.column_stack((subs, rels, objs, tss))
quads_idx = np.vstack((quads_idx, inv_quads_idx))
return quads_idx
def get_total_number(self, inPath, fileName):
with open(os.path.join(inPath, fileName), 'r') as fr:
for line in fr:
line_split = line.split()
return int(line_split[0]), int(line_split[1])
def get_data_with_t(self, data, tim):
triples = [[quad[0], quad[1], quad[2]] for quad in data if quad[3] == tim]
return np.array(triples)
def mkdirs(self, path):
if not os.path.exists(path):
os.makedirs(path)