-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdefaults.py
94 lines (81 loc) · 2.42 KB
/
defaults.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
imsize = (288, 160)
semantic_nc = None
image_channels = 3
max_images_to_train = 12e6
cse_nc = None
project = "fba"
semantic_labels = None
_output_dir = os.environ["BASE_OUTPUT_DIR"] if "BASE_OUTPUT_DIR" in os.environ else "outputs"
_cache_dir = ".fba_cache"
_checkpoint_dir = "checkpoints"
logger_backend = "wandb" # choices: ["tensorboard", "none"]
# Tag used for matplotlib plots
log_tag = None
# URL for publishing models
checkpoint_url = None
metrics_url = None
optimizer = dict(
lazy_regularization=True,
lazy_reg_interval=16,
D_opts=dict(type="Adam", lr=0.001, betas=(0.0, 0.99)),
G_opts=dict(type="Adam", lr=0.001, betas=(0.0, 0.99)),
)
# exponential moving average
EMA = dict(nimg_half_time=10e3, rampup_nimg=0)
hooks = {
"time": dict(type="TimeLoggerHook", num_ims_per_log=10e3),
"metric": dict(type="MetricHook", ims_per_log=2e5, n_diversity_samples=1),
"checkpoint": dict(type="CheckpointHook", ims_per_checkpoint=2e5, test_checkpoints=[]),
"image_saver": dict(type="ImageSaveHook", ims_per_save=2e5, n_diverse_samples=4, n_diverse_images=8, nims2log=16, save_train_G=False),
}
ims_per_log = 2048
random_seed = 0
jit_transform = False
data_train = dict(
loader=dict(num_workers=8, drop_last=True, pin_memory=True, batch_size=32, prefetch_factor=2),
sampler=dict(drop_last=True, shuffle=True)
)
data_val = dict(
loader=dict(num_workers=8, pin_memory=True, batch_size=32, prefetch_factor=2),
sampler=dict(drop_last=True, shuffle=False)
)
loss = dict(
type="LossHandler",
gan_criterion=dict(type="nsgan", weight=1),
gradient_penalty=dict(type="r1_regularization", weight=5, mask_out=True),
epsilon_penalty=dict(type="epsilon_penalty", weight=0.001),
)
generator = dict(
type="UnetGenerator",
scale_grad=True,
min_fmap_resolution=32,
cnum=32,
max_cnum_mul=16,
n_middle_blocks=2,
z_channels=512,
mask_output=True,
input_semantic=False,
style_cfg=dict(type="NoneStyle"),
embed_z=True,
class_specific_z=False,
conv_clamp=256,
input_cse=False,
latent_space=None,
use_cse=True,
modulate_encoder=False,
norm_type="instance_norm_std",
norm_unet=False,
unet_skip="residual"
)
discriminator = dict(
type="FPNDiscriminator",
min_fmap_resolution=8,
max_cnum_mul=8,
cnum=32,
input_condition=True,
semantic_input_mode=None,
conv_clamp=256,
input_cse=False,
output_fpn=False,
)