-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathlattice.py
91 lines (78 loc) · 2.73 KB
/
lattice.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""`lattice.py` defines:
* Lattice object containing node and edge information
* A in-palce reverse function.
"""
import sys
import numpy as np
class Lattice:
"""Lattice object."""
def __init__(self, path, mean=None, std=None):
"""Lattice object.
Arguments:
path {string} -- absolute path of a pre-processed lattice
Keyword Arguments:
mean {numpy array} -- mean vector of the dataset (default: {None})
std {numpy array} -- standard deviation of the dataset (default:
{None})
"""
self.path = path
self.mean = mean
self.std = std
self.child_dict = None
self.parent_dict = None
self.ignore = []
self.load()
def load(self):
"""Load the pre-processed lattice.
Normalise to zero mean and unit variance if mean and std are provided.
"""
data = np.load(self.path)
self.nodes = list(data['topo_order'])
self.edges = data['edge_data']
self.child_dict = data['child_2_parent'].item()
self.parent_dict = data['parent_2_child'].item()
# Backward compatibility
try:
self.ignore = list(data['ignore'])
except KeyError:
pass
self.node_num = len(self.nodes)
self.edge_num = self.edges.shape[0]
if self.edge_num > 0:
self.feature_dim = self.edges.shape[1]
if self.mean is not None:
if self.mean.shape[1] == self.feature_dim:
self.edges = self.edges - self.mean
else:
print("Dimension of mean vector is inconsistent with data.")
sys.exit(1)
if self.std is not None:
if self.std.shape[1] == self.feature_dim:
self.edges = self.edges / self.std
else:
print("Dimension of std vector is inconsistent with data.")
sys.exit(1)
else:
self.feature_dim = None
def reverse(self):
"""Reverse the graph."""
self.nodes.reverse()
self.child_dict, self.parent_dict = self.parent_dict, self.child_dict
class Target:
"""Target object."""
def __init__(self, path):
"""Target constructor
Arguments:
path {str} -- absolute path to target file.
"""
self.path = path
self.target = None
self.indices = None
self.ref = None
self.load()
def load(self):
"""Load target, one-best path indices and reference."""
data = np.load(self.path)
self.target = data['target']
self.indices = list(data['indices'])
self.ref = list(data['ref'])