-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmisc.txt
44 lines (38 loc) · 1.61 KB
/
misc.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# # convert list of images into tensors
# infected_samples = torch.Tensor(ds_infected)
# not_infected_samples = torch.Tensor(ds_not_infected)
# # create datasets
# inf_samples_dataset = data.TensorDataset(infected_samples)
# ninf_samples_dataset = data.TensorDataset(not_infected_samples)
# # create data loader
# inf_samples_loader = data.DataLoader(inf_samples_dataset)
# ninf_samples_loader = data.DataLoader(ninf_samples_dataset)
#
# labels = np.concatenate([np.repeat(1, len(ds_covid19_yes)),np.repeat(0, len(ds_covid19_not))])
# images = ds_covid19_yes + ds_covid19_not
# #transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize(INPUT_SIZE), transforms.ToTensor()])
# transform = transforms.Compose([
# transforms.Resize(INPUT_SIZE),
# transforms.ToTensor()
# ])
# dataset = CustomDataset(images, labels, transform=transform)
# dataloader = data.DataLoader(dataset, batch_size=5) # create the data loader
# img, label = dataset.__getitem__(2) # For 99th item
# plt.imshow(img.numpy()[0])
# plt.show()
#%% md
class CustomDataset(Dataset):
def __init__(self, images: [np.ndarray], labels: np.ndarray, transform=None):
self.images = images
self.labels = torch.Tensor(labels)
self.transform = transform
def __getitem__(self, index):
x = self.images[index]
y = self.labels[index]
if self.transform:
x = Image.fromarray(x)
x = self.transform(x)
return x, y
def __len__(self):
return len(self.images) # return the number of elements in the dataset
https://git-lfs.github.com/ #manage .pt files