-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathArrayTree.py
134 lines (103 loc) · 4.36 KB
/
ArrayTree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import copy
class ArrayTree:
#child_idx = given a node p, which index child is c?
#idx = general index in tree
#num_branches = number of branches
def __init__(self, num_branches):
self.nodes = []
self.edges = []
self.num_branches = num_branches
def __init__(self, num_branches, inp_nodes, inp_edges):
self.nodes = inp_nodes
self.edges = inp_edges
self.num_branches = num_branches
def child_idx_from_idx(self, idx):
if (idx % self.num_branches) == 0:
return (self.num_branches - 1)
else:
return ((idx % self.num_branches) - 1)
def get_parent_idx(self, idx):
mod_child_idx = (self.child_idx_from_idx(idx) + 1)
return int(((idx - mod_child_idx)/self.num_branches))
def get_children(self, parent_idx):
children = []
first_child = parent_idx * self.num_branches + 1
if len(self.nodes) > first_child:
for x in range(0,self.num_branches):
children.append(self.nodes[first_child + x])
return children
def get_children_indexes(self, parent_idx):
children = []
first_child = parent_idx * self.num_branches + 1
if len(self.nodes) > first_child:
for x in range(0,self.num_branches):
children.append(first_child + x)
return children
def get_height(self):
idx = len(self.nodes)-1
height = 0
while idx > 0:
height += 1
idx = self.get_parent_idx(idx)
return height
def get_childs_tree_idx(self, parent_idx, child_idx):
return self.num_branches * parent_idx + child_idx + 1
def has_child(self, parent_idx):
return ((len(self.nodes)-1) >= self.get_childs_tree_idx(parent_idx, 0))
def get_leaves(self):
idx = 0
while self.has_child(idx):
idx = self.get_childs_tree_idx(idx,0)
return self.nodes[idx:]
def add_level(self, inp_nodes, inp_edges):
#input array should be the proper length
#check if len(inp_array) == len(self.get_leaves())*num_branches
self.nodes.extend(inp_nodes)
self.edges.extend(inp_edges)
def copy_and_add_level(self, inp_nodes, inp_edges):
#input array should be the proper length
#check if len(inp_array) == len(self.get_leaves())*num_branches
tree = ArrayTree(self.num_branches, self.nodes.copy(), self.edges.copy())
tree.nodes.extend(inp_nodes)
tree.edges.extend(inp_edges)
return tree
def fully_explored(self,horizon):
#height of tree should be 1 less than horizon, e.g. if horizon is 1, return tree of height 0
#print("Height: " + str(self.get_height()))
return ((horizon - 1) <= (self.get_height()))
def get_level(self, idx):
#given an index, tells you what level of the tree it is on
level = 0
while idx > 0:
level += 1
idx = self.get_parent_idx(idx)
return int(level)
def print(self):
print(self.nodes[0])
for i in range(1,len(self.nodes)):
print("-" * self.get_level(i) + " Obs: " + str(self.edges[i]) + " -> Act: " + str(self.nodes[i]))
def get_graph_viz(self, tree_name):
#returns tree in graph_viz format
output = ''
output += 'digraph ' + tree_name + ' {\n'
output += 'edge [dir=none];\n'
for i in range(len(self.nodes)):
output += 'node' + str(i) + ' [ label = "' + str(self.nodes[i]) + '" ];\n'
for i in range(1,len(self.nodes)):
output += 'node' + str(self.get_parent_idx(i)) + ' -> '
output += 'node' + str(i) + ' [label="' + str(self.edges[i]) + '"];\n'
output += "}"
return output
def get_child_edge_idx_with_value(self, parent_idx, value):
for child_idx in self.get_children_indexes(parent_idx):
if self.edges[child_idx] == value:
return child_idx
print("ERROR. Observation not found")
return " "
def write_tree(self):
"""Returns a list that describes the tree"""
return [self.num_branches, self.nodes, self.edges]
def set_nodes(self, node_list):
self.nodes = node_list
def set_edges(self, edge_list):
self.nodes = edge_list