-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
64 lines (51 loc) · 2.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from cnn1dlstm.prepare_dataset import LoadAndPreprocessToModel
from cnn1dlstm.dataloader
from cnn1dlstm.model import Conv1d1Lstm
from cnn1dlstm.dataloader import SpectrogramDataset
from cnn1dlstm.spectrogram import Spectrogram
from cnn1dlstm.training import TrainLoaderDataset
import sys
import os
import argparse
def train(args):
dataset_train = args.dataset_train
dataset_val = args.dataset_val
max_wave_size = args.max_wave_size
noise_value = args.noise_value
patience = args.patience
n_epochs = args.n_epochs
learning_rate = args.learning_rate
weight_decay_ = args.weight_decay_
load_preprocess = LoadAndPreprocessToModel(dataset_train)
filenames,labels_ = load_preprocess.shuffle_dataset()
labels, encoder = load_preprocess.encoder_labels(labels_)
load_preprocess = LoadAndPreprocessToModel(dataset_val)
filenames_val, labels_val_ = load_preprocess.shuffle_dataset()
labels_val, encoder = load_preprocess.encoder_labels(labels_val_)
spec = Spectrogram(max_wave_size)
input_shape, preprocessing_fn = spec.waveform_spectrogram(filenames[0])
print(input_shape)
model = spec.model_conv1dlstm(input_shape, 2)
spectro = SpectrogramDataset(filenames, labels_, spec, noise_value)
spectro_val = SpectrogramDataset(filenames_val,labels_val_,spec,noise_value)
train = TrainLoaderDataset(spectro,spectro_val)
#train_loader,valloader=train.loader()
print(spectro.__len__())
model, train_loss = train.train_model_early_stopping(model,patience,n_epochs,learning_rate,weight_decay_)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Example of parser')
subparsers = parser.add_subparsers(dest='mode')
parser_train = subparsers.add_parser('train')
parser_train.add_argument('--dataset_train',type=str,required=True,help='Directory of dataset')
parser_train.add_argument('--dataset_val',type=str,required=True,help='Directory of validation.')
parser_train.add_argument('--max_wave_size',type=int,required=True)
parser_train.add_argument('--noise_value',type=float,required=True)
parser_train.add_argument('--patience',type=int,help='Description of patience argument')
parser_train.add_argument('--n_epochs',type=int,help='Description of number of epochs')
parser_train.add_argument('--learning_rate',type=float,help='Description of learning rate value')
parser_train.add_argument('--weight_decay_',type=float,help='Description of decay value')
args = parser.parse_args()
if args.mode == 'train':
train(args)
else:
raise Exception('Error argument!')