-
Notifications
You must be signed in to change notification settings - Fork 13
/
create_samples.py
136 lines (101 loc) · 3.7 KB
/
create_samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import os.path
import random
from PIL import Image
from my_image_folder import is_image_file
def oversample_num(wind): # return how many to oversample (according to wind level)
if wind < 60:
return 1
if wind < 80:
return 1 + random.randint(0,1)
if wind < 100:
return 1 + random.randint(0,2)
return 1 + random.randint(0,10)
def save_file(f,fname,f_root): # oversample specific copies
wind = int(fname.split('_')[2])
if oversample :
cps = oversample_num(wind)
else :
cps = 1
global count
count = count + cps
for i in range(0,cps):
temp = fname.split('.')
temp[0] = temp[0]+'_'+str(i)
new_fname = temp[0]+'.'+temp[1] # append a copy-number to filename
f.save(f_root+new_fname)
def if_match(f1,f2): # match : f1 is 6-hour earlier than f2, and they are same ty
tname1 = f1.split('_')
tname2 = f2.split('_')
if tname1[0]!=tname2[0]:
return False
date1 = tname1[1]
date2 = tname2[1]
h1 = date1[len(date1)-1]
h2 = date2[len(date2)-1]
# the time end with 4 kinds of number : 00,06,12,18
if (h1=='0' and h2=='6')or(h1=='6' and h2=='2')or(h1=='2' and h2=='8')or(h1=='8' and h2=='0'):
return True
else :
return False
def cut_pics(p): # only reserve central area
box = (128,128,384,384)
p = p.crop(box)
return p
def merge_pics(p1,p2): # red-channel:6-hour earlier pic, green-channel:current pic
# blue-channel:useless/unmeaning
p1 = p1.convert('RGB')
p2 = p2.convert('RGB')
r,_,_ = p1.split()
_,g,b = p2.split()
im = Image.merge('RGB',(r,g,b))
return im
def create_sample(source_dir,fname_1,fname_2,target_dir): # combine two raw images to a legal sample for our CNN
complete_fname_1 = os.path.join(root,fnames[i-1])
complete_fname_2 = os.path.join(root,fnames[i])
if not(is_image_file(complete_fname_1) and is_image_file(complete_fname_2)):
return 'Not image file: ',complete_fname_1,complete_fname_2
if not if_match(fname_1,fname_2):
return 'Two images are not matched: ',fname_1,fname_2
img_1 = Image.open(complete_fname_1)
img_2 = Image.open(complete_fname_2)
img_1 = cut_pics(img_1)
img_2 = cut_pics(img_2)
im = merge_pics(img_1,img_2)
save_file(im,fname_2,target_dir)
if __name__ == '__main__':
path_ = os.path.abspath('.')
raw_dir = path_ + '/tys_raw/'
train_root = path_ + '/train_set/'
if not os.path.exists(train_root):
os.mkdir(train_root)
test_root = path_ + '/test_set/'
if not os.path.exists(test_root):
os.mkdir(test_root)
global count, oversample
count = 0
oversample = True
for root, _, fnames in sorted(os.walk(raw_dir)):
fnames = sorted(fnames)
boundary = int(len(fnames)*0.8) # 80% samples as train set and 20% samples as test set
for i in range(1,boundary): # create train set
info = create_sample(root,fnames[i-1],fnames[i],train_root)
if info:
print info
if count > 30000 :
print 'Exceed the upper limit of a single file.'
break
if i % 100 == 99 :
print 'have processed ',i+1,' files.'
print 'items in train set: ',count
count = 0
for i in range(boundary,len(fnames)): # create test set
info = create_sample(root,fnames[i-1],fnames[i],test_root)
if info:
print info
if count > 30000 :
print 'Exceed the upper limit of a single file.'
break
if i % 100 == 99 :
print 'have processed ',i+1,' files.'
print 'items in test set: ',count