-
Notifications
You must be signed in to change notification settings - Fork 2
/
configs.py
68 lines (59 loc) · 1.61 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from d3pm_absorbing import D3PMAbsorbing
def d3pm_text8():
model_args = dict(
vocab_size=27,
n_embed=768,
n_heads=768 // 64,
n_blocks=12,
n_cond=128,
dropout=0.025,
T=1000,
lambda_ce=0.05,
)
training_args = dict(
batch_size=256,
learning_rate=1e-3,
min_lr=1e-5,
gradient_accumulation_steps=4,
warmup_iters=2_500,
max_iters=500_000,
eval_iters=1000,
weight_decay=0.1,
training_seed=1,
)
return D3PMAbsorbing, model_args, training_args
def d3pm_text8_4gpu():
model, model_args, training_args = d3pm_text8()
training_args["gradient_accumulation_steps"] = 1
training_args["eval_iters"] = 250
return model, model_args, training_args
def d3pm_openwebtext_8gpu():
model_args = dict(
vocab_size=50257,
n_embed=768,
n_heads=768 // 64,
n_blocks=12,
n_cond=128,
dropout=0.0,
T=1000,
lambda_ce=0.05,
)
training_args = dict(
dataset="openwebtext",
batch_size=16,
seq_len=1024,
learning_rate=6e-4,
min_lr=1e-5,
gradient_accumulation_steps=8,
warmup_iters=2_500,
max_iters=500_000,
eval_iters=400,
weight_decay=0.1,
training_seed=9,
)
return D3PMAbsorbing, model_args, training_args
def d3pm_openwebtext_32gpu():
model, model_args, training_args = d3pm_openwebtext_8gpu()
training_args["gradient_accumulation_steps"] = 2
training_args["eval_iters"] = 50
return model, model_args, training_args