-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathSeqGAN.yaml
55 lines (51 loc) · 2.25 KB
/
SeqGAN.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
GPU : '0' # if you have 2 GPU, use '0' or '1'
#########################################################################################
# Generator Hyper-parameters
######################################################################################
EMB_DIM : 32 # embedding dimension
HIDDEN_DIM : 512 # hidden state dimension of lstm cell
SEQ_LENGTH : 100 # sequence length
START_TOKEN : 0
PRE_GEN_EPOCH : 100 # supervise (maximum likelihood estimation) epochs for generator (default: 120)
PRE_DIS_EPOCH : 100 # supervise (maximum likelihood estimation) epochs for discriminator (default: 50)
SEED : 88
BATCH_SIZE : 32
generator_lr : 0.001
ROLLOUT_UPDATE_RATE: 0.9
reward_gamma: 0.99
# use x10 learning rate for adversarial training: mainly for slow & accurate pretraining & more weighted adv training
x10adv_g: True
#########################################################################################
# Discriminator Hyper-parameters
#########################################################################################
dis_embedding_dim : 32
dis_filter_sizes : [20, 20, 20, 20, 20]
dis_num_filters : [400, 400, 400, 400, 400]
dis_dropout_keep_prob : 0.75
dis_l2_reg_lambda : 0.2
dis_batch_size : 32
rollout_num : 32
discriminator_lr : 0.0001
#########################################################################################
# Basic Training Parameters
#########################################################################################
TOTAL_BATCH : 2000
# vocab size for our custom data
vocab_size : 3216
# positive data, containing real music sequences
positive_file : 'dataset/train'
# negative data from the generator, containing fake sequences
# specify different name if experimenting with multiple instances: causes EOF error & writing to same file from different instances
negative_file : 'dataset/generated'
valid_file : 'dataset/valid'
# # of real data tokens is 140000
# specify so that generated_num * seq_length = 140000 to match balanced real & fake data
generated_num : 1400
epochs_generator : 1
epochs_discriminator : 1
epochs_discriminator_multiplier : 3
pretrain : True
# RL is stochastic: scrap and restart from pretrained checkpoint if things start failing
infinite_loop: True
# our dataset achives around 0.53 from pretraining
loop_threshold: 0.5