Skip to content

Commit b934749

Browse files
committedNov 28, 2023
Changed from np.random.choice to random.choices
1 parent 6ef5a91 commit b934749

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed
 

‎models/model.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch.nn.utils import clip_grad_norm_
44
import time
55
import sys
6+
import random
67
import torchvision.models as models
78
from models.transformer import *
89
from .BigGAN_networks import *
@@ -96,7 +97,8 @@ def __init__(self, args):
9697
super(Generator, self).__init__()
9798
self.args = args
9899
INP_CHANNEL = self.args.num_examples
99-
if self.args.is_seq: INP_CHANNEL = 1
100+
if self.args.is_seq:
101+
INP_CHANNEL = 1
100102

101103
encoder_layer = TransformerEncoderLayer(self.args.tn_hidden_dim, self.args.tn_nheads,
102104
self.args.tn_dim_feedforward,
@@ -267,12 +269,14 @@ def Eval(self, ST, QRS, QRS_pos):
267269
hs = self.reparameterize(hs_mu, hs_logvar).permute(1, 0, 2).unsqueeze(0)
268270

269271
h = hs.transpose(1, 2)[-1] # torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
270-
if self.args.add_noise: h = h + self.noise.sample(h.size()).squeeze(-1).to(self.args.device)
272+
if self.args.add_noise:
273+
h = h + self.noise.sample(h.size()).squeeze(-1).to(self.args.device)
271274

272275
h = self.linear_q(h)
273276
h = h.contiguous()
274277

275-
if self.args.all_chars: h = torch.stack([h[i][QR[i]] for i in range(self.args.self.args.batch_size)], 0)
278+
if self.args.all_chars:
279+
h = torch.stack([h[i][QR[i]] for i in range(self.args.self.args.batch_size)], 0)
276280

277281
h = h.view(h.size(0), h.shape[1] * 2, 4, -1)
278282
h = h.permute(0, 3, 2, 1)
@@ -282,7 +286,7 @@ def Eval(self, ST, QRS, QRS_pos):
282286
OUT_IMGS.append(h.detach())
283287

284288
return OUT_IMGS
285-
289+
286290
def compute_style(self, ST):
287291
B, N, R, C = ST.shape
288292
FEAT_ST = self.Feat_Encoder(ST.view(B * N, 1, R, C))
@@ -321,7 +325,8 @@ def forward(self, ST, QR, QRs=None, QR_pos=None, QRs_pos=None, mode='train'):
321325

322326
h = hs.transpose(1, 2)[-1] # torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
323327

324-
if self.args.add_noise: h = h + self.noise.sample(h.size()).squeeze(-1).to(self.args.device)
328+
if self.args.add_noise:
329+
h = h + self.noise.sample(h.size()).squeeze(-1).to(self.args.device)
325330

326331
h = self.linear_q(h)
327332
h = h.contiguous()
@@ -354,7 +359,7 @@ def __init__(self, args):
354359
self.netW = WDiscriminator(
355360
resolution=self.args.resolution, n_classes=self.args.vocab_size, output_dim=self.args.num_writers
356361
).to(self.args.device)
357-
362+
358363
self.netconverter = strLabelConverter(self.args.alphabet + self.args.special_alphabet)
359364

360365
self.netOCR = CRNN(self.args).to(self.args.device)
@@ -409,7 +414,8 @@ def __init__(self, args):
409414
self.eval_text_encode = self.eval_text_encode.to(self.args.device).repeat(self.args.batch_size, 1, 1)
410415

411416
def save_images_for_fid_calculation(self, path, loader, split='train'):
412-
if not isinstance(path, Path): path = Path(path)
417+
if not isinstance(path, Path):
418+
path = Path(path)
413419
path.mkdir(exist_ok=True, parents=True)
414420

415421
self.real_base = path / f'Real_{split}'
@@ -666,7 +672,7 @@ def forward(self):
666672
self.text_encode = self.text_encode.to(self.args.device).detach()
667673
self.len_text = self.len_text.detach()
668674

669-
self.words = [word.encode('utf-8') for word in np.random.choice(self.lex, self.args.batch_size)]
675+
self.words = [word.encode('utf-8') for word in random.choices(self.lex, k=self.args.batch_size)]
670676
self.text_encode_fake, self.len_text_fake, self.encode_pos_fake = self.netconverter.encode(self.words)
671677
self.text_encode_fake = self.text_encode_fake.to(self.args.device)
672678
self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, self.args.vocab_size).to(
@@ -676,7 +682,7 @@ def forward(self):
676682
self.encode_pos_fake_js = []
677683

678684
for _ in range(self.args.num_words - 1):
679-
self.words_j = [word.encode('utf-8') for word in np.random.choice(self.lex, self.args.batch_size)]
685+
self.words_j = [word.encode('utf-8') for word in random.choices(self.lex, k=self.args.batch_size)]
680686
self.text_encode_fake_j, self.len_text_fake_j, self.encode_pos_fake_j = self.netconverter.encode(self.words_j)
681687
self.text_encode_fake_j = self.text_encode_fake_j.to(self.args.device)
682688
self.text_encode_fake_js.append(self.text_encode_fake_j)
@@ -699,7 +705,7 @@ def backward_D_OCR(self):
699705
self.pred_real_OCR = self.netOCR(self.real.detach())
700706
preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.args.batch_size).detach()
701707
loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size,
702-
self.len_text.detach())
708+
self.len_text.detach())
703709
self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
704710

705711
loss_total = self.loss_D + self.loss_OCR_real
@@ -844,7 +850,7 @@ def backward_G_only(self):
844850
pred_fake_OCR = self.netOCR(self.fake)
845851
preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.args.batch_size).detach()
846852
loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size,
847-
self.len_text_fake.detach())
853+
self.len_text_fake.detach())
848854
self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
849855

850856
self.loss_G = self.loss_G + self.Lcycle + self.lda1 + self.lda2 - self.KLD

0 commit comments

Comments
 (0)
Please sign in to comment.