-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtracks.py
85 lines (69 loc) · 2.21 KB
/
tracks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Track implementation with path finder
"""
from collections import namedtuple
Track = namedtuple('Track', ['src', 'dst', 'cost'])
Pcost = namedtuple('PathCost', ['path', 'cost'])
def ucs(tracks, src, dst):
"""unified cost search for path finding"""
def min_idx(openset):
minidx = 0
i = 1
while i < len(openset):
if openset[i].cost < openset[minidx].cost:
minidx = i
i += 1
return minidx
# asserts
if src == dst:
return []
if src not in tracks:
return []
if dst not in tracks:
return []
# initialization
openset = []
for track in tracks[src]:
openset.append(Pcost([track], track.cost))
# search
while openset:
# pop shortest path
minidx = min_idx(openset)
pc = openset[minidx]
del openset[minidx]
# are we finished?
last = pc.path[-1].dst
if last == dst:
return pc.path
# enroll it
for track in tracks[last]:
new_pc = Pcost(pc.path + [track], pc.cost + track.cost)
openset.append(new_pc)
class Tracks:
"""mapped tracks"""
def __init__(self, tracks):
# create track map: src -> (src, dst, delay)
self.tracks = {}
for track in tracks:
if track.src in self.tracks:
self.tracks[track.src].append(track)
else:
self.tracks[track.src] = [track]
def get_path(self, src, dst):
"""
Find shortest path from src to dst. Returns [] if invalid locations.
Not connected subgraphs may cause infinite cycle.
"""
return ucs(self.tracks, src, dst)
def stations(self):
"""Return set of stations. Read only."""
return self.tracks.keys()
def export(self, filename: str):
"""Exports map of tracks to Graphviz"""
with open(filename, "w") as output:
output.write("digraph track {\n")
for tracks in self.tracks.values():
for track in tracks:
output.write(' %s -> %s [label="%s"];\n' %
(track.src, track.dst, track.cost))
output.write('}\n')