Skip to content

Commit

Permalink
Merge pull request #122 from gnina/MolDataset
Browse files Browse the repository at this point in the history
MolDataset -> MolMapDataset and MolIterDataset
  • Loading branch information
dkoes authored Jul 9, 2024
2 parents efc5cb3 + c355145 commit 921de63
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 9 deletions.
118 changes: 112 additions & 6 deletions python/torch_bindings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import molgrid as mg
import types
from itertools import islice

def tensor_as_grid(t):
'''Return a Grid view of tensor t'''
gname = 'Grid'
Expand Down Expand Up @@ -157,7 +159,7 @@ def extra_repr(self):
self.gmaker.get_resolution(), self.gmaker.get_dimension(), self.center[0], self.center[1], self.center[2])


class MolDataset(torch.utils.data.Dataset):
class MolMapDataset(torch.utils.data.Dataset):
'''A pytorch mappable dataset for molgrid training files.'''
def __init__(self, *args,
random_translation: float=0.0,
Expand All @@ -177,7 +179,6 @@ def __init__(self, *args,
'''

self._random_translation, self._random_rotation = random_translation, random_rotation
print(self._random_translation, self._random_rotation)
if 'typers' in kwargs:
typers = kwargs.pop('typers')
self.examples = mg.ExampleDataset(*typers,**kwargs)
Expand Down Expand Up @@ -212,7 +213,7 @@ def __getstate__(self):
settings = self.examples.settings()
keyword_dict = {sett: getattr(settings, sett) for sett in dir(settings) if not sett.startswith('__')}
if self.typers is not None: ## This will fail if self.typers is not none, need a way to pickle AtomTypers
raise NotImplementedError('MolDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers)))
raise NotImplementedError('MolMapDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers)))
keyword_dict['typers'] = self.typers
keyword_dict['random_translation'] = self._random_translation
keyword_dict['random_rotation'] = self._random_rotation
Expand All @@ -233,11 +234,9 @@ def __setstate__(self,state):
self.examples.populate(self.types_files)


self.num_labels = self.examples.num_labels()

@staticmethod
def collateMolDataset(batch):
'''collate_fn for use in torch.utils.data.Dataloader when using the MolDataset.
'''collate_fn for use in torch.utils.data.Dataloader when using the MolMapDataset.
Returns lengths, centers, coords, types, radii, labels all padded to fit maximum size of batch'''
batch_list = list(zip(*batch))
lengths = torch.tensor(batch_list[0])
Expand All @@ -248,3 +247,110 @@ def collateMolDataset(batch):
labels = torch.stack(batch_list[5], dim=0)

return lengths, centers, coords, types, radii, labels

class MolIterDataset(torch.utils.data.IterableDataset):
'''A pytorch iterable dataset for molgrid training files. Use with a DataLoader(batch_size=None) for best results.'''
def __init__(self, *args,
random_translation: float=0.0,
random_rotation: bool=False,
**kwargs):
'''Initialize mappable MolGridDataset.
:param input(s): File name(s) of training example files
:param typers: A tuple of AtomTypers to use
:type typers: tuple
:param cache_structs: retain coordinates in memory for faster training
:param add_hydrogens: protonate molecules read using openbabel
:param duplicate_first: clone the first coordinate set to be paired with each of the remaining (receptor-ligand pairs)
:param make_vector_types: convert index types into one-hot encoded vector types
:param data_root: prefix for data files
:param recmolcache: precalculated molcache2 file for receptor (first molecule); if doesn't exist, will look in data _root
:param ligmolcache: precalculated molcache2 file for ligand; if doesn't exist, will look in data_root
'''

# molgrid.set_random_seed(kwargs['random_seed'])
self._random_translation, self._random_rotation = random_translation, random_rotation
if 'typers' in kwargs:
typers = kwargs.pop('typers')
self.examples = mg.ExampleProvider(*typers,**kwargs)
self.typers = typers
else:
self.examples = mg.ExampleProvider(**kwargs)
self.typers = None
self.types_files = list(args)
self.examples.populate(self.types_files)

self._num_labels = self.examples.num_labels()

def generate(self):
for batch in self.examples:
yield self.batch_to_tensors(batch)

def batch_to_tensors(self, batch):
batch_lengths = torch.zeros(len(batch), dtype=torch.int64)
batch_centers = torch.zeros((len(batch), 3), dtype=torch.float32)
batch_coords = []
batch_atomtypes = []
batch_radii = []
batch_labels = torch.zeros((len(batch),self._num_labels), dtype=torch.float32)
for idx, ex in enumerate(batch):
length, center, coords, atomtypes, radii, labels = self.example_to_tensor(ex)
batch_lengths[idx] = length
batch_centers[idx,:] = center
batch_coords.append(coords)
batch_atomtypes.append(atomtypes)
batch_radii.append(radii)
batch_labels[idx,:] = labels
pad_coords = torch.nn.utils.rnn.pad_sequence(batch_coords, batch_first=True)
pad_atomtypes = torch.nn.utils.rnn.pad_sequence(batch_atomtypes, batch_first=True)
pad_radii = torch.nn.utils.rnn.pad_sequence(batch_radii, batch_first=True)
return batch_lengths, batch_centers, pad_coords, pad_atomtypes, pad_radii, batch_labels


def example_to_tensor(self, ex):
center = torch.tensor(list(ex.coord_sets[-1].center()))
coordinates = ex.merge_coordinates()
if self._random_translation > 0 or self._random_rotation:
mg.Transform(ex.coord_sets[-1].center(), self._random_translation, self._random_rotation).forward(coordinates, coordinates)
if coordinates.has_vector_types() and coordinates.size() > 0:
atomtypes = torch.tensor(coordinates.type_vector.tonumpy(),dtype=torch.long).type('torch.FloatTensor')
else:
atomtypes = torch.tensor(coordinates.type_index.tonumpy(),dtype=torch.long).type('torch.FloatTensor')
coords = torch.tensor(coordinates.coords.tonumpy())
length = len(coords)
radii = torch.tensor(coordinates.radii.tonumpy())
labels = torch.tensor(ex.labels)
return length, center, coords, atomtypes, radii, labels

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
return self.generate()
dataset = worker_info.dataset
worker_id = worker_info.id
n_workers = worker_info.num_workers

return islice(self.generate(), worker_id, None, n_workers)

def __getstate__(self):
settings = self.examples.settings()
keyword_dict = {sett: getattr(settings, sett) for sett in dir(settings) if not sett.startswith('__')}
if self.typers is not None: ## This will fail if self.typers is not none, need a way to pickle AtomTypers
raise NotImplementedError('MolIterDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers)))
keyword_dict['typers'] = self.typers
keyword_dict['random_translation'] = self._random_translation
keyword_dict['random_rotation'] = self._random_rotation
return keyword_dict, self.types_files

def __setstate__(self,state):
kwargs=state[0]
self._random_translation = kwargs.pop('random_translation')
self._random_rotation = kwargs.pop('random_rotation')
if 'typers' in kwargs:
typers = kwargs.pop('typers')
self.examples = mg.ExampleProvider(*typers, **kwargs)
self.typers = typers
else:
self.examples = mg.ExampleProvider(**kwargs)
self.typers = None
self.types_files = list(state[1])
self.examples.populate(self.types_files)
73 changes: 70 additions & 3 deletions test/test_example_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,12 @@ def test_example_provider_iterator_interface():
break


def test_pytorch_dataset():
def test_pytorch_mapdataset():
fname = datadir + "/small.types"

e = molgrid.ExampleProvider(data_root=datadir + "/structs")
e.populate(fname)
m = molgrid.MolDataset(fname, data_root=datadir + "/structs")
m = molgrid.MolMapDataset(fname, data_root=datadir + "/structs")

assert len(m) == 1000

Expand All @@ -384,7 +384,7 @@ def test_pytorch_dataset():

'''Testing out the collate_fn when used with torch.utils.data.DataLoader'''
torch_loader = torch.utils.data.DataLoader(
m, batch_size=8, collate_fn=molgrid.MolDataset.collateMolDataset)
m, batch_size=8, collate_fn=molgrid.MolMapDataset.collateMolDataset)
iterator = iter(torch_loader)
next(iterator)
lengths, center, coords, types, radii, labels = next(iterator)
Expand Down Expand Up @@ -422,6 +422,73 @@ def test_pytorch_dataset():
singlegrid = molgrid.MGrid4f(*shape)
gmaker.forward(ex, singlegrid.cpu())
np.testing.assert_allclose(mgrid[2].tonumpy(),singlegrid.tonumpy(),atol=1e-5)

def test_pytorch_iterdataset():
fname = datadir + "/small.types"

BSIZE = 25
e = molgrid.ExampleProvider(data_root=datadir + "/structs", default_batch_size=BSIZE)
e.populate(fname)
m = molgrid.MolIterDataset(fname, data_root=datadir + "/structs", default_batch_size=BSIZE)
m_iter = iter(m)

ex = e.next()
coordinates = ex.merge_coordinates()

lengths, centers, coords, types, radii, labels = next(m_iter)

assert list(centers.shape) == [BSIZE,3]
np.testing.assert_allclose(coords[0,:lengths[0],:], coordinates.coords.tonumpy())
np.testing.assert_allclose(types[0,:lengths[0]], coordinates.type_index.tonumpy())
np.testing.assert_allclose(radii[0,:lengths[0]], coordinates.radii.tonumpy())

assert len(labels) == BSIZE
assert len(labels[0]) == 3
assert labels[0,0] == 1
np.testing.assert_allclose(labels[0,1], 6.05)
np.testing.assert_allclose(labels[0,-1], 0.162643)

# ensure it works with more than 1 worker
m.examples.reset()
torch_loader = torch.utils.data.DataLoader(
m, batch_size=None, num_workers=2)
iterator = iter(torch_loader)
next(iterator)
lengths, center, coords, types, radii, labels = next(iterator)
assert len(lengths) == BSIZE
assert center.shape[0] == BSIZE
assert coords.shape[0] == BSIZE
assert types.shape[0] == BSIZE
assert radii.shape[0] == BSIZE
assert labels.shape[0] == BSIZE

e.reset()
e.next_batch()
ex = e.next_batch()
coordinates = ex[2].merge_coordinates()
np.testing.assert_allclose(center[2], np.array(list(ex[2].coord_sets[-1].center())))
np.testing.assert_allclose(coords[2,:lengths[2]], coordinates.coords.tonumpy())
np.testing.assert_allclose(types[2,:lengths[2]], coordinates.type_index.tonumpy())
np.testing.assert_allclose(radii[2,:lengths[2]], coordinates.radii.tonumpy())
assert len(labels[2]) == e.num_labels()
assert labels[2,0] == ex[2].labels[0]
assert labels[2,1] == ex[2].labels[1]

gmaker = molgrid.GridMaker()
shape = gmaker.grid_dimensions(e.num_types())
mgrid = molgrid.MGrid5f(BSIZE,*shape)

gmaker.forward(center, coords, types, radii, mgrid.cpu())

mgridg = molgrid.MGrid5f(BSIZE,*shape)
gmaker.forward(center.cuda(), coords.cuda(), types.cuda(), radii.cuda(), mgridg.gpu())

np.testing.assert_allclose(mgrid.tonumpy(),mgridg.tonumpy(),atol=1e-5)

#compare against standard provider
egrid = molgrid.MGrid5f(BSIZE,*shape)
gmaker.forward(ex, egrid.cpu())
np.testing.assert_allclose(mgridg.tonumpy(),egrid.tonumpy(),atol=1e-5)


def test_duplicated_examples():
Expand Down

0 comments on commit 921de63

Please sign in to comment.