-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathdata_manager.py
68 lines (48 loc) · 2.13 KB
/
data_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import glob
import cv2
import random
import numpy as np
import pickle
import os
from torch.utils import data
class TrainDataset(data.Dataset):
def __init__(self, config):
super().__init__()
self.config = config
train_list_file = os.path.join(config.datasets_dir, config.train_list)
# 如果数据集尚未分割,则进行训练集和测试集的分割
if not os.path.exists(train_list_file) or os.path.getsize(train_list_file) == 0:
files = os.listdir(os.path.join(config.datasets_dir, 'ground_truth'))
random.shuffle(files)
n_train = int(config.train_size * len(files))
train_list = files[:n_train]
test_list = files[n_train:]
np.savetxt(os.path.join(config.datasets_dir, config.train_list), np.array(train_list), fmt='%s')
np.savetxt(os.path.join(config.datasets_dir, config.test_list), np.array(test_list), fmt='%s')
self.imlist = np.loadtxt(train_list_file, str)
def __getitem__(self, index):
t = cv2.imread(os.path.join(self.config.datasets_dir, 'ground_truth', str(self.imlist[index])), 1).astype(np.float32)
x = cv2.imread(os.path.join(self.config.datasets_dir, 'cloudy_image', str(self.imlist[index])), 1).astype(np.float32)
M = np.clip((t-x).sum(axis=2), 0, 1).astype(np.float32)
x = x / 255
t = t / 255
x = x.transpose(2, 0, 1)
t = t.transpose(2, 0, 1)
return x, t, M
def __len__(self):
return len(self.imlist)
class TestDataset(data.Dataset):
def __init__(self, test_dir, in_ch, out_ch):
super().__init__()
self.test_dir = test_dir
self.in_ch = in_ch
self.out_ch = out_ch
self.test_files = os.listdir(os.path.join(test_dir, 'cloudy_image'))
def __getitem__(self, index):
filename = os.path.basename(self.test_files[index])
x = cv2.imread(os.path.join(self.test_dir, 'cloudy_image', filename), 1).astype(np.float32)
x = x / 255
x = x.transpose(2, 0, 1)
return x, filename
def __len__(self):
return len(self.test_files)