-
Notifications
You must be signed in to change notification settings - Fork 22
/
datasets.py
executable file
·40 lines (31 loc) · 1.1 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Copyright (c) 2021, InterDigital R&D France. All rights reserved.
#
# This source code is made available under the license found in the
# LICENSE.txt in the root directory of this source tree.
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from PIL import Image
from torchvision import transforms, utils
class LatentDataset(data.Dataset):
def __init__(self, latent_dir, label_dir, training_set=True):
dlatents = np.load(latent_dir)
labels = np.load(label_dir)
train_len = int(0.9*len(labels))
if training_set:
self.dlatents = dlatents[:train_len]
self.labels = labels[:train_len]
#self.process_score()
else:
self.dlatents = dlatents[train_len:]
self.labels = labels[train_len:]
self.length = len(self.labels)
def __len__(self):
return self.length
def __getitem__(self, idx):
dlatent = torch.tensor(self.dlatents[idx])
lbl = torch.tensor(self.labels[idx])
return dlatent, lbl