-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_dataset.py
185 lines (148 loc) · 8.26 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List
import torch
from classes.diagram_offline import DiagramOffline
from classes.diagram_ndot_offline import DiagramOfflineNDot
from classes.qdsd import DATA_DIR
from plot.data import plot_patch_sample
from plot.lines_visualisation import create_multiplots
from utils.angle_operations import angles_from_list, normalize_angle, get_angle_stat
from utils.logger import logger
from utils.misc import save_list_to_file, renorm_array
from utils.settings import settings
from utils.output import init_out_directory, ExistingRunName
run_name = settings.run_name
try:
# Create the output directory to save results and plots
init_out_directory()
except ExistingRunName:
logger.critical(f'Existing run directory: "{run_name}"', exc_info=True)
# Set LaTex for matplotlib
plt.rcParams.update({
"text.usetex": True,
"font.family": "serif"
})
def load_diagram() -> List["DiagramOffline"]:
# Load diagrams from files (line and area)
diagrams = DiagramOfflineNDot.load_diagrams(pixel_size=settings.pixel_size,
research_group=settings.research_group,
diagrams_path=Path(DATA_DIR, 'interpolated_csv.zip'),
labels_path=Path(DATA_DIR, 'labels.json'),
single_dot=True if settings.dot_number == 1 else False,
load_lines=True,
load_areas=True,
white_list=None)
DiagramOfflineNDot.normalize_diagrams(diagrams)
# Normalize the diagram with the same min/max value used during the training.
# The values are fetch via the "normalization_values_path" setting or in the current run directory.
# DiagramOffline.normalize_diagrams(diagrams)
return diagrams
def load_patches(diagrams):
"""
From a diagrams list, generate patches on each diagram with associated line intersecting them. In case there are more
than one line cutting the patch, it is stored in a separated list called patches_multi_line. The patch is repeated N times
corresponding to the N lines intersecting it. It acts like a single line patch, but it might be more clever to differentiate
these two cases. It might not be the smartest way to do it though.
:param diagrams:
:return:
"""
# Patches with one line
patches = []
lines = []
for diagram in diagrams:
# torch.save(diagram.values[678:696, 138:156], 'test_bastien.pt') # example of patch being saved for later use
diagram_patches, lines_patches = diagram.get_patches((settings.patch_size_x, settings.patch_size_y), (6, 6),
(0, 0))
patches.extend(diagram_patches)
lines.extend(lines_patches)
return patches, lines
if __name__ == '__main__':
# Set patches and angles path with extra parameters if defined
path_torch = f'./saved/double_dot_{settings.research_group}_patches_normalized_{settings.patch_size_x}_{settings.patch_size_y}'
path_angle = f'./saved/double_dot_{settings.research_group}_angles_{settings.patch_size_x}_{settings.patch_size_y}'
diagrams_exp = load_diagram()
patches_list, lines_list = load_patches(diagrams_exp)
selected_patches = []
selected_lines = []
for patch, line_list in zip(patches_list, lines_list):
if len(line_list) == 1: # takes patch into account only if it has one line crossing it
if settings.dx:
Dx = np.gradient(patch)[0] # derivative with respect to the x-axis
selected_patches.append(Dx) # convert numpy array back to torch tensor
else:
selected_patches.append(patch)
selected_lines.append(line_list)
angles_lines = angles_from_list(selected_lines, normalize=True)
# get_angle_stat(angles_lines) # un-comment this line if you want to see the angle statistical distribution
if settings.dx:
# Update file name if derivative of patch is True
path_torch += "_Dx"
path_angle += "_Dx"
if settings.rotate_patch:
# Update file name
path_torch += "_rotate"
path_angle += "_rotate"
from utils.rotation import rotate_patches
selected_patches, rotated_lines_list, rotated_angle_list = rotate_patches(selected_patches,
selected_lines,
angles_lines)
get_angle_stat(rotated_angle_list)
if settings.include_synthetic:
# Update file path if user requires the population of the dataset with synthetic data
path_torch += f'_populated{str(settings.mean_gaussian).replace(".", "")}_{str(settings.scale_gaussian).replace(".", "")}'
path_angle += f'_populated{str(settings.mean_gaussian).replace(".", "")}_{str(settings.scale_gaussian).replace(".", "")}'
# Only load module if necessary
from utils.populate import populate_angles
populated_patches, populated_lines_list, populated_angle_list = populate_angles(selected_patches,
selected_lines,
angles_lines,
percentage=0.9,
size=(settings.patch_size_x, settings.patch_size_y),
background=settings.background,
sigma=settings.sigma,
aa=settings.anti_alias)
get_angle_stat(populated_angle_list) # plot the angle statistical distribution for the new dataset
# Plot a sample of patches to see an example of lines
plot_patch_sample(populated_patches,
populated_lines_list,
sample_number=16,
show_offset=False,
name='one_line_populated_DQD')
# Plot a sample of patches with line highlighted
plot_patch_sample(selected_patches, selected_lines, sample_number=16, show_offset=False, name='one_line_DQD')
# Reshape patches for neural network
# Get the number of images
n = len(populated_patches)
# Create an empty tensor with the desired shape
stacked_patches = torch.empty(n, settings.patch_size_x, settings.patch_size_y, dtype=torch.float32)
# Fill the tensor with stacked patches
if type(populated_patches[0]) == np.ndarray:
stacked_array = np.stack(populated_patches)
stacked_patches = torch.from_numpy(stacked_array)
elif type(populated_patches == list):
for i in range(len(selected_patches)):
if type(populated_patches[i]) == np.ndarray:
selected_patch = (populated_patches[i]).copy() # make a copy of the numpy array
selected_patch = torch.from_numpy(selected_patch)
stacked_patches[i, :, :] = selected_patch
else:
stacked_patches[i, :, :] = populated_patches[i]
else:
stacked_patches = torch.stack(populated_patches)
tensor_patches = stacked_patches.unsqueeze(1)
if settings.full_circle:
path_torch += "_fullcircle"
path_angle += "_fullcircle"
# Add extension to file path
path_torch += ".pt"
path_angle += ".txt"
# Save tensor
torch.save(renorm_array(stacked_patches), path_torch)
# Create multiplot to check some lines
# fig, axes = create_multiplots(stacked_patches, angles_lines, number_sample=16) # un-comment this line to see an example of patches
plt.tight_layout()
plt.show()
# Save angles list to file
save_list_to_file(populated_angle_list, path_angle) # comment this line out when the patches are all loaded in a tensor, and you only need to apply Dx over them