-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
33 lines (23 loc) · 950 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch.utils.data as data
import torch
import h5py
import cv2
import numpy as np
from EdgeDetection import Edge
class Dataset_Pro(data.Dataset):
def __init__(self, file_path):
super(Dataset_Pro, self).__init__()
data = h5py.File(file_path,'r+') # NxCxHxW = 0x1x2x3=nx191x64x64 channel height width
self.gt = data.get("GT")
self.lms = data.get("LMS")
self.ms = data.get("MS")
self.pan = data.get("PAN")
self.edge_pan = Edge(self.pan)
def __getitem__(self, index):
return torch.from_numpy(self.gt[index, :, :, :]).float(), \
torch.from_numpy(self.lms[index, :, :, :]).float(), \
torch.from_numpy(self.ms[index, :, :, :]).float(), \
torch.from_numpy(self.pan[index, :, :, :]).float(), \
torch.from_numpy(self.edge_pan[index, :, :, :]).float()
def __len__(self):
return self.gt.shape[0]