3
3
from torch .nn .utils import clip_grad_norm_
4
4
import time
5
5
import sys
6
+ import random
6
7
import torchvision .models as models
7
8
from models .transformer import *
8
9
from .BigGAN_networks import *
@@ -96,7 +97,8 @@ def __init__(self, args):
96
97
super (Generator , self ).__init__ ()
97
98
self .args = args
98
99
INP_CHANNEL = self .args .num_examples
99
- if self .args .is_seq : INP_CHANNEL = 1
100
+ if self .args .is_seq :
101
+ INP_CHANNEL = 1
100
102
101
103
encoder_layer = TransformerEncoderLayer (self .args .tn_hidden_dim , self .args .tn_nheads ,
102
104
self .args .tn_dim_feedforward ,
@@ -267,12 +269,14 @@ def Eval(self, ST, QRS, QRS_pos):
267
269
hs = self .reparameterize (hs_mu , hs_logvar ).permute (1 , 0 , 2 ).unsqueeze (0 )
268
270
269
271
h = hs .transpose (1 , 2 )[- 1 ] # torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
270
- if self .args .add_noise : h = h + self .noise .sample (h .size ()).squeeze (- 1 ).to (self .args .device )
272
+ if self .args .add_noise :
273
+ h = h + self .noise .sample (h .size ()).squeeze (- 1 ).to (self .args .device )
271
274
272
275
h = self .linear_q (h )
273
276
h = h .contiguous ()
274
277
275
- if self .args .all_chars : h = torch .stack ([h [i ][QR [i ]] for i in range (self .args .self .args .batch_size )], 0 )
278
+ if self .args .all_chars :
279
+ h = torch .stack ([h [i ][QR [i ]] for i in range (self .args .self .args .batch_size )], 0 )
276
280
277
281
h = h .view (h .size (0 ), h .shape [1 ] * 2 , 4 , - 1 )
278
282
h = h .permute (0 , 3 , 2 , 1 )
@@ -282,7 +286,7 @@ def Eval(self, ST, QRS, QRS_pos):
282
286
OUT_IMGS .append (h .detach ())
283
287
284
288
return OUT_IMGS
285
-
289
+
286
290
def compute_style (self , ST ):
287
291
B , N , R , C = ST .shape
288
292
FEAT_ST = self .Feat_Encoder (ST .view (B * N , 1 , R , C ))
@@ -321,7 +325,8 @@ def forward(self, ST, QR, QRs=None, QR_pos=None, QRs_pos=None, mode='train'):
321
325
322
326
h = hs .transpose (1 , 2 )[- 1 ] # torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
323
327
324
- if self .args .add_noise : h = h + self .noise .sample (h .size ()).squeeze (- 1 ).to (self .args .device )
328
+ if self .args .add_noise :
329
+ h = h + self .noise .sample (h .size ()).squeeze (- 1 ).to (self .args .device )
325
330
326
331
h = self .linear_q (h )
327
332
h = h .contiguous ()
@@ -354,7 +359,7 @@ def __init__(self, args):
354
359
self .netW = WDiscriminator (
355
360
resolution = self .args .resolution , n_classes = self .args .vocab_size , output_dim = self .args .num_writers
356
361
).to (self .args .device )
357
-
362
+
358
363
self .netconverter = strLabelConverter (self .args .alphabet + self .args .special_alphabet )
359
364
360
365
self .netOCR = CRNN (self .args ).to (self .args .device )
@@ -409,7 +414,8 @@ def __init__(self, args):
409
414
self .eval_text_encode = self .eval_text_encode .to (self .args .device ).repeat (self .args .batch_size , 1 , 1 )
410
415
411
416
def save_images_for_fid_calculation (self , path , loader , split = 'train' ):
412
- if not isinstance (path , Path ): path = Path (path )
417
+ if not isinstance (path , Path ):
418
+ path = Path (path )
413
419
path .mkdir (exist_ok = True , parents = True )
414
420
415
421
self .real_base = path / f'Real_{ split } '
@@ -666,7 +672,7 @@ def forward(self):
666
672
self .text_encode = self .text_encode .to (self .args .device ).detach ()
667
673
self .len_text = self .len_text .detach ()
668
674
669
- self .words = [word .encode ('utf-8' ) for word in np . random .choice (self .lex , self .args .batch_size )]
675
+ self .words = [word .encode ('utf-8' ) for word in random .choices (self .lex , k = self .args .batch_size )]
670
676
self .text_encode_fake , self .len_text_fake , self .encode_pos_fake = self .netconverter .encode (self .words )
671
677
self .text_encode_fake = self .text_encode_fake .to (self .args .device )
672
678
self .one_hot_fake = make_one_hot (self .text_encode_fake , self .len_text_fake , self .args .vocab_size ).to (
@@ -676,7 +682,7 @@ def forward(self):
676
682
self .encode_pos_fake_js = []
677
683
678
684
for _ in range (self .args .num_words - 1 ):
679
- self .words_j = [word .encode ('utf-8' ) for word in np . random .choice (self .lex , self .args .batch_size )]
685
+ self .words_j = [word .encode ('utf-8' ) for word in random .choices (self .lex , k = self .args .batch_size )]
680
686
self .text_encode_fake_j , self .len_text_fake_j , self .encode_pos_fake_j = self .netconverter .encode (self .words_j )
681
687
self .text_encode_fake_j = self .text_encode_fake_j .to (self .args .device )
682
688
self .text_encode_fake_js .append (self .text_encode_fake_j )
@@ -699,7 +705,7 @@ def backward_D_OCR(self):
699
705
self .pred_real_OCR = self .netOCR (self .real .detach ())
700
706
preds_size = torch .IntTensor ([self .pred_real_OCR .size (0 )] * self .args .batch_size ).detach ()
701
707
loss_OCR_real = self .OCR_criterion (self .pred_real_OCR , self .text_encode .detach (), preds_size ,
702
- self .len_text .detach ())
708
+ self .len_text .detach ())
703
709
self .loss_OCR_real = torch .mean (loss_OCR_real [~ torch .isnan (loss_OCR_real )])
704
710
705
711
loss_total = self .loss_D + self .loss_OCR_real
@@ -844,7 +850,7 @@ def backward_G_only(self):
844
850
pred_fake_OCR = self .netOCR (self .fake )
845
851
preds_size = torch .IntTensor ([pred_fake_OCR .size (0 )] * self .args .batch_size ).detach ()
846
852
loss_OCR_fake = self .OCR_criterion (pred_fake_OCR , self .text_encode_fake .detach (), preds_size ,
847
- self .len_text_fake .detach ())
853
+ self .len_text_fake .detach ())
848
854
self .loss_OCR_fake = torch .mean (loss_OCR_fake [~ torch .isnan (loss_OCR_fake )])
849
855
850
856
self .loss_G = self .loss_G + self .Lcycle + self .lda1 + self .lda2 - self .KLD
0 commit comments