diff --git a/python/torch_bindings.py b/python/torch_bindings.py index febad42..275c9c4 100644 --- a/python/torch_bindings.py +++ b/python/torch_bindings.py @@ -1,6 +1,8 @@ import torch import molgrid as mg import types +from itertools import islice + def tensor_as_grid(t): '''Return a Grid view of tensor t''' gname = 'Grid' @@ -157,7 +159,7 @@ def extra_repr(self): self.gmaker.get_resolution(), self.gmaker.get_dimension(), self.center[0], self.center[1], self.center[2]) -class MolDataset(torch.utils.data.Dataset): +class MolMapDataset(torch.utils.data.Dataset): '''A pytorch mappable dataset for molgrid training files.''' def __init__(self, *args, random_translation: float=0.0, @@ -177,7 +179,6 @@ def __init__(self, *args, ''' self._random_translation, self._random_rotation = random_translation, random_rotation - print(self._random_translation, self._random_rotation) if 'typers' in kwargs: typers = kwargs.pop('typers') self.examples = mg.ExampleDataset(*typers,**kwargs) @@ -212,7 +213,7 @@ def __getstate__(self): settings = self.examples.settings() keyword_dict = {sett: getattr(settings, sett) for sett in dir(settings) if not sett.startswith('__')} if self.typers is not None: ## This will fail if self.typers is not none, need a way to pickle AtomTypers - raise NotImplementedError('MolDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers))) + raise NotImplementedError('MolMapDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers))) keyword_dict['typers'] = self.typers keyword_dict['random_translation'] = self._random_translation keyword_dict['random_rotation'] = self._random_rotation @@ -233,11 +234,9 @@ def __setstate__(self,state): self.examples.populate(self.types_files) - self.num_labels = self.examples.num_labels() - @staticmethod def collateMolDataset(batch): - '''collate_fn for use in torch.utils.data.Dataloader when using the MolDataset. + '''collate_fn for use in torch.utils.data.Dataloader when using the MolMapDataset. Returns lengths, centers, coords, types, radii, labels all padded to fit maximum size of batch''' batch_list = list(zip(*batch)) lengths = torch.tensor(batch_list[0]) @@ -248,3 +247,110 @@ def collateMolDataset(batch): labels = torch.stack(batch_list[5], dim=0) return lengths, centers, coords, types, radii, labels + +class MolIterDataset(torch.utils.data.IterableDataset): + '''A pytorch iterable dataset for molgrid training files. Use with a DataLoader(batch_size=None) for best results.''' + def __init__(self, *args, + random_translation: float=0.0, + random_rotation: bool=False, + **kwargs): + '''Initialize mappable MolGridDataset. + :param input(s): File name(s) of training example files + :param typers: A tuple of AtomTypers to use + :type typers: tuple + :param cache_structs: retain coordinates in memory for faster training + :param add_hydrogens: protonate molecules read using openbabel + :param duplicate_first: clone the first coordinate set to be paired with each of the remaining (receptor-ligand pairs) + :param make_vector_types: convert index types into one-hot encoded vector types + :param data_root: prefix for data files + :param recmolcache: precalculated molcache2 file for receptor (first molecule); if doesn't exist, will look in data _root + :param ligmolcache: precalculated molcache2 file for ligand; if doesn't exist, will look in data_root + ''' + + # molgrid.set_random_seed(kwargs['random_seed']) + self._random_translation, self._random_rotation = random_translation, random_rotation + if 'typers' in kwargs: + typers = kwargs.pop('typers') + self.examples = mg.ExampleProvider(*typers,**kwargs) + self.typers = typers + else: + self.examples = mg.ExampleProvider(**kwargs) + self.typers = None + self.types_files = list(args) + self.examples.populate(self.types_files) + + self._num_labels = self.examples.num_labels() + + def generate(self): + for batch in self.examples: + yield self.batch_to_tensors(batch) + + def batch_to_tensors(self, batch): + batch_lengths = torch.zeros(len(batch), dtype=torch.int64) + batch_centers = torch.zeros((len(batch), 3), dtype=torch.float32) + batch_coords = [] + batch_atomtypes = [] + batch_radii = [] + batch_labels = torch.zeros((len(batch),self._num_labels), dtype=torch.float32) + for idx, ex in enumerate(batch): + length, center, coords, atomtypes, radii, labels = self.example_to_tensor(ex) + batch_lengths[idx] = length + batch_centers[idx,:] = center + batch_coords.append(coords) + batch_atomtypes.append(atomtypes) + batch_radii.append(radii) + batch_labels[idx,:] = labels + pad_coords = torch.nn.utils.rnn.pad_sequence(batch_coords, batch_first=True) + pad_atomtypes = torch.nn.utils.rnn.pad_sequence(batch_atomtypes, batch_first=True) + pad_radii = torch.nn.utils.rnn.pad_sequence(batch_radii, batch_first=True) + return batch_lengths, batch_centers, pad_coords, pad_atomtypes, pad_radii, batch_labels + + + def example_to_tensor(self, ex): + center = torch.tensor(list(ex.coord_sets[-1].center())) + coordinates = ex.merge_coordinates() + if self._random_translation > 0 or self._random_rotation: + mg.Transform(ex.coord_sets[-1].center(), self._random_translation, self._random_rotation).forward(coordinates, coordinates) + if coordinates.has_vector_types() and coordinates.size() > 0: + atomtypes = torch.tensor(coordinates.type_vector.tonumpy(),dtype=torch.long).type('torch.FloatTensor') + else: + atomtypes = torch.tensor(coordinates.type_index.tonumpy(),dtype=torch.long).type('torch.FloatTensor') + coords = torch.tensor(coordinates.coords.tonumpy()) + length = len(coords) + radii = torch.tensor(coordinates.radii.tonumpy()) + labels = torch.tensor(ex.labels) + return length, center, coords, atomtypes, radii, labels + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + return self.generate() + dataset = worker_info.dataset + worker_id = worker_info.id + n_workers = worker_info.num_workers + + return islice(self.generate(), worker_id, None, n_workers) + + def __getstate__(self): + settings = self.examples.settings() + keyword_dict = {sett: getattr(settings, sett) for sett in dir(settings) if not sett.startswith('__')} + if self.typers is not None: ## This will fail if self.typers is not none, need a way to pickle AtomTypers + raise NotImplementedError('MolIterDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers))) + keyword_dict['typers'] = self.typers + keyword_dict['random_translation'] = self._random_translation + keyword_dict['random_rotation'] = self._random_rotation + return keyword_dict, self.types_files + + def __setstate__(self,state): + kwargs=state[0] + self._random_translation = kwargs.pop('random_translation') + self._random_rotation = kwargs.pop('random_rotation') + if 'typers' in kwargs: + typers = kwargs.pop('typers') + self.examples = mg.ExampleProvider(*typers, **kwargs) + self.typers = typers + else: + self.examples = mg.ExampleProvider(**kwargs) + self.typers = None + self.types_files = list(state[1]) + self.examples.populate(self.types_files) diff --git a/test/test_example_provider.py b/test/test_example_provider.py index 2565ee2..4f7fe02 100644 --- a/test/test_example_provider.py +++ b/test/test_example_provider.py @@ -354,12 +354,12 @@ def test_example_provider_iterator_interface(): break -def test_pytorch_dataset(): +def test_pytorch_mapdataset(): fname = datadir + "/small.types" e = molgrid.ExampleProvider(data_root=datadir + "/structs") e.populate(fname) - m = molgrid.MolDataset(fname, data_root=datadir + "/structs") + m = molgrid.MolMapDataset(fname, data_root=datadir + "/structs") assert len(m) == 1000 @@ -384,7 +384,7 @@ def test_pytorch_dataset(): '''Testing out the collate_fn when used with torch.utils.data.DataLoader''' torch_loader = torch.utils.data.DataLoader( - m, batch_size=8, collate_fn=molgrid.MolDataset.collateMolDataset) + m, batch_size=8, collate_fn=molgrid.MolMapDataset.collateMolDataset) iterator = iter(torch_loader) next(iterator) lengths, center, coords, types, radii, labels = next(iterator) @@ -422,6 +422,73 @@ def test_pytorch_dataset(): singlegrid = molgrid.MGrid4f(*shape) gmaker.forward(ex, singlegrid.cpu()) np.testing.assert_allclose(mgrid[2].tonumpy(),singlegrid.tonumpy(),atol=1e-5) + +def test_pytorch_iterdataset(): + fname = datadir + "/small.types" + + BSIZE = 25 + e = molgrid.ExampleProvider(data_root=datadir + "/structs", default_batch_size=BSIZE) + e.populate(fname) + m = molgrid.MolIterDataset(fname, data_root=datadir + "/structs", default_batch_size=BSIZE) + m_iter = iter(m) + + ex = e.next() + coordinates = ex.merge_coordinates() + + lengths, centers, coords, types, radii, labels = next(m_iter) + + assert list(centers.shape) == [BSIZE,3] + np.testing.assert_allclose(coords[0,:lengths[0],:], coordinates.coords.tonumpy()) + np.testing.assert_allclose(types[0,:lengths[0]], coordinates.type_index.tonumpy()) + np.testing.assert_allclose(radii[0,:lengths[0]], coordinates.radii.tonumpy()) + + assert len(labels) == BSIZE + assert len(labels[0]) == 3 + assert labels[0,0] == 1 + np.testing.assert_allclose(labels[0,1], 6.05) + np.testing.assert_allclose(labels[0,-1], 0.162643) + + # ensure it works with more than 1 worker + m.examples.reset() + torch_loader = torch.utils.data.DataLoader( + m, batch_size=None, num_workers=2) + iterator = iter(torch_loader) + next(iterator) + lengths, center, coords, types, radii, labels = next(iterator) + assert len(lengths) == BSIZE + assert center.shape[0] == BSIZE + assert coords.shape[0] == BSIZE + assert types.shape[0] == BSIZE + assert radii.shape[0] == BSIZE + assert labels.shape[0] == BSIZE + + e.reset() + e.next_batch() + ex = e.next_batch() + coordinates = ex[2].merge_coordinates() + np.testing.assert_allclose(center[2], np.array(list(ex[2].coord_sets[-1].center()))) + np.testing.assert_allclose(coords[2,:lengths[2]], coordinates.coords.tonumpy()) + np.testing.assert_allclose(types[2,:lengths[2]], coordinates.type_index.tonumpy()) + np.testing.assert_allclose(radii[2,:lengths[2]], coordinates.radii.tonumpy()) + assert len(labels[2]) == e.num_labels() + assert labels[2,0] == ex[2].labels[0] + assert labels[2,1] == ex[2].labels[1] + + gmaker = molgrid.GridMaker() + shape = gmaker.grid_dimensions(e.num_types()) + mgrid = molgrid.MGrid5f(BSIZE,*shape) + + gmaker.forward(center, coords, types, radii, mgrid.cpu()) + + mgridg = molgrid.MGrid5f(BSIZE,*shape) + gmaker.forward(center.cuda(), coords.cuda(), types.cuda(), radii.cuda(), mgridg.gpu()) + + np.testing.assert_allclose(mgrid.tonumpy(),mgridg.tonumpy(),atol=1e-5) + + #compare against standard provider + egrid = molgrid.MGrid5f(BSIZE,*shape) + gmaker.forward(ex, egrid.cpu()) + np.testing.assert_allclose(mgridg.tonumpy(),egrid.tonumpy(),atol=1e-5) def test_duplicated_examples():