-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDataset.py
34 lines (29 loc) · 1.13 KB
/
Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
class CatsDogsDataset(Dataset):
def __init__(self, file_list, transform=None):
self.folder_patch = './images/'
self.annotations = file_list
self.transform = transform
self.filelength = len(file_list)
def __len__(self):
return self.filelength
def __getitem__(self, idx):
classID = self.annotations['CLASS-ID'].iloc[idx]
img_path = self.annotations['image'].iloc[idx]
img_path = self.folder_patch + img_path
img = Image.open(img_path)
img = img.convert('RGB')
if self.transform is not None:
try:
img = self.transform(img)
except RuntimeError as e:
print(f"Exception: {e}")
print("Shape before normalization:", img.size)
print(img_path)
tot = transforms.ToTensor()
img_tensor = tot(img)
print("Input Tensor Shape:", img_tensor.shape)
print("Input Tensor Values:", img_tensor)
return img, classID - 1