forked from arneschneuing/DiffSBDD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
70 lines (57 loc) · 2.63 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from itertools import accumulate
import numpy as np
import torch
from torch.utils.data import Dataset
class ProcessedLigandPocketDataset(Dataset):
def __init__(self, npz_path, center=True, transform=None):
self.transform = transform
with np.load(npz_path, allow_pickle=True) as f:
data = {key: val for key, val in f.items()}
# split data based on mask
self.data = {}
for (k, v) in data.items():
if k == 'names' or k == 'receptors':
self.data[k] = v
continue
sections = np.where(np.diff(data['lig_mask']))[0] + 1 \
if 'lig' in k \
else np.where(np.diff(data['pocket_mask']))[0] + 1
self.data[k] = [torch.from_numpy(x) for x in np.split(v, sections)]
# add number of nodes for convenience
if k == 'lig_mask':
self.data['num_lig_atoms'] = \
torch.tensor([len(x) for x in self.data['lig_mask']])
elif k == 'pocket_mask':
self.data['num_pocket_nodes'] = \
torch.tensor([len(x) for x in self.data['pocket_mask']])
if center:
for i in range(len(self.data['lig_coords'])):
mean = (self.data['lig_coords'][i].sum(0) +
self.data['pocket_coords'][i].sum(0)) / \
(len(self.data['lig_coords'][i]) + len(self.data['pocket_coords'][i]))
self.data['lig_coords'][i] = self.data['lig_coords'][i] - mean
self.data['pocket_coords'][i] = self.data['pocket_coords'][i] - mean
def __len__(self):
return len(self.data['names'])
def __getitem__(self, idx):
data = {key: val[idx] for key, val in self.data.items()}
if self.transform is not None:
data = self.transform(data)
return data
@staticmethod
def collate_fn(batch):
out = {}
for prop in batch[0].keys():
if prop == 'names' or prop == 'receptors':
out[prop] = [x[prop] for x in batch]
elif prop == 'num_lig_atoms' or prop == 'num_pocket_nodes' \
or prop == 'num_virtual_atoms':
out[prop] = torch.tensor([x[prop] for x in batch])
elif 'mask' in prop:
# make sure indices in batch start at zero (needed for
# torch_scatter)
out[prop] = torch.cat([i * torch.ones(len(x[prop]))
for i, x in enumerate(batch)], dim=0)
else:
out[prop] = torch.cat([x[prop] for x in batch], dim=0)
return out