-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathoptions.py
43 lines (40 loc) · 2.45 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
import os
from Dataset.param_aug import ParamDiffAug
def args_parser():
parser = argparse.ArgumentParser()
path_dir = os.path.dirname(__file__)
parser.add_argument('--path_cifar10', type=str, default=os.path.join(path_dir, 'data/CIFAR10/'))
parser.add_argument('--path_cifar100', type=str, default=os.path.join(path_dir, 'data/CIFAR100/'))
parser.add_argument('--num_classes', type=int, default=10)
parser.add_argument('--num_clients', type=int, default=20)
parser.add_argument('--num_online_clients', type=int, default=8)
parser.add_argument('--num_rounds', type=int, default=200)
parser.add_argument('--num_epochs_local_training', type=int, default=10) #
parser.add_argument('--batch_size_local_training', type=int, default=32)
parser.add_argument('--match_epoch', type=int, default=100)
parser.add_argument('--crt_epoch', type=int, default=300)
parser.add_argument('--batch_real', type=int, default=32)
parser.add_argument('--num_of_feature', type=int, default=100)
parser.add_argument('--lr_feature', type=float, default=0.1, help='learning rate for updating synthetic images')
parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
parser.add_argument('--batch_size_test', type=int, default=500)
parser.add_argument('--lr_local_training', type=float, default=0.1)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--non_iid_alpha', type=float, default=0.5)
parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--imb_type', default="exp", type=str, help='imbalance type')
parser.add_argument('--imb_factor', default=0.02, type=float, help='imbalance factor')
parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
parser.add_argument('--save_path', type=str, default=os.path.join(path_dir, 'result/'))
parser.add_argument('--method', type=str, default='DSA', help='DC/DSA')
parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
help='differentiable Siamese augmentation strategy')
# FedProx
parser.add_argument('--mu', type=float, default=0.01)
# FedAvgM
parser.add_argument('--init_belta', type=float, default=0.97)
args = parser.parse_args()
args.dsa_param = ParamDiffAug()
args.dsa = True if args.method == 'DSA' else False
return args