-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
2209 lines (1816 loc) · 101 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import argparse
import math
import copy
import os
import sys
import itertools
import warnings
import wandb
import yaml
import torch
import torchvision
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
import torch.nn.functional as F
from pathlib import Path
from functools import partial
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
from tqdm import tqdm
from typing import Any, List, Literal, Optional, Callable, Tuple, Union, Type
from loguru import logger
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
from torchvision.transforms import (ToTensor, RandomCrop, RandomResizedCrop, RandomHorizontalFlip, Normalize, Compose,
Resize)
from torchvision.transforms.functional import resize
from torchvision.datasets import ImageFolder
from pytorch_lightning import LightningDataModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.trainer.states import RunningStage
from consts import LOGGER_FORMAT
from schemas.architecture import ArchitectureArgs
from schemas.data import DataArgs
from schemas.environment import EnvironmentArgs
from schemas.optimization import OptimizationArgs
from vgg import configs, get_model_kernel_size
def log_args(args):
"""Logs the given arguments to the logger's output.
"""
logger.info(f'Running with the following arguments:')
longest_arg_name_length = max(len(k) for k in args.flattened_dict().keys())
pad_length = longest_arg_name_length + 4
for arg_name, value in args.flattened_dict().items():
logger.info(f'{f"{arg_name} ":-<{pad_length}} {value}')
def parse_args():
parser = argparse.ArgumentParser(
description='Main script for running the experiments with arguments from the corresponding pydantic schema'
)
parser.add_argument('--yaml_path', help=f'(Optional) path to a YAML file with the arguments')
return parser.parse_known_args()
def get_args(args_class):
"""Gets arguments as an instance of the given pydantic class,
according to the argparse object (possibly including the yaml config).
"""
known_args, unknown_args = parse_args()
args_dict = None
if known_args.yaml_path is not None:
with open(known_args.yaml_path, 'r') as f:
args_dict = yaml.load(f, Loader=yaml.FullLoader)
if args_dict is None: # This happens when the yaml file is empty, or no yaml file was given.
args_dict = dict()
while len(unknown_args) > 0:
arg_name = unknown_args.pop(0).replace('--', '')
values = list()
while (len(unknown_args) > 0) and (not unknown_args[0].startswith('--')):
values.append(unknown_args.pop(0))
if len(values) == 0:
raise ValueError(f'Argument {arg_name} given in command line has no corresponding value.')
value = values[0] if len(values) == 1 else values
categories = list(args_class.__fields__.keys())
found = False
for category in categories:
category_args = list(args_class.__fields__[category].default.__fields__.keys())
if arg_name in category_args:
if category not in args_dict:
args_dict[category] = dict()
args_dict[category][arg_name] = value
found = True
if not found:
raise ValueError(f'Argument {arg_name} is not recognized.')
args = args_class.parse_obj(args_dict)
return args
def get_possibly_sparse_linear_layer(in_features: int, out_features: int, sparse_fraction: float):
if sparse_fraction > 0:
return RandomlySparseConnected(in_features, out_features, sparse_fraction)
else:
return nn.Linear(in_features, out_features)
def get_mlp(input_dim: int,
output_dim: int,
n_hidden_layers: int = 0,
hidden_dimensions: Union[int, List[int]] = 0,
use_batch_norm: bool = False,
organize_as_blocks: bool = True,
shuffle_blocks_output: Union[bool, List[bool]] = False,
fixed_permutation_per_block: bool = False,
sparse_fractions: Optional[List[float]] = None) -> torch.nn.Sequential:
"""Create an MLP (i.e. Multi-Layer-Perceptron) and return it as a PyTorch's sequential model.
Args:
input_dim: The dimension of the input tensor.
output_dim: The dimension of the output tensor.
n_hidden_layers: Number of hidden layers.
hidden_dimensions: The dimension of each hidden layer.
use_batch_norm: Whether to use BatchNormalization after each layer or not.
organize_as_blocks: Whether to organize the model as blocks of Linear->(BatchNorm)->ReLU.
shuffle_blocks_output: If it's true - shuffle the output of each block in the network.
If it's a list of values, define as single value which will be True if any one of the values is True.
fixed_permutation_per_block: If it's true - use a fixed permutation per block in the network
and not sample a new one each time.
sparse_fractions: If given (i.e. it's not None) should be a list of floats,
indicating the sparsity of each layer. A number 0 < q < 1 indicates that
only a q fraction of the neurons will be connected.
Zero means that a fully-connected will be used.
Returns:
A sequential model which is the constructed MLP.
"""
layers: List[torch.nn.Module] = list()
if isinstance(shuffle_blocks_output, list):
shuffle_blocks_output = any(shuffle_blocks_output)
if not isinstance(hidden_dimensions, list):
hidden_dimensions = [hidden_dimensions] * n_hidden_layers
assert len(hidden_dimensions) == n_hidden_layers
sparse_fractions = get_list_of_arguments(sparse_fractions, len(hidden_dimensions) + 1, default=0)
in_features = input_dim
for i, hidden_dim in enumerate(hidden_dimensions):
block_layers: List[nn.Module] = list()
out_features = hidden_dim
# Begins with a `Flatten` layer. It's useful when the input is 4D from a conv layer, and harmless otherwise.
if i == 0:
block_layers.append(nn.Flatten())
block_layers.append(get_possibly_sparse_linear_layer(in_features, out_features, sparse_fractions[i]))
if use_batch_norm:
block_layers.append(torch.nn.BatchNorm1d(hidden_dim))
block_layers.append(torch.nn.ReLU())
if shuffle_blocks_output:
block_layers.append(ShuffleTensor(spatial_size=1, channels=out_features,
fixed_permutation=fixed_permutation_per_block))
if organize_as_blocks:
layers.append(nn.Sequential(*block_layers))
else:
layers.extend(block_layers)
in_features = out_features
final_layer = get_possibly_sparse_linear_layer(in_features, output_dim, sparse_fractions[-1])
if organize_as_blocks:
block_layers = [final_layer]
if len(hidden_dimensions) == 0:
block_layers = [nn.Flatten()] + block_layers
layers.append(nn.Sequential(*block_layers))
else:
if len(hidden_dimensions) == 0:
layers.append(nn.Flatten())
layers.append(final_layer)
return nn.Sequential(*layers)
def get_list_of_arguments(arg, length, default=None):
if isinstance(arg, list):
assert len(arg) == length
return copy.deepcopy(arg)
else:
if (default is not None) and (arg is None):
if isinstance(default, list):
return copy.deepcopy(default)
arg = default
return [arg] * length
class View(nn.Module):
def __init__(self, shape: tuple):
super().__init__()
self.shape = shape
def forward(self, x: torch.Tensor):
return x.view(x.shape[0], *self.shape)
def extra_repr(self) -> str:
return f'shape={self.shape}'
# # For debugging purposes, e.g. to verify the gradients of the weights
# # in the locations where mask is 0 remain unchanged.
# def verify_grad_is_zero_where_mask_is_0(grad: torch.Tensor, mask: torch.Tensor):
# # For debugging purposes, e.g. to verify the gradients of the weights
# # in the locations where mask is 0 remain unchanged.
# assert torch.all((grad == 0) | (mask == 1)).item()
class RandomlySparseConnected(nn.Module):
def __init__(self, in_features: int, out_features: int, fraction: float,
bias: bool = True, device=None, dtype=None):
super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
self.in_features = in_features
self.out_features = out_features
self.fraction = fraction
self.num_nonzero_weights_per_output_neuron = int(round(fraction * in_features))
self.mask = nn.Parameter(torch.zeros((out_features, in_features), **factory_kwargs), requires_grad=False)
self.weight = nn.Parameter(torch.zeros((out_features, in_features), **factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
# # For debugging purposes, e.g. to verify the gradients of the weights
# # in the locations where mask is 0 remain unchanged.
# self.weight.register_hook(functools.partial(verify_grad_is_zero_where_mask_is_0, mask=self.mask))
def reset_parameters(self) -> None:
"""Initializes the parameters if the module, which are
the random mask, and the weight and bias of the linear layer.
"""
self.init_random_mask()
self.init_weight_and_bias()
def init_random_mask(self):
"""Initializes the boolean mask indicating the (sparse) connections between the input and output neurons.
The mask is a tensor of shape (self.out_features, self.in_features) and dtype np.float32,
where a value of 1 in the coordinate ij means that the i-th output neuron
is connected to the j-th input neuron.
"""
ordered_indices_vector = np.tile(np.arange(self.in_features), self.out_features)
ordered_indices_matrix = ordered_indices_vector.reshape(self.out_features, self.in_features)
shuffled_indices_matrix = np.random.default_rng().permuted(ordered_indices_matrix, axis=1)
indices = shuffled_indices_matrix[:, :self.num_nonzero_weights_per_output_neuron]
for i in range(self.out_features):
self.mask[i, indices[i]] = 1
@torch.no_grad()
def init_weight_and_bias(self):
"""Initializes the weight matrix and bias vector of the linear layer.
The weight matrix and bias vector of the linear layer are initialized as if they were
a linear layer from `self.num_nonzero_weights_per_output_neuron` input neurons to
`self.out_features` output neurons (since de facto this is what's going to happen).
To understand why this is done in the context-manager `no_grad` see here:
https://medium.com/@mrityu.jha/understanding-the-grad-of-autograd-fc8d266fd6cf
Prevents the RuntimeError:
"RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation."
"""
tmp_linear = nn.Linear(in_features=self.num_nonzero_weights_per_output_neuron,
out_features=self.out_features,
bias=(self.bias is not None),
device=self.weight.device)
bool_mask = self.mask.bool()
for i in range(self.out_features):
self.weight[i, bool_mask[i]] = tmp_linear.weight.data[i]
self.bias = tmp_linear.bias
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight * self.mask, self.bias)
def extra_repr(self) -> str:
return f'in_feature={self.in_features}, ' + \
f'out_features={self.out_features}, ' + \
f'bias={self.bias is not None}, ' + \
f'fraction={self.fraction:.3f}'
def get_cnn(conv_channels: List[int],
linear_channels: List[int],
kernel_sizes: Optional[List[int]] = None,
strides: Optional[List[int]] = None,
use_max_pool: Optional[List[bool]] = None,
paddings: Optional[List[int]] = None,
shuffle_outputs: Optional[List[bool]] = None,
spatial_only: Optional[List[bool]] = None,
fixed_permutation: Optional[List[bool]] = None,
replace_with_linear: Optional[List[bool]] = None,
replace_with_bottleneck: Optional[List[int]] = None,
sparse_connected_fractions: Optional[List[float]] = None,
adaptive_avg_pool_before_mlp: bool = False,
max_pool_after_first_conv: bool = False,
in_spatial_size: int = 32,
in_channels: int = 3,
n_classes: int = 10) -> tuple[nn.Sequential, nn.Sequential]:
"""This function builds a CNN and return it as two PyTorch sequential models (convolutions followed by mlp).
Args:
conv_channels: A list of integers indicating the number of channels in the convolutional layers.
linear_channels: A list of integers indicating the number of channels in the linear layers.
kernel_sizes: A list of integers indicating the kernel size of the convolutional layers.
strides: A list of integers indicating the stride of the convolutional layers.
paddings: A list of integers indicating the padding of the convolutional layers.
use_max_pool: A list of booleans indicating whether to use max pooling in the convolutional layers.
shuffle_outputs: A list of booleans indicating whether to shuffle the outputs of the convolutional layers.
spatial_only: A list of booleans indicating whether to shuffle spatial-only (see doc in `ShuffleTensor`).
fixed_permutation: A list of booleans indicating whether to use fixed permutations (see doc in `ShuffleTensor`).
replace_with_linear: A list of booleans indicating whether to replace the convolutional layers
with linear layers of the same expressiveness.
replace_with_bottleneck: Whether to replace each conv layer with a "bottleneck" linear layer
of the same expressiveness, meaning a linear layer of low rank constraint
(e.g. 100,000 -> 1,000 -> 100,000). The number represent the middle linear layer dimensionality.
sparse_connected_fractions: A list of fractions (i.e. floats between 0 and 1),
where positive values indicate to replace the corresponding convolution layer
with a randomly sparse connected layer with the given fraction of connections.
adaptive_avg_pool_before_mlp: Whether to use adaptive average pooling to 1x1 before the final mlp.
(This is done in ResNet architectures).
max_pool_after_first_conv: Whether to use 3x3 max pool (padding=1, stride=2) after the first convolution layer.
in_spatial_size: Will be used to infer input dimension for the first affine layer.
in_channels: Number of channels in the input tensor.
n_classes: Number of classes (i.e. determines the size of the prediction vector containing the classes' scores).
Returns:
A sequential model which is the constructed CNN.
"""
conv_blocks: List[nn.Sequential] = list()
n_convs = len(conv_channels)
use_max_pool = get_list_of_arguments(use_max_pool, n_convs, default=False)
shuffle_outputs = get_list_of_arguments(shuffle_outputs, n_convs, default=False)
strides = get_list_of_arguments(strides, n_convs, default=1)
kernel_sizes = get_list_of_arguments(kernel_sizes, n_convs, default=3)
paddings = get_list_of_arguments(paddings, n_convs, default=[k // 2 for k in kernel_sizes])
spatial_only = get_list_of_arguments(spatial_only, n_convs, default=True)
fixed_permutation = get_list_of_arguments(fixed_permutation, n_convs, default=True)
replace_with_linear = get_list_of_arguments(replace_with_linear, n_convs, default=False)
replace_with_bottleneck = get_list_of_arguments(replace_with_bottleneck, n_convs, default=0)
sparse_connected_fractions = get_list_of_arguments(sparse_connected_fractions, n_convs + len(linear_channels) + 1,
default=0)
zipped_args = zip(conv_channels, paddings, strides, kernel_sizes, use_max_pool,
shuffle_outputs, spatial_only, fixed_permutation,
replace_with_linear, replace_with_bottleneck, sparse_connected_fractions[:n_convs])
for i, (out_channels, padding, stride, kernel_size, pool,
shuf, spatial, fixed,
linear, bottleneck, sparse_fraction) in enumerate(zipped_args):
block_layers: List[nn.Module] = list()
out_spatial_size = int(math.floor((in_spatial_size + 2 * padding - kernel_size) / stride + 1))
if pool:
out_spatial_size = int(math.floor(out_spatial_size / 2))
if linear or (sparse_fraction > 0) or (bottleneck > 0):
assert int(linear) + int(sparse_fraction > 0) + int(bottleneck > 0) == 1, \
'Only one of linear, sparse_fraction, bottleneck can be true'
in_features = in_channels * (in_spatial_size ** 2)
out_features = out_channels * (out_spatial_size ** 2)
block_layers.append(nn.Flatten())
if bottleneck > 0:
block_layers.append(nn.Linear(in_features, bottleneck))
block_layers.append(nn.Linear(bottleneck, out_features))
else:
block_layers.append(get_possibly_sparse_linear_layer(in_features, out_features, sparse_fraction))
block_layers.append(View(shape=(out_channels, out_spatial_size, out_spatial_size)))
else:
block_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
block_layers.append(nn.BatchNorm2d(out_channels)) # TODO make an argument for using BatchNorm
block_layers.append(nn.ReLU())
if pool:
block_layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
if shuf:
block_layers.append(ShuffleTensor(out_spatial_size, out_channels, spatial, fixed))
if (i == 0) and max_pool_after_first_conv:
block_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
if (i == n_convs - 1) and adaptive_avg_pool_before_mlp:
block_layers.append(nn.AdaptiveAvgPool2d((1, 1)))
out_spatial_size = 1
conv_blocks.append(nn.Sequential(*block_layers))
in_channels = out_channels
in_spatial_size = out_spatial_size
features = torch.nn.Sequential(*conv_blocks)
mlp = get_mlp(input_dim=in_channels * (in_spatial_size ** 2),
output_dim=n_classes,
n_hidden_layers=len(linear_channels),
hidden_dimensions=linear_channels,
use_batch_norm=True, # TODO make this an argument
organize_as_blocks=True,
sparse_fractions=sparse_connected_fractions[n_convs:])
return features, mlp
@torch.no_grad()
def calc_aggregated_patch(dataloader,
patch_size,
agg_func: Callable,
existing_model: Optional[nn.Module] = None):
"""Calculate the aggregated patch across all patches in the dataloader.
Args:
dataloader: dataloader to iterate on.
patch_size: The patch-size to feed into the aggregate function.
agg_func: The aggregate function, which gets a single argument which is a NumPy array,
and return a single argument which is a NumPy array
existing_model: An (optionally) existing model to call on each image in the data.
"""
total_size = 0
mean = None
device = get_model_device(existing_model)
for inputs, _ in tqdm(dataloader, total=len(dataloader), desc='Calculating mean patch'):
inputs = inputs.to(device)
if existing_model is not None:
inputs = existing_model(inputs)
# Unfold the input batch to its patches - shape (N, C*H*W, M) where M is the number of patches per image.
patches = F.unfold(inputs, patch_size)
# Transpose to (N, M, C*H*W) and then reshape to (N*M, C*H*W) to have collection of vectors
# Also make contiguous in memory
patches = patches.transpose(1, 2).flatten(0, 1).contiguous().double()
# Perform the aggregation function over the batch-size and number of patches per image.
# For example, when calculating mean it'll a (C*H*W)-dimensional vector,
# and when calculating the covariance it will be a square matrix of shape (C*H*W, C*H*W)
aggregated_patch = agg_func(patches)
if mean is None:
mean = torch.zeros_like(aggregated_patch)
batch_size = inputs.size(0)
mean = ((total_size / (total_size + batch_size)) * mean +
(batch_size / (total_size + batch_size)) * aggregated_patch)
total_size += batch_size
return mean
def calc_covariance(data, mean=None):
"""Calculates the covariance-matrix of the given data.
This function assumes the data matrix is ordered as rows-vectors
(i.e. shape (n,d) so n data-points in d dimensions).
Args:
data: The given data, a 2-dimensional NumPy array ordered as rows-vectors
(i.e. shape (n,d) so n data-points in d dimensions).
mean: The mean of the data, if not given the mean will be calculated.
It's useful when the mean is the mean of some larger distribution, and not only the mean of the
given data array (as done when calculating the covariance matrix of the whole patches distribution).
Returns:
The covariance-matrix of the given data.
"""
if mean is None:
mean = data.mean(axis=0)
centered_data = data - mean
return (1 / data.shape[0]) * (centered_data.T @ centered_data)
def calc_whitening_from_dataloader(dataloader: DataLoader,
patch_size: int,
whitening_regularization_factor: float,
zca_whitening: bool = False,
existing_model: Optional[nn.Module] = None) -> np.ndarray:
"""Calculates the whitening matrix from the given data.
Denote the data matrix by X (i.e. collection of patches) with shape N x D.
N is the number of patches, and D is the dimension of each patch (channels * spatial_size ** 2).
This function returns the whitening operator as a columns-vectors matrix of shape D x D,
so it needs to be multiplied by the target data matrix X' of shape N' x D from the right (X' @ W)
[and NOT from the left, i.e. NOT W @ X'].
Args:
dataloader: The given data to iterate on.
patch_size: The size of the patches to calculate the whitening on.
whitening_regularization_factor: The regularization factor used when calculating the whitening,
which is some small constant positive float added to the denominator.
zca_whitening: Whether it's ZCA whitening (or PCA whitening).
existing_model: An (optionally) existing model to call on each image in the data.
Returns:
The whitening matrix.
"""
logger.debug('Performing a first pass over the dataset to calculate the mean patch...')
mean_patch = calc_aggregated_patch(dataloader, patch_size, agg_func=partial(torch.mean, dim=0),
existing_model=existing_model)
logger.debug('Performing a second pass over the dataset to calculate the covariance...')
covariance_matrix = calc_aggregated_patch(dataloader, patch_size,
agg_func=partial(calc_covariance, mean=mean_patch),
existing_model=existing_model)
logger.debug('Calculating eigenvalues decomposition to get the whitening matrix...')
whitening_matrix = get_whitening_matrix_from_covariance_matrix(
covariance_matrix.cpu(), whitening_regularization_factor, zca_whitening
)
logger.debug('Done.')
return whitening_matrix
def configure_logger(out_dir: str, level='INFO', print_sink=sys.stdout):
"""
Configure the logger:
(1) Remove the default logger (to stdout) and use a one with a custom format.
(2) Adds a log file named `run.log` in the given output directory.
"""
logger.remove()
logger.remove()
logger.add(sink=print_sink, format=LOGGER_FORMAT, level=level)
logger.add(sink=os.path.join(out_dir, 'run.log'), format=LOGGER_FORMAT, level=level)
def get_dataloaders(batch_size: int = 64,
normalize_to_unit_gaussian: bool = False,
normalize_to_plus_minus_one: bool = False,
random_crop: bool = False,
random_horizontal_flip: bool = False,
random_erasing: bool = False,
random_resized_crop: bool = False):
"""Gets dataloaders for the CIFAR10 dataset, including data augmentations as requested by the arguments.
Args:
batch_size: The size of the mini-batches to initialize the dataloaders.
normalize_to_unit_gaussian: If true, normalize the values to be a unit gaussian.
normalize_to_plus_minus_one: If true, normalize the values to be in the range [-1,1] (instead of [0,1]).
random_crop: If true, performs padding of 4 followed by random crop.
random_horizontal_flip: If true, performs random horizontal flip.
random_erasing: If true, erase a random rectangle in the image. See https://arxiv.org/pdf/1708.04896.pdf.
random_resized_crop: If true, performs random resized crop.
Returns:
A dictionary mapping "train"/"test" to its dataloader.
"""
raise NotImplementedError('This function is deprecated and will be removed in the future.')
transforms = {'train': list(), 'test': list()}
if random_horizontal_flip:
transforms['train'].append(torchvision.transforms.RandomHorizontalFlip())
if random_crop:
transforms['train'].append(torchvision.transforms.RandomCrop(size=32, padding=4))
if random_resized_crop:
transforms['train'].append(torchvision.transforms.RandomResizedCrop(size=32, scale=(0.75, 1.), ratio=(1., 1.)))
for t in ['train', 'test']:
transforms[t].append(torchvision.transforms.ToTensor())
if random_erasing:
transforms['train'].append(torchvision.transforms.RandomErasing())
if normalize_to_plus_minus_one or normalize_to_unit_gaussian:
# For the different normalization values see:
# https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/7
if normalize_to_unit_gaussian:
# These normalization values are taken from https://github.com/kuangliu/pytorch-cifar/issues/19
# normalization_values = [(0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)]
# These normalization values are taken from https://github.com/louity/patches
# and also https://stackoverflow.com/questions/50710493/cifar-10-meaningless-normalization-values
normalization_values = [(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)]
else:
normalization_values = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]
for t in ['train', 'test']:
transforms[t].append(torchvision.transforms.Normalize(*normalization_values))
datasets = {t: torchvision.datasets.CIFAR10(root='./data',
train=(t == 'train'),
transform=torchvision.transforms.Compose(transforms[t]),
download=False)
for t in ['train', 'test']}
dataloaders = {t: torch.utils.data.DataLoader(datasets[t],
batch_size=batch_size,
shuffle=(t == 'train'),
num_workers=2)
for t in ['train', 'test']}
return dataloaders
def get_model_device(model: Optional[torch.nn.Module]):
"""Returns the device of the given model
"""
default_device = torch.device('cpu')
# If the model is None, assume the model's device is CPU.
if model is None:
return default_device
try:
device = next(model.parameters()).device
except StopIteration: # If the model has no parameters, assume the model's device is CPU.
device = default_device
return device
def power_minus_1(a: torch.Tensor):
"""Raises the input tensor to the power of minus 1.
"""
return torch.divide(torch.ones_like(a), a)
@torch.no_grad()
def get_model_output_shape(model: nn.Module, dataloader: Optional[DataLoader] = None):
"""Gets the output shape of the given model, on images from the given dataloader.
"""
if dataloader is None:
clean_dataloaders = get_dataloaders(batch_size=1)
dataloader = clean_dataloaders["train"]
inputs, _ = next(iter(dataloader))
inputs = inputs.to(get_model_device(model))
outputs = model(inputs)
outputs = outputs.cpu().numpy()
return outputs.shape[1:] # Remove the first dimension corresponding to the batch
def get_whitening_matrix_from_covariance_matrix(covariance_matrix: np.ndarray,
whitening_regularization_factor: float,
zca_whitening: bool = False) -> np.ndarray:
"""Calculates the whitening matrix from the given covariance matrix.
Args:
covariance_matrix: The covariance matrix.
whitening_regularization_factor: The regularization factor used when calculating the whitening,
which is some small constant positive float added to the denominator.
zca_whitening: Whether it's ZCA whitening (or PCA whitening).
Returns:
The whitening matrix.
"""
eigenvectors, eigenvalues, eigenvectors_transposed = np.linalg.svd(covariance_matrix, hermitian=True)
inv_sqrt_eigenvalues = np.diag(1. / (np.sqrt(eigenvalues) + whitening_regularization_factor))
whitening_matrix = eigenvectors.dot(inv_sqrt_eigenvalues)
if zca_whitening:
whitening_matrix = whitening_matrix @ eigenvectors.T
whitening_matrix = whitening_matrix.astype(np.float32)
return whitening_matrix
def whiten_data(data, whitening_regularization_factor=1e-05, zca_whitening=False):
"""Whiten the given data.
Note that the data is assumed to be of shape (n_samples, n_features), meaning it's a collection of row-vectors.
Args:
data: The given data to whiten.
whitening_regularization_factor: The regularization factor used when calculating the whitening,
which is some small constant positive float added to the denominator.
zca_whitening: Whether it's ZCA whitening (or PCA whitening).
Returns:
The whitened data.
"""
covariance_matrix = calc_covariance(data)
whitening_matrix = get_whitening_matrix_from_covariance_matrix(covariance_matrix,
whitening_regularization_factor,
zca_whitening)
centered_data = data - data.mean(axis=0)
whitened_data = centered_data @ whitening_matrix
return whitened_data
def normalize_data(data, epsilon=1e-05):
"""Normalize the given data (making it centered (zero mean) and each feature have unit variance).
Note that the data is assumed to be of shape (n_samples, n_features), meaning it's a collection of row-vectors.
Args:
data: The data to normalize.
epsilon: Some small positive number to add to the denominator,
to avoid getting NANs (if the data-point has a small std).
Returns:
The normalized data.
"""
centered_data = data - data.mean(axis=0)
normalized_data = centered_data / (centered_data.std(axis=0) + epsilon)
return normalized_data
def get_random_initialized_conv_kernel_and_bias(in_channels: int,
out_channels: int,
kernel_size: int) -> Tuple[np.ndarray, np.ndarray]:
"""Returns randomly initialized kernel and bias for a conv layer, as in PyTorch default initialization (Xavier).
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: The kernel size.
Returns:
A tuple of two numpy arrays which are the randomly initialized kernel and bias for a conv layer,
as in PyTorch default initialization (Xavier).
"""
tmp_conv = nn.Conv2d(in_channels, out_channels, kernel_size)
kernel = tmp_conv.weight.data.cpu().numpy().copy()
bias = tmp_conv.bias.data.cpu().numpy().copy()
return kernel, bias
@torch.no_grad()
def sample_random_patches(dataloader,
n_patches,
patch_size,
existing_model: Optional[nn.Module] = None,
visualize: bool = False,
random_uniform_patches: bool = False,
random_gaussian_patches: bool = False,
verbose: bool = False):
"""Sample random patches from the data.
Args:
dataloader: The dataloader to sample patches from.
n_patches: Number of patches to sample.
patch_size: The size of the patches to sample.
existing_model: (Possibly) an existing model to transform each image from the dataloader.
visualize: Whether to visualize the sampled patches (for debugging purposes).
random_uniform_patches: Whether to avoid sampling and simply return patches from the uniform distribution.
random_gaussian_patches: Whether to avoid sampling and simply return patches from the Gaussian distribution.
verbose: Whether to print progress using tqdm.
Returns:
The sampled patches as a NumPy array.
"""
batch_size = dataloader.batch_size
n_images = len(dataloader.dataset)
# We need the shape of the images in the data.
# In relatively small datasets (CIFAR, MNIST) the data itself is stored in `dataloader.dataset.data`
# and in ImageNet it's not the case since the data is too large.
# This is why the shape of ImageNet images is hard-coded.
images_shape = dataloader.dataset.data.shape[1:] if hasattr(dataloader.dataset, 'data') else (224, 224, 3)
if len(images_shape) == 2: # When the dataset contains grayscale images,
images_shape += (1,) # add dimension of channels which will be 1.
images_shape = np.roll(images_shape, shift=1) # In the dataset it's H x W x C but in the model it's C x H x W
if existing_model is not None:
device = get_model_device(existing_model)
images_shape = get_model_output_shape(existing_model, dataloader)
if len(images_shape) > 1:
assert len(images_shape) == 3 and (images_shape[1] == images_shape[2]), "Should be C x H x W where H = W"
spatial_size = images_shape[-1]
if patch_size == -1: # -1 means the patch size is the whole size of the image.
patch_size = spatial_size
n_patches_per_row_or_col = spatial_size - patch_size + 1
patch_shape = (images_shape[0],) + 2 * (patch_size,)
else:
assert patch_size == -1, "When working with fully-connected the patch 'size' must be -1 i.e. the whole size."
n_patches_per_row_or_col = 1
patch_shape = images_shape
n_patches_per_image = n_patches_per_row_or_col ** 2
n_patches_in_dataset = n_images * n_patches_per_image
if n_patches >= n_patches_in_dataset:
n_patches = n_patches_in_dataset
patches_indices_in_dataset = np.random.default_rng().choice(n_patches_in_dataset, size=n_patches, replace=False)
images_indices = patches_indices_in_dataset % n_images
patches_indices_in_images = patches_indices_in_dataset // n_images
patches_x_indices_in_images = patches_indices_in_images % n_patches_per_row_or_col
patches_y_indices_in_images = patches_indices_in_images // n_patches_per_row_or_col
batches_indices = images_indices // batch_size
images_indices_in_batches = images_indices % batch_size
patches = np.empty(shape=(n_patches,) + patch_shape, dtype=np.float32)
if random_uniform_patches:
return np.random.default_rng().uniform(low=-1, high=+1, size=patches.shape).astype(np.float32)
if random_gaussian_patches:
patch_dim = math.prod(patch_shape)
return np.random.default_rng().multivariate_normal(
mean=np.zeros(patch_dim), cov=np.eye(patch_dim), size=n_patches).astype(np.float32).reshape(patches.shape)
iterator = enumerate(dataloader)
if verbose:
iterator = tqdm(iterator, total=len(dataloader), desc='Sampling patches from the dataset')
for batch_index, (inputs, _) in iterator:
if batch_index not in batches_indices:
continue
relevant_patches_mask = (batch_index == batches_indices)
relevant_patches_indices = np.where(relevant_patches_mask)[0]
if existing_model is not None:
inputs = inputs.to(device)
inputs = existing_model(inputs)
inputs = inputs.cpu().numpy()
for i in relevant_patches_indices:
image_index_in_batch = images_indices_in_batches[i]
if len(patch_shape) > 1:
patch_x_start = patches_x_indices_in_images[i]
patch_y_start = patches_y_indices_in_images[i]
patch_x_slice = slice(patch_x_start, patch_x_start + patch_size)
patch_y_slice = slice(patch_y_start, patch_y_start + patch_size)
patches[i] = inputs[image_index_in_batch, :, patch_x_slice, patch_y_slice]
if visualize:
visualize_image_patch_pair(image=inputs[image_index_in_batch], patch=patches[i],
patch_x_start=patch_x_start, patch_y_start=patch_y_start)
else:
patches[i] = inputs[image_index_in_batch]
return patches
def visualize_image_patch_pair(image, patch, patch_x_start, patch_y_start):
"""Visualize the given image and the patch in it, with rectangle in the location of the patch.
"""
patch_size = patch.shape[-1]
rect = Rectangle(xy=(patch_y_start, patch_x_start), # x and y are reversed on purpose...
width=patch_size, height=patch_size,
linewidth=1, edgecolor='red', facecolor='none')
plt.figure()
ax = plt.subplot(2, 1, 1)
ax.imshow(np.transpose(image, axes=(1, 2, 0)))
ax.add_patch(rect)
ax = plt.subplot(2, 1, 2)
ax.imshow(np.transpose(patch, axes=(1, 2, 0)))
plt.show()
def cross_entropy_gradient(logits, labels):
"""
Calculate the gradient of the cross-entropy loss with respect to the input logits.
Note the cross-entropy loss in PyTorch basically calculates log-softmax followed by negative log-likelihood loss.
Therefore, the gradient is the softmax output of the logits, where in the labels indices a 1 is subtracted.
Inspiration from http://machinelearningmechanic.com/deep_learning/2019/09/04/cross-entropy-loss-derivative.html
:param logits: The raw scores which are the input to the cross-entropy-loss.
:param labels: The labels (for each i the index of the true class of this training-sample).
:return: The gradient of the cross-entropy loss.
"""
# This is the probabilities vector obtained using the softmax function on the raw scores.
p = torch.nn.functional.softmax(logits, dim=1)
# Subtract 1 from the labels indices, which gives the final gradient of the cross-entropy loss.
p.scatter_add_(dim=1, index=labels.unsqueeze(dim=-1), src=torch.full_like(p, fill_value=-1))
return p
def evaluate_model_with_last_gradient(model, criterion, dataloader, device, training_step=None,
log_to_wandb: bool = True):
"""
Evaluate the given model on the test set.
In addition to returning the final test loss & accuracy,
this function evaluate each one of the model local modules (by logging to wandb).
:param model: The model
:param criterion: The criterion.
:param dataloader: The test set data-loader.
:param device: The device to use.
:param training_step: The training-step (integer), important to wandb logging.
:param log_to_wandb: Whether to log to wandb or not.
:return: The test set loss and accuracy.
"""
raise DeprecationWarning('This function is deprecated, and will be removed in the future.')
model.eval()
modules_accumulators = [Accumulator() if (aux_net is not None) else None for aux_net in model.auxiliary_nets]
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
aux_nets_outputs = model(inputs)
aux_nets_losses = [criterion(outputs, labels) if (outputs is not None) else None
for outputs in aux_nets_outputs]
aux_nets_predictions = [torch.max(outputs, dim=1)[1] if (outputs is not None) else None
for outputs in aux_nets_outputs]
# Update the corresponding accumulators to visualize the performance of each module.
for i in range(len(model.blocks)):
if model.auxiliary_nets[i] is not None:
modules_accumulators[i].update(
mean_loss=aux_nets_losses[i].item(),
num_corrects=torch.sum(torch.eq(aux_nets_predictions[i], labels.data)).item(),
n_samples=inputs.size(0)
)
if log_to_wandb:
assert training_step is not None
for i, modules_accumulator in enumerate(modules_accumulators):
if modules_accumulator is not None:
wandb.log(data=modules_accumulator.get_dict(prefix=f'module#{i}_test'), step=training_step)
final_accumulator = modules_accumulators[-2] # Last one is None because last block is MaxPool with no aux-net.
return final_accumulator.get_mean_loss(), final_accumulator.get_accuracy()
def evaluate_local_model(model, criterion, dataloader, device, training_step=None, log_to_wandb: bool = True):
"""
Evaluate the given model on the test set.
In addition to returning the final test loss & accuracy,
this function evaluate each one of the model local modules (by logging to wandb).
:param model: The model
:param criterion: The criterion.
:param dataloader: The test set data-loader.
:param device: The device to use.
:param training_step: The training-step (integer), important to wandb logging.
:param log_to_wandb: Whether to log to wandb or not.
:return: The test set loss and accuracy.
"""
raise DeprecationWarning('This function is deprecated, and will be removed in the future.')
model.eval()
n_modules = len(model.blocks)
modules_accumulators = [Accumulator() if (aux_net is not None) else None for aux_net in model.auxiliary_nets]
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
inputs_representation = inputs
for i in range(n_modules):
result = model(inputs_representation, first_block_index=i, last_block_index=i)
inputs_representation, outputs = result[0], result[1]
if outputs is not None:
loss = criterion(outputs, labels)
modules_accumulators[i].update(
mean_loss=loss.item(),
num_corrects=torch.sum(torch.eq(torch.max(outputs, dim=1)[1], labels.data)).item(),
n_samples=inputs.size(0)
)
if log_to_wandb:
assert training_step is not None
for i, modules_accumulator in enumerate(modules_accumulators):
if modules_accumulator is not None:
wandb.log(data=modules_accumulator.get_dict(prefix=f'module#{i}_test'), step=training_step)
final_accumulator = modules_accumulators[-2] # Last one is None because last block is MaxPool with no aux-net.
return final_accumulator.get_mean_loss(), final_accumulator.get_accuracy()
def perform_train_step_dgl(model, inputs, labels, criterion, optimizers, training_step,
modules_accumulators,
log_interval: int = 100):
"""
Perform a train-step for a model trained with DGL.
The difference between the regular train-step and this one is that the model forward pass
is done iteratively for each block in the model, performing backward pass and optimizer step for each block
(using its corresponding auxiliary network).
:param model: The model.
:param inputs: The inputs.
:param labels: The labels.
:param criterion: The criterion.
:param optimizers: The optimizers (one for each local module in the whole model).
:param training_step: The training-step (integer), important to wandb logging.
:param modules_accumulators: Accumulators for each local module.
:param log_interval: How many training/testing steps between each logging (to wandb).
:return: The loss of this train-step, as well as the predictions.
"""
raise DeprecationWarning('This function is deprecated, and will be removed in the future.')
inputs_representation = torch.clone(inputs)
loss, predictions = None, None
for i in range(len(model.blocks)):
inputs_representation, outputs = model(inputs_representation, first_block_index=i, last_block_index=i)
if outputs is not None:
assert optimizers[i] is not None, "If the module has outputs it means it has an auxiliary-network " \
"attached so it should has trainable parameters to optimize."