-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathutils.py
121 lines (94 loc) · 2.92 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import matplotlib.pyplot as plt
import numpy as np
def concatenate_dict(main_dict, new_dict):
for key in main_dict.keys():
main_dict[key] += [new_dict[key]]
def plot_image(arr):
fig = plt.Figure()
ax = fig.add_subplot(111)
im = ax.imshow(arr, origin='lower', aspect='auto', interpolation='nearest')
fig.colorbar(im)
return fig
def plot_lines(arr):
fig = plt.Figure()
ax = fig.add_subplot(111)
for i in range(arr.shape[0]):
ax.plot(arr[i], label='%d' % i)
ax.legend()
return fig
def draw(offsets, ascii_seq=None, save_file=None):
strokes = np.concatenate(
[offsets[:, 0:1], np.cumsum(offsets[:, 1:], axis=0)],
axis=1
)
fig, ax = plt.subplots(figsize=(12, 3))
stroke = []
for eos, x, y in strokes:
stroke.append((x, y))
if eos == 1:
xs, ys = zip(*stroke)
ys = np.array(ys)
ax.plot(xs, ys, 'k', c='blue')
stroke = []
if stroke:
xs, ys = zip(*stroke)
ys = np.array(ys)
ax.plot(xs, ys, 'k', c='blue')
stroke = []
ax.set_xlim(-50, 600)
ax.set_ylim(-40, 40)
ax.axis('off')
ax.set_aspect('equal')
ax.tick_params(
axis='off', left=False, right=False,
top=False, bottom=False,
labelleft=False, labeltop=False,
labelright=False, labelbottom=False
)
if ascii_seq is not None:
if not isinstance(ascii_seq, str):
ascii_seq = ''.join(list(map(chr, ascii_seq)))
plt.title(ascii_seq)
if save_file is not None:
plt.savefig(save_file)
return fig
def draw_multiple(list_of_offsets, ascii_seq=None, save_file=None):
list_of_strokes = []
for offsets in list_of_offsets:
strokes = np.concatenate(
[offsets[:, 0:1], np.cumsum(offsets[:, 1:], axis=0)],
axis=1
)
list_of_strokes.append(strokes)
fig, ax = plt.subplots(figsize=(12, 9))
for i, strokes in enumerate(list_of_strokes):
strokes[:, -1] -= 30 * i
stroke = []
for eos, x, y in strokes:
stroke.append((x, y))
if eos == 1:
xs, ys = zip(*stroke)
ys = np.array(ys)
ax.plot(xs, ys, 'k', c='blue')
stroke = []
if stroke:
xs, ys = zip(*stroke)
ys = np.array(ys)
ax.plot(xs, ys, 'k', c='blue')
# ax.set_xlim(-50, 600)
# ax.set_ylim(-200, 200)
ax.axis('off')
ax.set_aspect('equal')
ax.tick_params(
axis='off', left=False, right=False,
top=False, bottom=False,
labelleft=False, labeltop=False,
labelright=False, labelbottom=False
)
if ascii_seq is not None:
if not isinstance(ascii_seq, str):
ascii_seq = ''.join(list(map(chr, ascii_seq)))
plt.title(ascii_seq)
if save_file is not None:
plt.savefig(save_file)
return fig