-
Notifications
You must be signed in to change notification settings - Fork 6
/
opts.py
76 lines (56 loc) · 3.02 KB
/
opts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from __future__ import print_function
import configargparse
def model_opts(parser):
"""
These options are passed to the construction of the model.
Be careful with these as they will be used during unmixing.
"""
group = parser.add_argument_group('Model AE')
group.add('--encoder_type', '-encoder_type', type=str, default='shallow',
choices=['deep', 'shallow'],
help="Allows the user to choose between two levels of encoder complexity."
"Options are: [deep|shallow]")
# SLReLU unavailable, add assert in main
group.add('--soft_threshold', '-soft_threshold', type=str, default='SReLU',
choices=['SReLU', 'SLReLU'],
help="Type of soft-thresholding for final layer of encoder"
"Options are: [SReLU|SLReLU]")
group.add('--activation', '-activation', type=str,
choices=['ReLU', 'Leaky-ReLU', 'Sigmoid'],
help="Activation function for hidden layers of encoder."
"For shallow AE there won't be any activation. Options are:"
"[ReLU|Leaky-ReLU|Sigmoid]")
def train_opts(parser):
"""
These options are passed to the training of the model.
Be careful with these as they will be used during unmixing.
"""
group = parser.add_argument_group('General')
group.add('--src_dir', '-src_dir', type=str, required=True,
help="System path to the Samson directory.")
group.add('--save_checkpt', '-save_checkpt', type=int, default=0,
help="Number of epochs after which a check point of"
"model parameters should be saved.")
group.add('--save_dir', '-save_dir', type=str, default="../logs",
help="System path to save model weights.")
group.add('--train_from', '-train_from', type=str, default=None,
help="Path to checkpoint file to continue training from.")
group.add('--num_bands', '-num_bands', type=int, default=156,
help="Number of spectral bands present in input image.")
group.add('--end_members', '-end_members', type=int, default=3,
help="Number of end-members to be extracted from HSI.")
group = parser.add_argument_group('Hyperparameters')
group.add('--batch_size', '-batch_size', type=int, default=20,
help="Maximum batch size for training.")
group.add('--learning_rate','-learning_rate', type=float, default=1e-3,
help="Learning rate for training the network.")
group.add('--epochs','-epochs', type=int, default=100,
help="Number of iterations that the network should be trained for.")
group.add('--gaussian_dropout', '-gaussian_dropout', type=float, default=1.0,
help="Mean of multiplicative gaussain noise used for regularization.")
group.add('--threshold', '-threshold', type=float, default=5.0,
help="Defines the threshold for the soft-thresholding operation.")
group.add('--objective', '-objective', type=str, default='MSE',
choices=['MSE', 'SAD', 'SID'],
help="Objective function used to train the Autoencoder."
"Options are: [MSE|SAD|SID]")