-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathparameters.py
150 lines (122 loc) · 6.09 KB
/
parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Parameters used in the feature extraction, neural network model, and training the SELDnet can be changed here.
#
# Ideally, do not change the values of the default parameters. Create separate cases with unique <task-id> as seen in
# the code below (if-else loop) and use them. This way you can easily reproduce a configuration on a later time.
def get_params(argv='1'):
print("SET: {}".format(argv))
# ########### default parameters ##############
params = dict(
quick_test=True, # To do quick test. Trains/test on small subset of dataset, and # of epochs
finetune_mode=True, # Finetune on existing model, requires the pretrained model path set - pretrained_model_weights
pretrained_model_weights='3_1_dev_split0_multiaccdoa_foa_model.h5',
# INPUT PATH
# dataset_dir='DCASE2020_SELD_dataset/', # Base folder containing the foa/mic and metadata folders
dataset_dir='../DCASE2024_SELD_dataset/',
# OUTPUT PATHS
# feat_label_dir='DCASE2020_SELD_dataset/feat_label_hnet/', # Directory to dump extracted features and labels
feat_label_dir='../DCASE2024_SELD_dataset/seld_feat_label/',
model_dir='models', # Dumps the trained models and training curves in this folder
dcase_output_dir='results', # recording-wise results are dumped in this path.
# DATASET LOADING PARAMETERS
mode='dev', # 'dev' - development or 'eval' - evaluation dataset
dataset='foa', # 'foa' - ambisonic or 'mic' - microphone signals
# FEATURE PARAMS
fs=24000,
hop_len_s=0.02,
label_hop_len_s=0.1,
max_audio_len_s=60,
nb_mel_bins=64,
use_salsalite=False, # Used for MIC dataset only. If true use salsalite features, else use GCC features
fmin_doa_salsalite=50,
fmax_doa_salsalite=2000,
fmax_spectra_salsalite=9000,
# MODEL TYPE
modality='audio', # 'audio' or 'audio_visual'
multi_accdoa=False, # False - Single-ACCDOA or True - Multi-ACCDOA
thresh_unify=15, # Required for Multi-ACCDOA only. Threshold of unification for inference in degrees.
# DNN MODEL PARAMETERS
label_sequence_length=50, # Feature sequence length
batch_size=128, # Batch size
dropout_rate=0.05, # Dropout rate, constant for all layers
nb_cnn2d_filt=64, # Number of CNN nodes, constant for each layer
f_pool_size=[4, 4, 2], # CNN frequency pooling, length of list = number of CNN layers, list value = pooling per layer
nb_heads=8,
nb_self_attn_layers=2,
nb_transformer_layers=2,
nb_rnn_layers=2,
rnn_size=128,
nb_fnn_layers=1,
fnn_size=128, # FNN contents, length of list = number of layers, list value = number of nodes
nb_epochs=250, # Train for maximum epochs
lr=1e-3,
# METRIC
average='macro', # Supports 'micro': sample-wise average and 'macro': class-wise average,
segment_based_metrics=False, # If True, uses segment-based metrics, else uses frame-based metrics
evaluate_distance=True, # If True, computes distance errors and apply distance threshold to the detections
lad_doa_thresh=20, # DOA error threshold for computing the detection metrics
lad_dist_thresh=float('inf'), # Absolute distance error threshold for computing the detection metrics
lad_reldist_thresh=float('1'), # Relative distance error threshold for computing the detection metrics
)
# ########### User defined parameters ##############
if argv == '1':
print("USING DEFAULT PARAMETERS\n")
elif argv == '2':
print("FOA + ACCDOA\n")
params['quick_test'] = False
params['dataset'] = 'foa'
params['multi_accdoa'] = False
elif argv == '3':
print("FOA + multi ACCDOA\n")
params['quick_test'] = False
params['dataset'] = 'foa'
params['multi_accdoa'] = True
elif argv == '4':
print("MIC + GCC + ACCDOA\n")
params['quick_test'] = False
params['dataset'] = 'mic'
params['use_salsalite'] = False
params['multi_accdoa'] = False
elif argv == '5':
print("MIC + SALSA + ACCDOA\n")
params['quick_test'] = False
params['dataset'] = 'mic'
params['use_salsalite'] = True
params['multi_accdoa'] = False
elif argv == '6':
print("MIC + GCC + multi ACCDOA\n")
params['pretrained_model_weights'] = '6_1_dev_split0_multiaccdoa_mic_gcc_model.h5'
params['quick_test'] = False
params['dataset'] = 'mic'
params['use_salsalite'] = False
params['multi_accdoa'] = True
elif argv == '7':
print("MIC + SALSA + multi ACCDOA\n")
params['quick_test'] = False
params['dataset'] = 'mic'
params['use_salsalite'] = True
params['multi_accdoa'] = True
elif argv == '999':
print("QUICK TEST MODE\n")
params['quick_test'] = True
else:
print('ERROR: unknown argument {}'.format(argv))
exit()
feature_label_resolution = int(params['label_hop_len_s'] // params['hop_len_s'])
params['feature_sequence_length'] = params['label_sequence_length'] * feature_label_resolution
params['t_pool_size'] = [feature_label_resolution, 1, 1] # CNN time pooling
params['patience'] = int(params['nb_epochs']) # Stop training if patience is reached
params['model_dir'] = params['model_dir'] + '_' + params['modality']
params['dcase_output_dir'] = params['dcase_output_dir'] + '_' + params['modality']
if '2020' in params['dataset_dir']:
params['unique_classes'] = 14
elif '2021' in params['dataset_dir']:
params['unique_classes'] = 12
elif '2022' in params['dataset_dir']:
params['unique_classes'] = 13
elif '2023' in params['dataset_dir']:
params['unique_classes'] = 13
elif '2024' in params['dataset_dir']:
params['unique_classes'] = 13
for key, value in params.items():
print("\t{}: {}".format(key, value))
return params