-
Notifications
You must be signed in to change notification settings - Fork 333
/
eval_ofa_net.py
72 lines (62 loc) · 2.33 KB
/
eval_ofa_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import torch
import argparse
from ofa.imagenet_classification.data_providers.imagenet import ImagenetDataProvider
from ofa.imagenet_classification.run_manager import ImagenetRunConfig, RunManager
from ofa.model_zoo import ofa_net
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--path", help="The path of imagenet", type=str, default="/dataset/imagenet"
)
parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
parser.add_argument(
"-b",
"--batch-size",
help="The batch on every device for validation",
type=int,
default=100,
)
parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
parser.add_argument(
"-n",
"--net",
metavar="OFANET",
default="ofa_resnet50",
choices=[
"ofa_mbv3_d234_e346_k357_w1.0",
"ofa_mbv3_d234_e346_k357_w1.2",
"ofa_proxyless_d234_e346_k357_w1.3",
"ofa_resnet50",
],
help="OFA networks",
)
args = parser.parse_args()
if args.gpu == "all":
device_list = range(torch.cuda.device_count())
args.gpu = ",".join(str(_) for _ in device_list)
else:
device_list = [int(_) for _ in args.gpu.split(",")]
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
args.batch_size = args.batch_size * max(len(device_list), 1)
ImagenetDataProvider.DEFAULT_PATH = args.path
ofa_network = ofa_net(args.net, pretrained=True)
run_config = ImagenetRunConfig(test_batch_size=args.batch_size, n_worker=args.workers)
""" Randomly sample a sub-network,
you can also manually set the sub-network using:
ofa_network.set_active_subnet(ks=7, e=6, d=4)
"""
ofa_network.sample_active_subnet()
subnet = ofa_network.get_active_subnet(preserve_weight=True)
""" Test sampled subnet
"""
run_manager = RunManager(".tmp/eval_subnet", subnet, run_config, init=False)
# assign image size: 128, 132, ..., 224
run_config.data_provider.assign_active_img_size(224)
run_manager.reset_running_statistics(net=subnet)
print("Test random subnet:")
print(subnet.module_str)
loss, (top1, top5) = run_manager.validate(net=subnet)
print("Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f" % (loss, top1, top5))