-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcifar10_config.py
64 lines (53 loc) · 1.29 KB
/
cifar10_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import ml_collections
def d(**kwargs):
"""Helper of creating a config dict."""
return ml_collections.ConfigDict(initial_dictionary=kwargs)
def get_config():
config = ml_collections.ConfigDict()
config.seed = 1234
config.pred = 'noise_pred'
config.train = d(
n_steps=500000,
batch_size=128,
mode='uncond',
log_interval=10,
eval_interval=5000,
save_interval=50000,
)
config.optimizer = d(
name='adamw',
lr=0.0002,
weight_decay=0.03,
betas=(0.99, 0.999),
)
config.lr_scheduler = d(
name='customized',
warmup_steps=2500
)
config.nnet = d(
name='uvit',
img_size=32,
patch_size=2,
embed_dim=512,
depth=12,
num_heads=8,
mlp_ratio=4,
qkv_bias=False,
mlp_time_embed=False,
num_classes=-1,
scalelong = 0, # choose a scaling method
kappa = 0.5, # scaling coefficient
)
config.dataset = d(
name='cifar10',
path='assets/datasets/cifar10',
random_flip=True,
)
config.sample = d(
sample_steps=1000,
n_samples=50000,
mini_batch_size=500,
algorithm='euler_maruyama_sde',
path=''
)
return config