-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_params.py
115 lines (94 loc) · 3.29 KB
/
test_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import json
import os
from pathlib import Path
import random
import subprocess
import cv2
from tqdm import tqdm
PARAMS_PATH = "frontend/static/test_params.json"
OUTPUT_PATH = "frontend/static/data"
DATA_PATH = "D:/GithubProjects/tensorflow_datasets"
DATA_PATH_LAPTOP = "C:/Users/victo/tensorflow_datasets"
parser = argparse.ArgumentParser(description="Tests multiple parameters")
parser.add_argument(
"--n",
type=int,
help="Number of images to test each parameter on",
default=25,
)
parser.add_argument(
"--data_path",
help="Path to the tensorflow_datasets folder",
default=DATA_PATH,
)
parser.add_argument(
"--param",
help="Path to params to use",
default="pps_test",
choices=["pps_test", "cnl_test"],
)
def copy_images(files: list[Path], n, dataset, path):
random_files = random.sample(files, n)
image_names = []
for i in range(int(n)):
if random_files[i].suffix.lower() not in [".png", ".jpg", ".jpeg"]:
print("File is not an image, skipping...")
continue
img = cv2.imread(str(random_files[i]))
p = f"{path}/{dataset}/test"
os.makedirs(p, exist_ok=True)
cv2.imwrite(f"{p}/{random_files[i].name}", img)
image_names.append(random_files[i].name)
with open(f"{path}/{dataset}/image_names.json", "x") as json_file:
json.dump(image_names, json_file, indent=2)
def get_args(params, args, dataset, path):
return [
"python.exe",
"-m",
"segment_images",
"--data_path",
path,
"--dataset",
dataset,
"--output",
f"{OUTPUT_PATH}/{args.param}/{params['id']}",
"--size",
str(args.n),
"--points_per_side",
str(params["points_per_side"]),
"--pred_iou_thresh",
str(params["pred_iou_thresh"]),
"--stability_score_thresh",
str(params["stability_score_thresh"]),
"--crop_n_layers",
str(params["crop_n_layers"]),
"--crop_n_layers_downscale_factor",
str(params["crop_n_layers_downscale_factor"]),
"-a",
]
def main(args):
with open(PARAMS_PATH, "r") as json_file:
params_json = json.load(json_file)
params = params_json[args.param]
oxford_flowers_files = list(
Path(args.data_path, "oxford_flowers102", "test").iterdir()
)
imagenet_files = list(Path(args.data_path, "imagenet2012", "test").iterdir())
oxford_pets = list(Path(args.data_path, "oxford_iiit_pet", "test").iterdir())
IMAGES_OUTPUT_PATH = f"{OUTPUT_PATH}/{args.param}/images"
os.makedirs(IMAGES_OUTPUT_PATH, exist_ok=True)
copy_images(oxford_flowers_files, args.n, "oxford_flowers102", IMAGES_OUTPUT_PATH)
copy_images(imagenet_files, args.n, "imagenet2012", IMAGES_OUTPUT_PATH)
copy_images(oxford_pets, args.n, "oxford_iiit_pet", IMAGES_OUTPUT_PATH)
for p in tqdm(params):
print("Segmenting for:", p["id"])
im_args = get_args(p, args, "imagenet2012", IMAGES_OUTPUT_PATH)
of_args = get_args(p, args, "oxford_flowers102", IMAGES_OUTPUT_PATH)
ofp_args = get_args(p, args, "oxford_iiit_pet", IMAGES_OUTPUT_PATH)
subprocess.run(im_args)
subprocess.run(ofp_args)
subprocess.run(of_args)
if __name__ == "__main__":
args = parser.parse_args()
main(args)