Skip to content

Commit

Permalink
Merge pull request #121 from gnina/MolDataset
Browse files Browse the repository at this point in the history
Updated MolDataset
  • Loading branch information
dkoes authored Jul 2, 2024
2 parents 5a642b1 + 0b10133 commit efc5cb3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 39 deletions.
65 changes: 29 additions & 36 deletions python/torch_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def extra_repr(self):

class MolDataset(torch.utils.data.Dataset):
'''A pytorch mappable dataset for molgrid training files.'''
def __init__(self, *args, **kwargs):
def __init__(self, *args,
random_translation: float=0.0,
random_rotation: bool=False,
**kwargs):
'''Initialize mappable MolGridDataset.
:param input(s): File name(s) of training example files
:param typers: A tuple of AtomTypers to use
Expand All @@ -173,9 +176,10 @@ def __init__(self, *args, **kwargs):
:param ligmolcache: precalculated molcache2 file for ligand; if doesn't exist, will look in data_root
'''

self._random_translation, self._random_rotation = random_translation, random_rotation
print(self._random_translation, self._random_rotation)
if 'typers' in kwargs:
typers = kwargs['typers']
del kwargs['typers']
typers = kwargs.pop('typers')
self.examples = mg.ExampleDataset(*typers,**kwargs)
self.typers = typers
else:
Expand All @@ -184,39 +188,42 @@ def __init__(self, *args, **kwargs):
self.types_files = list(args)
self.examples.populate(self.types_files)

self.num_labels = self.examples.num_labels()


def __len__(self):
return len(self.examples)

def __getitem__(self, idx):
ex = self.examples[idx]
center = torch.tensor([i for i in ex.coord_sets[-1].center()])
center = torch.tensor(list(ex.coord_sets[-1].center()))
coordinates = ex.merge_coordinates()
if self._random_translation > 0 or self._random_rotation:
mg.Transform(ex.coord_sets[-1].center(), self._random_translation, self._random_rotation).forward(coordinates, coordinates)
if coordinates.has_vector_types() and coordinates.size() > 0:
atomtypes = torch.tensor(coordinates.type_vector.tonumpy(),dtype=torch.long).type('torch.FloatTensor')
else:
atomtypes = torch.tensor(coordinates.type_index.tonumpy(),dtype=torch.long).type('torch.FloatTensor')
coords = torch.tensor(coordinates.coords.tonumpy())
length = len(coords)
radii = torch.tensor(coordinates.radii.tonumpy())
labels = [ex.labels[lab] for lab in range(self.num_labels)]
return center, coords, atomtypes, radii, labels
labels = torch.tensor(ex.labels)
return length, center, coords, atomtypes, radii, labels


def __getstate__(self):
settings = self.examples.settings()
settings = self.examples.settings()
keyword_dict = {sett: getattr(settings, sett) for sett in dir(settings) if not sett.startswith('__')}
if self.typers is not None: ## This will fail if self.typers is not none, need a way to pickle AtomTypers
raise NotImplementedError('MolDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers)))
keyword_dict['typers'] = self.typers
keyword_dict['random_translation'] = self._random_translation
keyword_dict['random_rotation'] = self._random_rotation
return keyword_dict, self.types_files

def __setstate__(self,state):
kwargs=state[0]

self._random_translation = kwargs.pop('random_translation')
self._random_rotation = kwargs.pop('random_rotation')
if 'typers' in kwargs:
typers = kwargs['typers']
del kwargs['typers']
typers = kwargs.pop('typers')
self.examples = mg.ExampleDataset(*typers, **kwargs)
self.typers = typers
else:
Expand All @@ -225,33 +232,19 @@ def __setstate__(self,state):
self.types_files = list(state[1])
self.examples.populate(self.types_files)


self.num_labels = self.examples.num_labels()

@staticmethod
def collateMolDataset(batch):
'''collate_fn for use in torch.utils.data.Dataloader when using the MolDataset.
Returns lengths, centers, coords, types, radii, labels all padded to fit maximum size of batch'''
lens = []
centers = []
lcoords = []
ltypes = []
lradii = []
labels = []
for center,coords,types,radii,label in batch:
lens.append(coords.shape[0])
centers.append(center)
lcoords.append(coords)
ltypes.append(types)
lradii.append(radii)
labels.append(torch.tensor(label))


lengths = torch.tensor(lens)
lcoords = torch.nn.utils.rnn.pad_sequence(lcoords, batch_first=True)
ltypes = torch.nn.utils.rnn.pad_sequence(ltypes, batch_first=True)
lradii = torch.nn.utils.rnn.pad_sequence(lradii, batch_first=True)

centers = torch.stack(centers,dim=0)
labels = torch.stack(labels,dim=0)
batch_list = list(zip(*batch))
lengths = torch.tensor(batch_list[0])
centers = torch.stack(batch_list[1], dim=0)
coords = torch.nn.utils.rnn.pad_sequence(batch_list[2], batch_first=True)
types = torch.nn.utils.rnn.pad_sequence(batch_list[3], batch_first=True)
radii = torch.nn.utils.rnn.pad_sequence(batch_list[4], batch_first=True)
labels = torch.stack(batch_list[5], dim=0)

return lengths, centers, lcoords, ltypes, lradii, labels
return lengths, centers, coords, types, radii, labels
6 changes: 3 additions & 3 deletions test/test_example_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def test_pytorch_dataset():
ex = e.next()
coordinates = ex.merge_coordinates()

center, coords, types, radii, labels = m[0]
lengths, center, coords, types, radii, labels = m[0]

assert list(center.shape) == [3]
np.testing.assert_allclose(coords, coordinates.coords.tonumpy())
Expand All @@ -378,7 +378,7 @@ def test_pytorch_dataset():
np.testing.assert_allclose(labels[1], 6.05)
np.testing.assert_allclose(labels[-1], 0.162643)

center, coords, types, radii, labels = m[-1]
lengths, center, coords, types, radii, labels = m[-1]
assert labels[0] == 0
np.testing.assert_allclose(labels[1], -10.3)

Expand All @@ -396,7 +396,7 @@ def test_pytorch_dataset():
assert radii.shape[0] == 8
assert labels.shape[0] == 8

mcenter, mcoords, mtypes, mradii, mlabels = m[10]
mlengths, mcenter, mcoords, mtypes, mradii, mlabels = m[10]
np.testing.assert_allclose(center[2], mcenter)
np.testing.assert_allclose(coords[2][:lengths[2]], mcoords)
np.testing.assert_allclose(types[2][:lengths[2]], mtypes)
Expand Down

0 comments on commit efc5cb3

Please sign in to comment.