-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathModules.py
538 lines (455 loc) · 19.4 KB
/
Modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
from typing import List
import torch
import math
from argparse import Namespace # for type
import logging
from nvlabs.torch_utils.ops import conv2d_gradfix
from nvlabs.torch_utils.ops.conv2d_gradfix import conv2d as nvlabs_conv2d, no_weight_gradients
conv2d_gradfix.enabled = True
class HifiSinger(torch.nn.Module):
def __init__(self, hyper_parameters: Namespace):
super(HifiSinger, self).__init__()
self.hp = hyper_parameters
self.layer_Dict = torch.nn.ModuleDict()
self.layer_Dict['Encoder'] = Encoder(self.hp)
self.layer_Dict['Duration_Predictor'] = Duration_Predictor(self.hp)
self.layer_Dict['Decoder'] = Decoder(self.hp)
def forward(
self,
durations,
tokens,
notes,
token_lengths= None, # token_length == duration_length == note_length
):
encoder_Masks = None
if not token_lengths is None:
encoder_Masks = self.Mask_Generate(
lengths= token_lengths,
max_lengths= tokens.size(1)
)
encodings = self.layer_Dict['Encoder'](
tokens= tokens,
durations= durations,
notes= notes,
masks= encoder_Masks
)
encodings, predicted_Durations = self.layer_Dict['Duration_Predictor'](
encodings= encodings,
durations= durations
)
decoder_Masks = self.Mask_Generate(
lengths= durations[:, :-1].sum(dim= 1),
max_lengths= durations[0].sum()
)
predicted_Mels, predicted_Silences, predicted_Pitches = self.layer_Dict['Decoder'](
encodings= encodings,
masks= decoder_Masks
)
predicted_Pitches = predicted_Pitches + torch.stack([
note.repeat_interleave(duration) / self.hp.Max_Note
for note, duration in zip(notes, durations)
], dim= 0)
predicted_Mels.data.masked_fill_(decoder_Masks.unsqueeze(1), -self.hp.Sound.Max_Abs_Mel)
predicted_Silences.data.masked_fill_(decoder_Masks, 0.0) # 0.0 -> Silence, 1.0 -> Voice
predicted_Pitches.data.masked_fill_(decoder_Masks, 0.0)
return predicted_Mels, torch.sigmoid(predicted_Silences), predicted_Pitches, predicted_Durations
def Mask_Generate(self, lengths, max_lengths= None):
'''
lengths: [Batch]
'''
sequence = torch.arange(max_lengths or torch.max(lengths))[None, :].to(lengths.device)
return sequence >= lengths[:, None] # [Batch, Time]
class Discriminators(torch.nn.Module):
def __init__(self, hyper_parameters: Namespace) -> None:
super(Discriminators, self).__init__()
self.hp = hyper_parameters
self.layer_Dict = torch.nn.ModuleDict()
for index, frequency_Range in enumerate(self.hp.Discriminator.Frequency_Range):
self.layer_Dict['Discriminator_{}'.format(index)] = Discriminator(
stacks= self.hp.Discriminator.Stacks,
channels= self.hp.Discriminator.Channels,
kernel_size= self.hp.Discriminator.Kernel_Size,
frequency_range= frequency_Range
)
def forward(
self,
x: torch.FloatTensor,
lengths: torch.LongTensor
):
'''
x: [Batch, Time]
'''
return [
self.layer_Dict['Discriminator_{}'.format(index)](x, lengths)
for index in range(len(self.hp.Discriminator.Frequency_Range))
]
class Discriminator(torch.nn.Module):
def __init__(
self,
stacks: int,
kernel_size: int,
channels: int,
frequency_range: List[int]
) -> None:
super(Discriminator, self).__init__()
self.frequency_Range = frequency_range
self.layer = torch.nn.Sequential()
previous_Channels = 1
for index in range(stacks - 1):
self.layer.add_module('Conv_{}'.format(index), Conv2d(
in_channels= previous_Channels,
out_channels= channels,
kernel_size= kernel_size,
bias= False,
w_init_gain= 'linear'
))
self.layer.add_module('Leaky_ReLU_{}'.format(index), torch.nn.LeakyReLU(
negative_slope= 0.2,
inplace= True
))
previous_Channels = channels
self.layer.add_module('Projection', Conv2d(
in_channels= previous_Channels,
out_channels= 1,
kernel_size= 1,
bias= True,
w_init_gain= 'linear'
))
def forward(
self,
x: torch.FloatTensor,
lengths: torch.LongTensor
):
'''
x: [Batch, Mel_dim, Time]
'''
sampling_Length = lengths.min()
mels = []
for mel, length in zip(x, lengths):
offset = torch.randint(
low= 0,
high= length - sampling_Length + 1,
size= (1,)
).to(x.device)
mels.append(mel[self.frequency_Range[0]:self.frequency_Range[1], offset:offset + sampling_Length])
mels = torch.stack(mels).unsqueeze(dim= 1) # [Batch, 1, Sampled_Dim, Min_Time])
return self.layer(mels).squeeze(dim= 1) # [Batch, Sampled_Dim, Min_Time]
class Encoder(torch.nn.Module):
def __init__(self, hyper_parameters: Namespace):
super(Encoder, self).__init__()
self.hp = hyper_parameters
self.layer_Dict = torch.nn.ModuleDict()
self.layer_Dict['Phoneme_Embedding'] = torch.nn.Embedding(
num_embeddings= self.hp.Tokens,
embedding_dim= self.hp.Encoder.Size,
)
self.layer_Dict['Duration_Embedding'] = torch.nn.Embedding(
num_embeddings= self.hp.Max_Duration,
embedding_dim= self.hp.Encoder.Size,
)
self.layer_Dict['Note_Embedding'] = torch.nn.Embedding(
num_embeddings= self.hp.Max_Note,
embedding_dim= self.hp.Encoder.Size,
)
self.layer_Dict['Positional_Embedding'] = Sinusoidal_Positional_Embedding(
channels= self.hp.Encoder.Size,
dropout= 0.0
)
for index in range(self.hp.Encoder.FFT_Block.Stacks):
self.layer_Dict['FFT_Block_{}'.format(index)] = FFT_Block(
in_channels= self.hp.Encoder.Size,
heads= self.hp.Encoder.FFT_Block.Heads,
dropout_rate= self.hp.Encoder.FFT_Block.Dropout_Rate,
ff_in_kernel_size= self.hp.Encoder.FFT_Block.FeedForward.In_Kernel_Size,
ff_out_kernel_size= self.hp.Encoder.FFT_Block.FeedForward.Out_Kernel_Size,
ff_channels= self.hp.Encoder.FFT_Block.FeedForward.Channels,
)
def forward(
self,
tokens: torch.LongTensor,
durations: torch.LongTensor,
notes: torch.LongTensor,
masks: torch.BoolTensor= None
):
'''
x: [Batch, Time]
lengths: [Batch]
'''
tokens = self.layer_Dict['Phoneme_Embedding'](tokens).transpose(2, 1) # [Batch, Channels, Time]
durations = self.layer_Dict['Duration_Embedding'](durations).transpose(2, 1) # [Batch, Channels, Time]
notes = self.layer_Dict['Note_Embedding'](notes).transpose(2, 1) # [Batch, Channels, Time]
x = self.layer_Dict['Positional_Embedding'](tokens + durations + notes)
for index in range(self.hp.Encoder.FFT_Block.Stacks):
x = self.layer_Dict['FFT_Block_{}'.format(index)](x, masks)
return x # [Batch, Channels, Time]
class Duration_Predictor(torch.nn.Module):
def __init__(self, hyper_parameters: Namespace):
super(Duration_Predictor, self).__init__()
self.hp = hyper_parameters
self.layer_Dict = torch.nn.ModuleDict()
previous_Channels = self.hp.Encoder.Size
for index, (kernel_Size, channels) in enumerate(zip(
self.hp.Duration_Predictor.Conv.Kernel_Size,
self.hp.Duration_Predictor.Conv.Channels
)):
self.layer_Dict['Conv_{}'.format(index)] = Conv1d(
in_channels= previous_Channels,
out_channels= channels,
kernel_size= kernel_Size,
padding= (kernel_Size - 1) // 2,
w_init_gain= 'relu'
)
self.layer_Dict['LayerNorm_{}'.format(index)] = torch.nn.LayerNorm(
normalized_shape= channels
)
self.layer_Dict['ReLU_{}'.format(index)] = torch.nn.ReLU()
self.layer_Dict['Dropout_{}'.format(index)] = torch.nn.Dropout(
p= self.hp.Duration_Predictor.Conv.Dropout_Rate
)
previous_Channels = channels
self.layer_Dict['Projection'] = torch.nn.Sequential()
self.layer_Dict['Projection'].add_module('Conv', Conv1d(
in_channels= previous_Channels,
out_channels= 1,
kernel_size= 1,
w_init_gain= 'relu'
))
self.layer_Dict['Projection'].add_module('ReLU', torch.nn.ReLU())
def forward(
self,
encodings: torch.FloatTensor,
durations: torch.LongTensor= None
):
x = encodings
for index in range(len(self.hp.Duration_Predictor.Conv.Kernel_Size)):
x = self.layer_Dict['Conv_{}'.format(index)](x)
x = self.layer_Dict['LayerNorm_{}'.format(index)](x.transpose(2, 1)).transpose(2, 1)
x = self.layer_Dict['ReLU_{}'.format(index)](x)
x = self.layer_Dict['Dropout_{}'.format(index)](x)
predicted_Durations = self.layer_Dict['Projection'](x)
if durations is None:
durations = predicted_Durations.ceil().long().clamp(0, self.hp.Max_Duration)
durations = torch.stack([
(torch.ones_like(duration) if duration.sum() == 0 else duration)
for duration in durations
], dim= 0)
max_Durations = torch.max(torch.cat([duration.sum(dim= 0, keepdim= True) + 1 for duration in durations]))
if max_Durations > self.hp.Max_Duration: # I assume this means failing
durations = torch.ones_like(predicted_Durations).long()
else:
durations = torch.cat([
durations[:, :-1], durations[:, -1:] + max_Durations - durations.sum(dim= 1, keepdim= True)
], dim= 1)
x = torch.stack([
encoding.repeat_interleave(duration, dim= 1)
for encoding, duration in zip(encodings, durations)
], dim= 0)
return x, predicted_Durations.squeeze(1)
class Decoder(torch.nn.Module):
def __init__(self, hyper_parameters: Namespace):
super(Decoder, self).__init__()
self.hp = hyper_parameters
self.layer_Dict = torch.nn.ModuleDict()
self.layer_Dict['Positional_Embedding'] = Sinusoidal_Positional_Embedding(
channels= self.hp.Encoder.Size
)
self.layer_Dict['FFT_Block'] = torch.nn.Sequential()
for index in range(self.hp.Decoder.FFT_Block.Stacks):
self.layer_Dict['FFT_Block_{}'.format(index)] = FFT_Block(
in_channels= self.hp.Encoder.Size,
heads= self.hp.Decoder.FFT_Block.Heads,
dropout_rate= self.hp.Decoder.FFT_Block.Dropout_Rate,
ff_in_kernel_size= self.hp.Decoder.FFT_Block.FeedForward.In_Kernel_Size,
ff_out_kernel_size= self.hp.Decoder.FFT_Block.FeedForward.Out_Kernel_Size,
ff_channels= self.hp.Decoder.FFT_Block.FeedForward.Channels,
)
self.layer_Dict['Projection'] = Conv1d(
in_channels= self.hp.Encoder.Size,
out_channels= self.hp.Sound.Mel_Dim + 1 + 1,
kernel_size= 1,
w_init_gain= 'linear'
)
def forward(
self,
encodings: torch.FloatTensor,
masks: torch.BoolTensor
):
x = encodings
x = self.layer_Dict['Positional_Embedding'](x)
for index in range(self.hp.Encoder.FFT_Block.Stacks):
x = self.layer_Dict['FFT_Block_{}'.format(index)](x, masks= masks)
x = self.layer_Dict['Projection'](x)
mels, silences, notes = torch.split(
x,
split_size_or_sections= [self.hp.Sound.Mel_Dim, 1, 1],
dim= 1
)
return mels, silences.squeeze(1), notes.squeeze(1)
class FFT_Block(torch.nn.Module):
def __init__(
self,
in_channels: int,
heads: int,
dropout_rate: float,
ff_in_kernel_size: int,
ff_out_kernel_size: int,
ff_channels: int
):
super(FFT_Block, self).__init__()
self.layer_Dict = torch.nn.ModuleDict()
self.layer_Dict['Multihead_Attention'] = torch.nn.MultiheadAttention(
embed_dim= in_channels,
num_heads= heads
)
self.layer_Dict['LayerNorm_0'] = torch.nn.LayerNorm(
normalized_shape= in_channels
)
self.layer_Dict['Dropout'] = torch.nn.Dropout(p= dropout_rate)
self.layer_Dict['Conv'] = torch.nn.Sequential()
self.layer_Dict['Conv'].add_module('Conv_0', Conv1d(
in_channels= in_channels,
out_channels= ff_channels,
kernel_size= ff_in_kernel_size,
padding= (ff_in_kernel_size - 1) // 2,
w_init_gain= 'relu'
))
self.layer_Dict['Conv'].add_module('ReLU', torch.nn.ReLU())
self.layer_Dict['Conv'].add_module('Conv_1', Conv1d(
in_channels= ff_channels,
out_channels= in_channels,
kernel_size= ff_out_kernel_size,
padding= (ff_out_kernel_size - 1) // 2,
w_init_gain= 'linear'
))
self.layer_Dict['Conv'].add_module('Dropout', torch.nn.Dropout(p= dropout_rate))
self.layer_Dict['LayerNorm_1'] = torch.nn.LayerNorm(
normalized_shape= in_channels
)
def forward(self, x: torch.FloatTensor, masks: torch.BoolTensor= None):
'''
x: [Batch, Channels, Time]
'''
x = self.layer_Dict['Multihead_Attention'](
query= x.permute(2, 0, 1),
key= x.permute(2, 0, 1),
value= x.permute(2, 0, 1),
key_padding_mask= masks
)[0].permute(1, 2, 0) + x
x = self.layer_Dict['LayerNorm_0'](x.transpose(2, 1)).transpose(2, 1)
x = self.layer_Dict['Dropout'](x)
if not masks is None:
x *= torch.logical_not(masks).unsqueeze(1).float()
x = self.layer_Dict['Conv'](x) + x
x = self.layer_Dict['LayerNorm_1'](x.transpose(2, 1)).transpose(2, 1)
if not masks is None:
x *= torch.logical_not(masks).unsqueeze(1).float()
return x
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class Sinusoidal_Positional_Embedding(torch.nn.Module):
def __init__(self, channels, dropout=0.1, max_len=5000):
super(Sinusoidal_Positional_Embedding, self).__init__()
self.dropout = torch.nn.Dropout(p=dropout)
pe = torch.zeros(max_len, channels)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, channels, 2).float() * (-math.log(10000.0) / channels))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(2, 1) #[Batch, Channels, Time]
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :, :x.size(2)]
return self.dropout(x)
class Conv2d(torch.nn.Conv2d):
def __init__(self, w_init_gain= 'relu', clamp: float=None, *args, **kwargs):
self.w_init_gain = w_init_gain
self.clamp = clamp
super().__init__(*args, **kwargs)
self.runtime_Coef = 1.0 / math.sqrt(self.in_channels * self.kernel_size[0] * self.kernel_size[1])
def reset_parameters(self):
torch.nn.init.normal_(self.weight, mean=0.0, std= 1.0)
if not self.bias is None:
torch.nn.init.zeros_(self.bias)
def forward(self, x: torch.Tensor):
x = nvlabs_conv2d(
input= x,
weight= self.weight.to(x.device) * self.runtime_Coef,
stride= self.stride,
padding= (int((self.kernel_size[0] - self.stride[0]) / 2), int((self.kernel_size[1] - self.stride[0]) / 2))
) # [Batch, Out, Resolution, Resolution]
if not self.bias is None:
x += self.bias.to(x.device)[None, :, None, None]
if not self.clamp is None:
x.clamp_(-self.clamp, self.clamp)
return x
class Conv1d(Conv2d):
def __init__(self, w_init_gain= 'relu', clamp: float=None, *args, **kwargs):
kwargs['kernel_size'] = (1, kwargs['kernel_size'])
super().__init__(w_init_gain, clamp, *args, **kwargs)
def forward(self, x: torch.Tensor):
return super().forward(x.unsqueeze(2)).squeeze(2)
class Gradient_Penalty(torch.nn.Module):
def __init__(
self,
gamma: float= 10.0
) -> None:
super().__init__()
self.gamma = gamma
def forward(
self,
reals: torch.Tensor,
discriminations: torch.Tensor,
) -> torch.Tensor:
'''
reals: [Batch, Channels, Time]. Real mels.
discriminations: [Batch]. Discrimination outputs of real mels.
'''
with no_weight_gradients():
gradient_Penalties = torch.autograd.grad(
outputs= discriminations.sum(),
inputs= reals,
create_graph= True,
only_inputs= True
)[0]
gradient_Penalties = gradient_Penalties.square().sum(dim= (1, 2)) * (self.gamma * 0.5) # [Batch]
gradient_Penalties = (gradient_Penalties + reals[:, 0, 0] * 0.0).mean()
return gradient_Penalties
if __name__ == "__main__":
import yaml
from Arg_Parser import Recursive_Parse
hp = Recursive_Parse(yaml.load(
open('Hyper_Parameters.yaml', encoding='utf-8'),
Loader=yaml.Loader
))
from Datasets import Dataset, Collater
token_Dict = yaml.load(open(hp.Token_Path), Loader=yaml.Loader)
dataset = Dataset(
pattern_path= hp.Train.Train_Pattern.Path,
Metadata_file= hp.Train.Train_Pattern.Metadata_File,
token_dict= token_Dict,
accumulated_dataset_epoch= hp.Train.Train_Pattern.Accumulated_Dataset_Epoch,
)
collater = Collater(
token_dict= token_Dict,
max_abs_mel= hp.Sound.Max_Abs_Mel
)
dataLoader = torch.utils.data.DataLoader(
dataset= dataset,
collate_fn= collater,
sampler= torch.utils.data.RandomSampler(dataset),
batch_size= hp.Train.Batch_Size,
num_workers= hp.Train.Num_Workers,
pin_memory= True
)
durations, tokens, notes, token_lengths, mels, silences, pitches, mel_lengths = next(iter(dataLoader))
model = HifiSinger(hp)
predicted_Mels, predicted_Silences, predicted_Pitches, predicted_Durations = model(
tokens= tokens,
durations= durations,
notes= notes,
token_lengths= token_lengths
)
discriminator = Discriminators(hp)
discriminations = discriminator(predicted_Mels, mel_lengths)
print(discriminations)
for x in discriminations:
print(x.shape)