-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconv_cINN_make_model.py
1904 lines (1496 loc) · 95.9 KB
/
conv_cINN_make_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 17 12:34:11 2021
@author: John S. Hyatt
"""
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import (Model,
metrics)
from tensorflow.keras.layers import (Input,
Activation,
LeakyReLU,
LayerNormalization,
Convolution2D,
Concatenate)
from tensorflow.keras.backend import int_shape
from conv_cINN_base_functions import dilated_residual_block
import numpy as np
###############################################################################
"""
With a naive implementation, training starts at NaN loss for too-large networks (really, anything that isn't unrealistically small) due to log_prob(z) = log_prob(f(x)) blowing up when the number of dimensions is not very small. This can be fixed by using Orthogonal kernel initializer (using layer normalization provides some benefit as well).
For datasets such as MNIST, where many elements are always the same (e.g. the black pixels at the edges), dim(X)<<dim(Z) (even without taking into account the fact that a particular real dataset is a low-dimensional embedding in R^N), and the map can't be bijective. Training completely destabilizes once annealing finishes. However, including a small amount of noise in X resolves this problem, since even though most of the dimensions are still meaningless, they are no longer unchanging (effectively, it replaces "delta distributions" of unchanging black background pixels with narrow, but non-delta, distributions of pixel values. The gradients the model receives as a result are meaningless and the noise can be removed during inference. This step is not needed when the dataset does not have pixels of fixed value (like a background).
In the original RealNVP paper, the authors replace x -> logit(a + (1-a)* x/x_max), where a is a small floor (say 0.01), x_max is the highest allowed value of x, and logit(p) = log(p / (1-p)). This apparently serves two purposes:
1) Probability is defined over an unbounded domain, but many x (e.g., images) have a hard boundary, for example pixel values in the range [0,255]. Logit maps a function on the domain [0,1] to (-inf, inf), making the domain unbounded.
2) An unrealistically high prediction of logit-intensity will still yield the a realistic intensity (either the minimum or maximum) when de-logit-ized. That is, the generated distribution of intensity should more closely match the true intensity distribution than the generated distribution of logit-intensity matches the true logit-intensity distribution.
I have experimented with modeling x directly, as there's "not much" difference between a hard bounded domain and an unbounded one whose probability density outside the boundary is very low. Because the majority of values logit(p) takes are closely associated with either p~0 or p~1, this may overemphasize high/low pixel intensities.
On the other hand, experimentally, this seems to allow unnaturally high/low pixel values in latent space and, correspondingly, pixel values in x space outside of the allowed boundaries, due to the nonzero density that in practice is modeled outside of the bounded regions. (This can be ambiguous depending on the image preprocessing: generally, it is impossible to assign bounds to the true data distribution with perfect certainty.) Logit-izing the intensity hides this problem; an equivalent solution is just to set intensities below or above the allowed values to the min or max.
They also did L2 weight normalization in addition to (modified) batch normalization. As pointed out in [T. van Laarhoven, "L2 Regularization vs. Batch and Weight Normalization," NeurIPS 2017], it doesn't make sense to do both (or use weight normalization with layer norm, for that matter).
As a result I have NOT used the L2 weight normalization. Instead I have just replaced their modified batch norm with regular layer norm.
Some of this code is based on the https://github.com/taesungp/real-nvp/blob/master/real_nvp/nn.py implementation of the architecture described in https://arxiv.org/pdf/1605.08803.pdf.
"""
"""
NOTE: On some hardware I have observed a speedup when @tf.function-decorating the forward_and_Jacobian() and backward() methods on the new layers defined below. On other hardware I have observed no speedup or even a slight slowdown. I've had the same experience when decorating the call() method in cFlow() instead. I've left both sets of decorators in this code, but commented; uncomment them if you want to experiment.
"""
###############################################################################
"""
GENERIC LAYER CLASS
"""
# Adding some default methods to the Layer class. These will be called by the Layers used to build the model.
# This does NOT remove anything from the Layer class, only adds to it.
class Layer(tf.keras.layers.Layer):
"""
Used to define other Layers later. This is an abstract class that can propagate in both directions.
In the forward direction, it passes along u (the input), the log determinant of the Jacobian, and z (specifically, the already-factored-out z components).
In the backward direction, it passes along v (the input) and z (specifically, the factored-out z components that haven't yet been factored back in).
"""
# =============================================================================
# @tf.function
# =============================================================================
def forward_and_Jacobian(self,
u, # the input in a chain x -> z
sum_log_det_J, # the total Jacobian so far
z): # the already factored-out z so far
# It doesn't do anything (this is just an abstract class)
raise NotImplementedError(str(type(self)))
# =============================================================================
# @tf.function
# =============================================================================
def backward(self,
v, # the input in a chain z -> x
z): # the factored-out z that hasn't yet been reintegrated
# It doesn't do anything (this is just an abstract class)
raise NotImplementedError(str(type(self)))
###############################################################################
"""
A LAYER FOR PERFORMING SCALAR MULTIPLICATION OF THE MODEL INPUTS
"""
class tanh_scaling_layer(Layer):
"""
Multiplies the inputs by a trainable scalar variable.
"""
# Layer class takes kwargs (name and dtype). Best practice is to pass these to the parent class.
def __init__(self, **kwargs):
super(tanh_scaling_layer, self).__init__(**kwargs)
# Add the tanh scale, initialized at 1.
# Not defining a shape defaults it to a scalar.
def build(self, input_shape):
initializer = tf.keras.initializers.Ones()
self.w = self.add_weight(initializer=initializer,
trainable=True)
def call(self, inputs):
return tf.math.scalar_mul(self.w,
inputs)
# Because this Layer defines __init__(), get_config() needs to be redefined to serialize the model.
# This will include the kwargs passed to the parent class in __init__().
def get_config(self):
config = super(tanh_scaling_layer, self).get_config()
return config
###############################################################################
"""
THE SPATIAL SQUEEZING AND z FACTORING LAYERS
"""
class squeeze_layer(Layer):
"""
In the forward direction, this Layer spatially halves (in both dimensions, quadrupling the channel depth) both the input and any already-factored-out zy, while passing along log_det_J.
In the backward direction, this Layer spatially doubles (in both dimensions, quartering the channel depth) both the input and any not-yet-refactored-in zy.
Note that, unlike the original RealNVP, the "output" side for is zy, not z! This is a conditional model, so y has to be passed through as well.
"""
# Layer class takes kwargs (name and dtype). Best practice is to pass these to the parent class.
def __init__(self, **kwargs):
super(squeeze_layer, self).__init__(**kwargs)
# Because this Layer defines __init__(), get_config() needs to be redefined to serialize the model.
# This will include the kwargs passed to the parent class in __init__().
def get_config(self):
config = super(squeeze_layer, self).get_config()
return config
# From XY -> ZY (U -> V)
# =============================================================================
# @tf.function
# =============================================================================
def forward_and_Jacobian(self,
u,
sum_log_det_J,
zy):
"""
Args:
u: the input, going in the xy -> zy direction.
sum_log_det_J: the total sum of the Jacobian contributions to the loss so far.
zy: the already-factored-out zy so far.
Returns:
v: the output, going in the xy -> zy direction, reshaped but not otherwise changed from u by this Layer.
sum_log_det_J: the total sum of the Jacobian contributions to the loss so far (not changed by this Layer).
zy: the already-factored-out zy so far, reshaped but not otherwise changed by this Layer.
"""
u_shape = int_shape(u)
# In order to be halved spatially, the spatial dimensions must be divisible by 2.
assert u_shape[1] % 2 == 0 \
and u_shape[2] % 2 == 0, \
'u must have spatial dimensions divisible by 2.'
v = tf.nn.space_to_depth(u, 2) # halving the spatial dimensions
# Need to keep zy spatial dims the same so new zy factors can be concatenated on.
if zy is not None:
zy = tf.nn.space_to_depth(zy, 2) # halving each spatial dimension
return v, sum_log_det_J, zy
# From ZY -> XY (V -> U)
# =============================================================================
# @tf.function
# =============================================================================
def backward(self,
v, # the input in a chain zy -> xy
zy): # the factored-out zy that hasn't yet been reintegrated
"""
Args:
v: the input, going in the zy -> xy direction.
zy: the factored-out zy that hasn't yet been reintegrated.
Returns:
u: the output, going in the zy -> xy direction, reshaped but not otherwise changed from v by this Layer.
zy: the factored-out zy so far, reshaped but not otherwise changed by this Layer.
"""
v_shape = int_shape(v)
# Doubling both spatial dimensions means quartering the channel depth.
assert v_shape[3] % 4 == 0, \
'v must have channel dimensions divisible by 4.'
u = tf.nn.depth_to_space(v, 2) # doubling the spatial dimensions
# Need to "undo" the compression of zy as it propagates back through
if zy is not None:
zy = tf.nn.depth_to_space(zy, 2)
return u, zy
class factor_out_zy_layer(Layer):
"""
In the forward direction, this Layer factors out half of the non-zy channels into zy.
In the backward direction, this Layer factors as many channels back in from zy.
Note that the backward direction DOES perform the first unfactor, while the forward direction DOES NOT perform the final concatenation to zy.
"""
# Layer class takes kwargs (name and dtype). Best practice is to pass these to the parent class.
def __init__(self,
num_prev_factors,
**kwargs):
"""
Args:
num_prev_factors: how many times the model has already had part of zy factored out, going in the XY -> ZY direction.
"""
super(factor_out_zy_layer, self).__init__(**kwargs)
self.num_prev_factors = num_prev_factors
# Because this Layer defines __init__(), get_config() needs to be redefined to serialize the model.
# This will include the kwargs passed to the parent class in __init__().
def get_config(self):
config = super(factor_out_zy_layer, self).get_config()
# Add the new arguments into config.
config.update({'num_prev_factors' : self.num_prev_factors})
return config
# From XY -> ZY (U -> V)
# =============================================================================
# @tf.function
# =============================================================================
def forward_and_Jacobian(self,
u,
sum_log_det_J,
zy):
"""
Args:
u: the input, going in the xy -> zy direction.
sum_log_det_J: the total sum of the Jacobian contributions to the loss so far.
zy: the already-factored-out zy so far.
Returns:
v: the output, going in the xy -> zy direction, now with additional zy factored out from u.
sum_log_det_J: the total sum of the Jacobian contributions to the loss so far (not changed by this Layer).
zy: the already-factored-out zy so far concatenated with the zy factored out by this Layer.
"""
u_shape = int_shape(u)
# Half of the channels will be factored into zy.
split = u_shape[3] // 2
factored_zy = u[..., :split]
v = u[..., split:]
# If this is the first time, there's no prior zy to concatenate to. Otherwise, concatenate onto the already-factored-out zy.
if zy is not None:
zy = tf.concat([zy, factored_zy],
axis=3)
else:
zy = factored_zy
return v, sum_log_det_J, zy
# From ZY -> XY (V -> U)
# =============================================================================
# @tf.function
# =============================================================================
def backward(self,
v,
zy):
"""
Args:
v: the input, going in the zy -> xy direction.
zy: the factored-out zy that hasn't yet been reintegrated.
Returns:
u: the output, going in the zy -> xy direction, now with additional components reintegrated from zy.
zy: the factored-out zy minus the part that was reintegrated by this Layer.
"""
# At scale s, (1/2)^(s+1) of the original dimensions have been factored out.
# For a backwards pass at scale s, (1/2)^s of zy should be factored back in.
zy_shape = int_shape(zy)
if v is None: # the last layer is all zy
split = zy_shape[3] // (2**self.num_prev_factors)
else: # the input has some v and some not-yet-refactored zy
split = int_shape(v)[3]
reintegrated_v = zy[..., -split:]
zy = zy[..., :-split]
# Redundant assertion.
assert int_shape(reintegrated_v)[3] == split
if v is not None:
u = tf.concat([reintegrated_v, v], 3)
else:
u = reintegrated_v
return u, zy
###############################################################################
"""
DEFINE A SCALING COUPLING LAYER
"""
class coupling_layer(Layer):
"""
RealNVP coupling layer, with modifications:
The most important one is that instead of x and z, the end states of the total model are xy and zy. As a result, additional dimensions propagate through each layer!
Almost as important, because xy and zy have more than one dimension, channelwise masking and coupling is performed BEFORE the squeeze/factor operation in each block. I also only include one of each mask (in order 0, 1, 2, 3) rather than multiples (in order 0, 1, 0, 2, 3, 2). This does not directly affect anything INSIDE each coupling layer, but it does affect the way they are structured in the stack.
The masks are constructed differently, although conceptually they are the same. In the original paper, the mask was applied to the input (zeroing half of its elements), which was then fed into the coupling function. The mask was then re-applied to the output of the coupling function to "manually" re-zero those elements, before combining it with the other half of the input (obtained via complementary mask). This results in many useless computations (they are just zeroed out by the second, post-coupling mask) AND makes inefficient use of the convolutional kernels, since either 50% of spatial dimensions or 50% of channel dimensions are zero. It might also lead to the kernels learning weird behavior (which would interfere with LayerNorm, since that would be applied before re-masking), complicate analysis of intermediate representations, and for the checkerboard masks in particular, would lead to kernels covering a smaller spatial dimension than necessary. As a result, I compress uv1 (the inputs to A and b) and uv2 so that the calculations v2 = A(u1)*u2 + b(u1) and u2 = A^-1(v1)*(v2 - b(v1)) can be performed without all those extraneous zeros. Note that this does make the MASKING process more computationally intensive, AND memory intensive: only the compressed version of uv2 is required, but both the compressed AND uncompressed versions of uv1 are required (and must be stored simultaneously). vu2 must then be decompressed and added to the uncompressed vu1 = uv1.
A side benefit from this is that the checkerboard and channelwise masks have different spatial scales, meaning that this implementation probably has somewhat better inductive biases than the original.
Due to the fact that the model maps XY -> ZY and the log-likelihood-of-Y-equivalent loss term requires access to Y, I don't want to just apply squeezing and factoring to generate a 1D vector Z: the original spatial relationships between pixels are important. Although that loss term does not REQUIRE Y to preserve the spatial relationships for the choice of metric I use (L norm), it might for other choices of metric. As a result, after fully squeezing and factoring ZY, I reshape it to match the original shape of XZ. (This might be a target for efficiency improvements in an actual applied implementation of this model.)
I use LayerNorm instead of the original paper's modified BatchNorm.
I do not use L2 weight normalization due to already using LayerNorm. L2 weight normalization and normalization schemes like batch or layer normalization are redundant; see van Laarhoven, "L2 Regularization versus Batch and Weight Normalization" (2017).
I do not use tanh plus a learnable scaling parameter.
Logits on f_X are included by default, but they can be replaced by just using f_X by commenting lines in the code, labeled below.
I include dilated convolutions in the ResNeXt blocks that define A and b.
I use LeakyReLU instead of ReLU.
"""
# Layer class takes kwargs (name and dtype). Best practice is to pass these to the parent class.
def __init__(self,
in_shape,
which_mask,
num_res_blocks,
cardinality,
num_kernels,
kernel_size,
init,
LAYER_NORM=False,
which_dilations=[1,2,4],
**kwargs):
"""
Args:
in_shape: a list with the shape of the input tensor, [height, width, channel_depth]. Height and width are specified so that the model can be built before calling the coupling function. Depth is specified because if the number of channels in the input is odd (only the case before the first squeeze/factoring layer, if at all), then masks 2 and 3 split the extra channel differently. (The depth doesn't affect anything if the number of channels in the input is even.)
which_mask: either 0, 1, 2, or 3
if 0: mask = [1,1,...,1] [0,0,...,0] [1,1,...,1]...
[0,0,...,0] [1,1,...,1] [0,0,...,0]...
[1,1,...,1] [0,0,...,0] [1,1,...,1]...
... ... ...
elif 1: mask = [0,0,...,0] [1,1,...,1] [0,0,...,0]...
[1,1,...,1] [0,0,...,0] [1,1,...,1]...
[0,0,...,0] [1,1,...,1] [0,0,...,0]...
... ... ...
elif 2: mask = [1,0,1,0,...] [1,0,1,0,...] [1,0,1,0,...]...
[1,0,1,0,...] [1,0,1,0,...] [1,0,1,0,...]...
[1,0,1,0,...] [1,0,1,0,...] [1,0,1,0,...]...
... ... ...
elif 3: mask = [0,1,0,1,...] [0,1,0,1,...] [0,1,0,1,...]...
[0,1,0,1,...] [0,1,0,1,...] [0,1,0,1,...]...
[0,1,0,1,...] [0,1,0,1,...] [0,1,0,1,...]...
... ... ...
num_res_blocks: number of residual blocks in each of the neural networks defining A and b.
cardinality: cardinality of the residual blocks.
num_kernels: number of kernels per convolutional layer. Note that this number is halved for masks 0 and 1, whose compressed forms have twice as many channels/half as much spatial extent. I therefore assign half as many kernels to these layers as the corresponding channelwise-masked inputs.
kernel_size: size of the convolutional kernels.
init: the kernel initializer.
LAYER_NORM: Boolean. Whether or not the residual blocks will have layer normalization or not.
which_dilations: A list of the dilation factors to be used in parallel in the residual block. For example, [1,2,4] has, in parallel: no dilation, dilation with 1 zero between nonzero elements, and dilation with 3 zeros. With 3x3 kernels, [1,2,4] gives overlapping receptive fields up to 9x9 and [1,2,4,8] gives receptive fields up to 17x17. With 4x4 kernels, [1,3] gives overlapping receptive fields up to 10x10.
Note that A and b have the same hyperparameters within a given coupling layer (for a given model, profiling might show that this is not necessary/efficient!) and larger dilations within a layer are given progressively fewer kernels (the assumption being that the larger-length-scale correlations between successive layers can be well incorporated with relatively few kernels; moreover, a model's receptive field typically grows as it gets deeper, and because the intermediate representations' spatial scales decrease every time the model is squeeze/factored).
"""
super(coupling_layer, self).__init__(**kwargs)
# Hyperparameters for defining A and b.
# In this code, these are the same for both A and b, and for all coupling blocks, except that fewer factors listed in which_dilations will be used for smaller-spatial-dimension inputs. A specific implementation would probably benefit from more customized tuning.
self.input_height = in_shape[0]
self.input_width = in_shape[1]
self.input_depth = in_shape[2]
self.which_mask = which_mask
self.num_res_blocks = num_res_blocks
self.cardinality = cardinality
self.kernel_size = kernel_size
self.init = init
self.LAYER_NORM = LAYER_NORM
self.which_dilations = which_dilations
assert self.input_height % 2 == 0 \
and self.input_width % 2 == 0, \
'u/v must have spatial dimensions divisible by 2.'
# Masks 0 and 1 correspond to checkerboard masks, whose compressed forms have twice as many channels/half as much spatial extent, and therefore half as many kernels, as the channelwise-masked inputs.
if self.which_mask in [0,1]:
self.num_kernels = int(num_kernels / 2)
elif self.which_mask in [2,3]:
self.num_kernels = num_kernels
# self.which_mask defines the mask used to obtain uv1. The mask used to obtain uv2 is its complement.
if self.which_mask==0:
self.which_mask_complement = 1
elif self.which_mask==1:
self.which_mask_complement = 0
elif self.which_mask==2:
self.which_mask_complement = 3
elif self.which_mask==3:
self.which_mask_complement = 2
# Get the shape of the MASKED, COMPRESSED input to the coupling function.
self.get_masked_compressed_shape()
# Build the neural networks A and b used in the coupling function.
self.model_A, self.model_b = self.coupling_function()
# Because this Layer defines __init__(), get_config() needs to be redefined to serialize the model.
# This will include the kwargs passed to the parent class in __init__().
def get_config(self):
config = super(coupling_layer, self).get_config()
# Add the new arguments into config.
config.update({'input_height' : self.input_height,
'input_width' : self.input_width,
'input_depth' : self.input_depth,
'which_mask' : self.which_mask,
'num_res_blocks' : self.num_res_blocks,
'cardinality' : self.cardinality,
'kernel_size' : self.kernel_size,
'init' : self.init,
'LAYER_NORM' : self.LAYER_NORM,
'which_dilations' : self.which_dilations})
return config
def A_wrapper(self,
A_input):
"""
A wrapper for model A to allow it to be used inside of a TF graph.
"""
return self.model_A(A_input)
def b_wrapper(self,
b_input):
"""
A wrapper for model b to allow it to be used inside of a TF graph.
"""
return self.model_b(b_input)
def get_masked_compressed_shape(self):
"""
Function for obtaining the shape of the (masked, compressed) u1/v1 input to the coupling function, given the full input shape and the mask type.
"""
if self.which_mask in [0,1]:
self.compressed_height = int(self.input_height / 2)
self.compressed_width = int(self.input_width / 2)
self.compressed_depth = 2 * self.input_depth
elif self.which_mask in [2,3]:
self.compressed_height = self.input_height
self.compressed_width = self.input_width
# Masks 2 and 3 have different depths if the total number of channels is odd.
if self.which_mask == 2:
self.compressed_depth = int(np.ceil(self.input_depth/2))
elif self.which_mask == 3:
self.compressed_depth = int(np.floor(self.input_depth/2))
def mask(self,
uv,
which_mask_index,
compress):
"""
Args:
uv = input of shape (batch_size, uv_h, uv_w, uv_d). Going in the forward direction, uv = u; going in the backward direction, uv = v.
batch_size: number of elements in the batch
uv_h: the height of the input tensor (currently, must be even).
uv_w: the width of the input tensor (currently, must be even).
uv_d: the depth of the input tensor (can be even or odd).
which_mask_index: one of 0,1,2,3
if 0: mask = [1,1,...,1] [0,0,...,0] [1,1,...,1]...
[0,0,...,0] [1,1,...,1] [0,0,...,0]...
[1,1,...,1] [0,0,...,0] [1,1,...,1]...
... ... ...
elif 1: mask = [0,0,...,0] [1,1,...,1] [0,0,...,0]...
[1,1,...,1] [0,0,...,0] [1,1,...,1]...
[0,0,...,0] [1,1,...,1] [0,0,...,0]...
... ... ...
elif 2: mask = [1,0,1,0,...] [1,0,1,0,...] [1,0,1,0,...]...
[1,0,1,0,...] [1,0,1,0,...] [1,0,1,0,...]...
[1,0,1,0,...] [1,0,1,0,...] [1,0,1,0,...]...
... ... ...
elif 3: mask = [0,1,0,1,...] [0,1,0,1,...] [0,1,0,1,...]...
[0,1,0,1,...] [0,1,0,1,...] [0,1,0,1,...]...
[0,1,0,1,...] [0,1,0,1,...] [0,1,0,1,...]...
... ... ...
compress: Boolean. Removes the masked (zeroed) entries if True.
Returns:
uv_masked: Denoting specific elements in uv by
A1 B1 A1 B1... A2 B2 A2 B2... A3 B3 A3 B3... ...
C1 D1 C1 D1... C2 D2 C2 D2... C3 D3 C3 D3... ...
A1 B1 A1 B1... A2 B2 A2 B2... A3 B3 A3 B3... ...
C1 D1 C1 D1... C2 D2 C2 D2... C3 D3 C3 D3... ...
.. .. .. .. .. .. .. .. .. .. .. .. ...
with 1, 2, 3, ... denoting channels:
if compress==True:
if which_mask==0:
this produces an output tensor with shape
(batch, uv_h/2, uv_w/2, uv_d*2):
A1 A1... A2 A2... ... D1 D1... D2 D2... ...
A1 A1... A2 A2... ... D1 D1... D2 D2... ...
.. .. .. .. ... .. .. .. .. ...
elif which_mask==1:
this produces an output tensor with shape
(batch, uv_h/2, uv_w/2, uv_d*2):
B1 B1... B2 B2... ... C1 C1... C2 C2... ...
B1 B1... B2 B2... ... C1 C1... C2 C2... ...
.. .. .. .. ... .. .. .. .. ...
elif which_mask==2:
this produces an output tensor with shape
(batch, uv_h, uv_w, ceil(uv_d/2)):
A1 B1 A1 B1... A3 B3 A3 B3... A5 B5 A5 B5... ...
C1 D1 C1 D1... C3 D3 C3 D3... C5 D5 C5 D5... ...
A1 B1 A1 B1... A3 B3 A3 B3... A5 B5 A5 B5... ...
C1 D1 C1 D1... C3 D3 C3 D3... C5 D5 C5 D5... ...
.. .. .. .. .. .. .. .. .. .. .. .. ...
elif which_mask==3:
this produces an output tensor with shape
(batch, uv_h, uv_w, floor(uv_d/2)):
A2 B2 A2 B2... A4 B4 A4 B4... ...
C2 D2 C2 D2... C4 D4 C4 D4... ...
A2 B2 A2 B2... A4 B4 A4 B4... ...
C2 D2 C2 D2... C4 D4 C4 D4... ...
.. .. .. .. .. .. .. .. ...
(The smallest possible channel depth is 2, with mask==2 corresponding to channel 1, and mask==3 corresponding to channel 2.)
elif compress==False:
this produces an output tensor with shape (batch, uv_h, uv_w, uv_d):
if which_mask==0:
A1 00 A1 00... A2 00 A2 00... A3 00 A3 00... ...
00 D1 00 D1... 00 D2 00 D2... 00 D3 00 D3... ...
A1 00 A1 00... A2 00 A2 00... A3 00 A3 00... ...
00 D1 00 D1... 00 D2 00 D2... 00 D3 00 D3... ...
.. .. .. .. .. .. .. .. .. .. .. .. ...
elif which_mask==1:
00 B1 00 B1... 00 B2 00 B2... 00 B3 00 B3... ...
C1 00 C1 00... C2 00 C2 00... C3 00 C3 00... ...
00 B1 00 B1... 00 B2 00 B2... 00 B3 00 B3... ...
C1 00 C1 00... C2 00 C2 00... C3 00 C3 00... ...
.. .. .. .. .. .. .. .. .. .. .. .. ...
elif which_mask==2:
A1 B1 A1 B1... 00 00 00 00... A3 B3 A3 B3... ...
C1 D1 C1 D1... 00 00 00 00... C3 D3 C3 D3... ...
A1 B1 A1 B1... 00 00 00 00... A3 B3 A3 B3... ...
C1 D1 C1 D1... 00 00 00 00... C3 D3 C3 D3... ...
.. .. .. .. .. .. .. .. .. .. .. .. ...
elif which_mask==3:
00 00 00 00... A2 B2 A2 B2... 00 00 00 00... ...
00 00 00 00... C2 D2 C2 D2... 00 00 00 00... ...
00 00 00 00... A2 B2 A2 B2... 00 00 00 00... ...
00 00 00 00... C2 D2 C2 D2... 00 00 00 00... ...
.. .. .. .. .. .. .. .. .. .. .. .. ...
"""
# The input tensor must have the correct spatial and channel dimensions.
uv = tf.ensure_shape(uv,shape=[None,
self.input_height,
self.input_width,
self.input_depth])
(batch_size,
uv_h,
uv_w,
uv_d) = int_shape(uv)
# Only need to explicitly include the checkerboard mask zeros if the output is not compressed, since if it is compressed it just skips over the zeroed elements.
if not compress:
if which_mask_index==0:
ones = tf.ones(uv_d,)
zeros = tf.zeros(uv_d,)
mask = tf.stack([[
[ones,
zeros],
[zeros,
ones]
]])
elif which_mask_index==1:
ones = tf.ones(uv_d,)
zeros = tf.zeros(uv_d,)
mask = tf.stack([[
[zeros,
ones],
[ones,
zeros]
]])
elif which_mask_index==2:
one_indices = tf.range(0,
uv_d,
2)
one_indices = tf.expand_dims(one_indices,
axis=-1)
# mask==2 has the extra channel if there are an odd number, so use the ceiling function.
ones = tf.ones(
tf.cast(
tf.math.ceil(uv_d/2),
dtype=tf.int32))
mask_shape = tf.constant([uv_d])
mask = tf.scatter_nd(one_indices,
ones,
mask_shape)
mask = tf.stack([[
[mask,
mask],
[mask,
mask]
]])
elif which_mask_index==3:
one_indices = tf.range(1,
uv_d,
2)
one_indices = tf.expand_dims(one_indices,
axis=-1)
# mask==3 doesn't have the extra channel if there are an odd number, so use the floor function.
ones = tf.ones(
tf.cast(
tf.math.floor(uv_d/2),
dtype=tf.int32))
mask_shape = tf.constant([uv_d])
mask = tf.scatter_nd(one_indices,
ones,
mask_shape)
mask = tf.stack([[
[mask,
mask],
[mask,
mask]
]])
# Tile to the correct height/width
mask = tf.tile(mask,
[1,
int(uv_h/2),
int(uv_w/2),
1])
# The mask must go over the entire batch.
# Using the einsum method is much more efficient than explicitly repeating the mask `batch_size` times.
mask = mask[0]
uv_masked = tf.einsum('jkl,ijkl->ijkl',
mask,
uv)
# If we are compressing it, just skip over the unwanted components.
elif compress:
# For checkerboard masking, we need to skip every other element AND stack the offset remaining elements in the checkerboard.
if which_mask_index in [0,1]:
# Note the repeating indices.
if which_mask_index==0:
uv_c0 = uv[:,
0::2,
0::2,
:]
uv_c1 = uv[:,
1::2,
1::2,
:]
# Note the alternating indices.
elif which_mask_index==1:
uv_c0 = uv[:,
0::2,
1::2,
:]
uv_c1 = uv[:,
1::2,
0::2,
:]
uv_masked = Concatenate(axis=-1)([uv_c0,
uv_c1])
# For channel-wise masking, we just drop the unwanted channels.
elif which_mask_index in [2,3]:
if which_mask_index==2:
uv_masked = uv[...,
0::2]
elif which_mask_index==3:
uv_masked = uv[...,
1::2]
return uv_masked
def decompress_mask(self,
uv_masked_compressed,
which_mask_index,
uv_shape_OUTPUT):
"""
Args:
uv_masked_compressed = compressed tensor input of shape (batch_size, uv_h_c, uv_w_c, uv_d_c). Going in the forward direction, uv = u; going in the backward direction, uv = v.
batch_size: number of elements in the batch.
uv_h_c: the height of the COMPRESSED input tensor.
uv_w_c: the width of the COMPRESSED input tensor.
uv_d_c: the depth of the COMPRESSED input tensor.
which_mask_index: one of 0,1,2,3
if 0: mask = [1,1,...,1] [0,0,...,0] [1,1,...,1]...
[0,0,...,0] [1,1,...,1] [0,0,...,0]...
[1,1,...,1] [0,0,...,0] [1,1,...,1]...
... ... ...
elif 1: mask = [0,0,...,0] [1,1,...,1] [0,0,...,0]...
[1,1,...,1] [0,0,...,0] [1,1,...,1]...
[0,0,...,0] [1,1,...,1] [0,0,...,0]...
... ... ...
elif 2: mask = [1,0,1,0,...] [1,0,1,0,...] [1,0,1,0,...]...
[1,0,1,0,...] [1,0,1,0,...] [1,0,1,0,...]...
[1,0,1,0,...] [1,0,1,0,...] [1,0,1,0,...]...
... ... ...
elif 3: mask = [0,1,0,1,...] [0,1,0,1,...] [0,1,0,1,...]...
[0,1,0,1,...] [0,1,0,1,...] [0,1,0,1,...]...
[0,1,0,1,...] [0,1,0,1,...] [0,1,0,1,...]...
... ... ...
uv_shape_OUTPUT = tuple of shape (batch_size, uv_h, uv_w, uv_d) defining the UNCOMPRESSED output shape. This is necessary because the depth of the uncompressed tensor (although not the height or width) can be odd. If this is the case, masks 2 and 3 will have different output depths; otherwise they will have the same output depth.
Returns:
uv_masked_uncompressed: the uncompressed version of uv_masked_compressed.
if mask in [0,1]: uv_masked_compressed has shape
(batch, uv_h/2, uv_w/2, uv_d*2)
elif mask==2: uv_masked_compressed has shape
(batch, uv_h, uv_w, ceil(uv_d/2))
elif mask==3: uv_masked_compressed has shape
(batch, uv_h, uv_w, floor(uv_d/2))
uv_masked_uncompressed has shape
(batch, uv_h, uv_w, uv_d).
NOTE: uv_h and uv_w are both currently required to be even. uv_d may be odd, in which case uv_masked_compressed has 1 more channel for mask==2, or 1 less channel for mask==3.
Denoting specific elements in uv by
A1 B1 A1 B1... A2 B2 A2 B2... A3 B3 A3 B3... ...
C1 D1 C1 D1... C2 D2 C2 D2... C3 D3 C3 D3... ...
A1 B1 A1 B1... A2 B2 A2 B2... A3 B3 A3 B3... ...
C1 D1 C1 D1... C2 D2 C2 D2... C3 D3 C3 D3... ...
.. .. .. .. .. .. .. .. .. .. .. .. ...
with 1, 2, 3, ... denoting channels:
if which_mask==0:
this converts an input tensor
A1 A1... A2 A2... ... D1 D1... D2 D2... ...
A1 A1... A2 A2... ... D1 D1... D2 D2... ...
.. .. .. .. ... .. .. .. .. ...
to output tensor
A1 00 A1 00... A2 00 A2 00... ...
00 D1 00 D1... 00 D2 00 D2... ...
A1 00 A1 00... A2 00 A2 00... ...
00 D1 00 D1... 00 D2 00 D2... ...
.. .. .. .. .. .. .. .. ...
elif which_mask==1:
this converts an input tensor
B1 B1... B2 B2... ... C1 C1... C2 C2... ...
B1 B1... B2 B2... ... C1 C1... C2 C2... ...
.. .. .. .. ... .. .. .. .. ...
to output tensor
00 B1 00 B1... 00 B2 00 B2... ...
C1 00 C1 00... C2 00 C2 00... ...
00 B1 00 B1... 00 B2 00 B2... ...
C1 00 C1 00... C2 00 C2 00... ...
.. .. .. .. .. .. .. .. ...
elif which_mask==2:
this converts an input tensor
A1 B1 A1 B1... A3 B3 A3 B3... ...
C1 D1 C1 D1... C3 D3 C3 D3... ...
A1 B1 A1 B1... A3 B3 A3 B3... ...
C1 D1 C1 D1... C3 D3 C3 D3... ...
.. .. .. .. .. .. .. .. ...
to output tensor
A1 B1 A1 B1... 00 00 00 00... A3 B3 A3 B3... ...
C1 D1 C1 D1... 00 00 00 00... C3 D3 C3 D3... ...
A1 B1 A1 B1... 00 00 00 00... A3 B3 A3 B3... ...
C1 D1 C1 D1... 00 00 00 00... C3 D3 C3 D3... ...
.. .. .. .. .. .. .. .. .. .. .. .. ...
elif which_mask==3:
this converts an input tensor
A2 B2 A2 B2... ...
C2 D2 C2 D2... ...
A2 B2 A2 B2... ...
C2 D2 C2 D2... ...
.. .. .. .. ...
to output tensor
00 00 00 00... A2 B2 A2 B2... 00 00 00 00... ...
00 00 00 00... C2 D2 C2 D2... 00 00 00 00... ...
00 00 00 00... A2 B2 A2 B2... 00 00 00 00... ...
00 00 00 00... C2 D2 C2 D2... 00 00 00 00... ...
.. .. .. .. .. .. .. .. .. .. .. .. ...
"""
(batch_size,
uv_h_c,
uv_w_c,
uv_d_c) = int_shape(uv_masked_compressed)
(batch_size,
uv_h,
uv_w,
uv_d) = uv_shape_OUTPUT
# Expanding the checkerboard mask.
if which_mask_index in [0,1]:
# Input shape is (batch, u_h_c, u_w_c, u_d_c) = (batch, u_h/2, u_w/2, u_d*2).
assert uv_d_c % 2 == 0, \
'The compressed, checkerboard-masked u/v should always have an even number of channels.'
# Each of the two offset checkerboards takes up half of uv_d_c.
uv_c0 = uv_masked_compressed[...,
:uv_d]
uv_c1 = uv_masked_compressed[...,
uv_d:]
# These indices determine the location of the compressed elements and the inserted zeros during expansion. Note that for expanding mask_0-compressed tensors, both expansions per channel are performed with the same index applied twice in sequence, while for the mask_1-compressed tensors, they are applied alternately.
# For non-square tensors, each dimension needs its own set of 0 and 1 checkerboard indices.
# NOTE: the "0" and "1" in h0, h1, w0, and w1 do NOT correspond to masks 0 and 1! They are indices for the two possible checkerboard offsets:
# 1 0 1 0 0 1 0 1
# 0 1 0 1 and 1 0 1 0
# 1 0 1 0 0 1 0 1
# 0 1 0 1 1 0 1 0
# The indices 0 and 1 go like:
# 0 1 0 1
# 0 X X X X
# 1 X X X X
# 0 X X X X
# 1 X X X X
# with h going across and w going down.
# In the compressed masked representation, zero entries in the masked input have been removed in favor of splitting alternating checkerboards into channels:
# mask 0 h0, w0 h1, w1
# 1 0 1 0 1 0 1 0 0 0 0 0
# 0 1 0 1 -> 0 0 0 0 + 0 1 0 1
# 1 0 1 0 1 0 1 0 0 0 0 0
# 0 1 0 1 0 0 0 0 0 1 0 1
# mask 1 h0, w1 h1, w0
# 0 1 0 1 0 1 0 1 0 0 0 0
# 1 0 1 0 -> 0 0 0 0 + 1 0 1 0
# 0 1 0 1 0 1 0 1 0 0 0 0
# 1 0 1 0 0 0 0 0 1 0 1 0
indices_h0 = tf.range(start=0,
limit=2*uv_h_c,
delta=2)
indices_h1 = tf.range(start=1,
limit=2*uv_h_c+1,
delta=2)
indices_w0 = tf.range(start=0,
limit=2*uv_w_c,
delta=2)
indices_w1 = tf.range(start=1,
limit=2*uv_w_c+1,
delta=2)
indices_h0 = tf.expand_dims(indices_h0,
axis=-1) # shape (uv_h/2,1)
indices_w0 = tf.expand_dims(indices_w0,
axis=-1) # shape (uv_w/2,1)
indices_h1 = tf.expand_dims(indices_h1,
axis=-1) # shape (uv_h/2,1)
indices_w1 = tf.expand_dims(indices_w1,
axis=-1) # shape (uv_w/2,1)
####################################
# Expand compressed "A"/"B" channel.
####################################
# In order to use scatter_nd, transpose the batch dimension to the end.
updates = tf.transpose(uv_c0,
[1,2,3,0]) # shape (uv_h/2,uv_w/2,uv_d,batch)
# Expand by a factor of 2 in one of the two to-be-expanded dimensions.
shape = [2,1,1,1] * tf.shape(updates) # shape (uv_h,uv_w/2,uv_d,batch)
# `indices` tells scatter_nd where to put the zeros it is inserting.
# `updates` is the input that scatter_nd is expanding.
# `shape` tells scatter_nd what the output shape should be.
# `which_mask` determines whether the indices are identical or complementary. For this particular transform, it is the same for both masks.
scatter = tf.scatter_nd(indices_h0,
updates,
shape) # shape (uv_h,uv_w/2,uv_d,batch)
# Next, we need to expand the other dimension, so transpose those two and repeat the above steps.
updates = tf.transpose(scatter,
[1,0,2,3]) # shape (uv_w/2,uv_h,uv_d,batch)
shape = [2,1,1,1] * tf.shape(updates) # shape (uv_w,uv_h,uv_d,batch)
# This one does differ depending on the mask.
if which_mask_index==0:
scatter = tf.scatter_nd(indices_w0,
updates,
shape) # shape (uv_w,uv_h,uv_d,batch)
elif which_mask_index==1:
scatter = tf.scatter_nd(indices_w1,
updates,
shape) # shape (uv_w,uv_h,uv_d,batch)
# Now we have to transpose back into (batch_dim, height, width, channels)
uv_c0 = tf.transpose(scatter,
[3,1,0,2]) # shape (batch,uv_h,uv_w,uv_d)
####################################
# Expand compressed "C"/"D" channel.
####################################
# Exactly the same as before, except that the indices are different.
updates = tf.transpose(uv_c1,
[1,2,3,0]) # shape (uv_h/2,uv_w/2,uv_d,batch)
shape = [2,1,1,1] * tf.shape(updates) # shape (uv_h,uv_w/2,uv_d,batch)