-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathclassify_lines.py
111 lines (95 loc) · 4.46 KB
/
classify_lines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Methods using ExactLine to understand classification of a network on a line.
"""
import numpy as np
from pysyrenn.frontend.network import Network
from pysyrenn.frontend.argmax_layer import ArgMaxLayer
class LinesClassifier:
"""Handles classifying a set of lines using ExactLine.
"""
def __init__(self, network, lines, preimages=True):
"""Creates a new LinesClassifier for the given @network and @lines.
@lines should be a list of (startpoint, endpoint) tuples. If
preimages=True is set, preimages of the endpoints of each
classification region will be returned (otherwise, only the ratio
between startpoint and endpoint will be).
"""
self.network = network
self.lines = lines
self.preimages = preimages
self.partially_computed = False
self.transformed_lines = None
self.computed = False
self.classifications = None
def partial_compute(self):
"""Computes the relevant ExactLine and stores it for analysis.
"""
if self.partially_computed:
return
self.transformed_lines = self.network.exactlines(
self.lines, compute_preimages=self.preimages, include_post=True)
self.partially_computed = True
@classmethod
def from_exactlines(cls, transformed_lines):
"""Constructs a partially-computed LinesClassifier from ExactLines.
This is useful, for example, if you need ExactLines for some other
analysis and then want to determine classification regions. We use it
to determine the class break-point when generating Figure 4 from [1]
(experiments/linearity_hypothesis.py).
"""
if not len(transformed_lines[0]) == 2:
error = ("ExactLine must be called with include_post=True " +
"to use from_exactline.")
if len(transformed_lines) == 2:
error += ("\nIf you called exactline (singular), you must " +
"pass a singleton list instead.")
raise TypeError(error)
self = cls(None, None, None)
self.transformed_lines = transformed_lines
self.partially_computed = True
return self
def compute(self):
"""Returns the classification regions of network restricted to line.
Returns a list with one tuple (pre_regions, corresponding_labels) for
each line in self.lines. pre_regions is a list of tuples of endpoints
that partition each input line.
"""
if self.computed:
return self.classifications
self.partial_compute()
self.classifications = []
classify_network = Network([ArgMaxLayer()])
for pre, post in self.transformed_lines:
# First, we take each of the linear regions and split them where
# the ArgMax changes.
lines = list(zip(post[:-1], post[1:]))
classify_transformed_lines = classify_network.exactlines(
lines, compute_preimages=False, include_post=False)
split_pre = []
split_post = []
for i, endpoints in enumerate(classify_transformed_lines):
pre_delta = pre[i + 1] - pre[i]
post_delta = post[i + 1] - post[i]
for point_ratio in endpoints:
point_pre = pre[i] + (point_ratio * pre_delta)
point_post = post[i] + (point_ratio * post_delta)
if i == 0 or not point_ratio == 0.0:
split_pre.append(point_pre)
split_post.append(point_post)
# Now, in each of the resulting regions, we compute the
# corresponding label.
region_labels = []
for i in range(len(split_pre) - 1):
mid_post = 0.5 * (split_post[i] + split_post[i + 1])
region_labels.append(np.argmax(mid_post))
# Finally, we merge segments with the same classification.
merged_pre = []
merged_labels = []
for i, label in enumerate(region_labels):
if not merged_labels or label != merged_labels[-1]:
merged_pre.append(split_pre[i])
merged_labels.append(label)
merged_pre.append(split_pre[-1])
regions = list(zip(merged_pre[:-1], merged_pre[1:]))
self.classifications.append((regions, merged_labels))
self.computed = True
return self.classifications