From 0c0c7ed13c5eac36a2b4d073b91156e448e5c3d7 Mon Sep 17 00:00:00 2001 From: Max Liu Date: Mon, 26 Aug 2019 21:03:34 -0400 Subject: [PATCH] Add/update comparison methods for Atom, Bond, Molecule, Species Python 3 no longer allows implicit sorting of custom classes. This implements hash, eq, lt, gt special methods for these classes. In general, the goal was to maintain prior behavior for hash and eq and implement reasonable behavior for gt and lt. Sorting of these classes is generally done to enable deterministic behavior, rather than any quantitative reason, so that was the main consideration. --- rmgpy/molecule/molecule.py | 105 +++++++++++++++++++++++++-------- rmgpy/molecule/moleculeTest.py | 92 +++++++++++++++++++++++++++++ rmgpy/species.py | 30 ++++++++++ rmgpy/speciesTest.py | 50 ++++++++++++++++ 4 files changed, 252 insertions(+), 25 deletions(-) diff --git a/rmgpy/molecule/molecule.py b/rmgpy/molecule/molecule.py index 4b5bab855d..f836deefa8 100644 --- a/rmgpy/molecule/molecule.py +++ b/rmgpy/molecule/molecule.py @@ -164,6 +164,32 @@ def __setstate__(self, d): self.atomType = atomTypes[d['atomType']] if d['atomType'] else None self.lonePairs = d['lonePairs'] + def __hash__(self): + """ + Define a custom hash method to allow Atom objects to be used in dictionaries and sets. + """ + return hash(('Atom', self.symbol)) + + def __eq__(self, other): + """Method to test equality of two Atom objects.""" + return self is other + + def __lt__(self, other): + """Define less than comparison. For comparing against other Atom objects (e.g. when sorting).""" + if isinstance(other, Atom): + return self.sorting_key < other.sorting_key + else: + raise NotImplementedError('Cannot perform less than comparison between Atom and ' + '{0}.'.format(type(other).__name__)) + + def __gt__(self, other): + """Define greater than comparison. For comparing against other Atom objects (e.g. when sorting).""" + if isinstance(other, Atom): + return self.sorting_key > other.sorting_key + else: + raise NotImplementedError('Cannot perform greater than comparison between Atom and ' + '{0}.'.format(type(other).__name__)) + @property def mass(self): return self.element.mass @@ -235,14 +261,6 @@ def equivalent(self, other, strict=True): return False return True - def get_descriptor(self): - """ - Return a tuple used for sorting atoms. - Currently uses atomic number, connectivity value, - radical electrons, lone pairs, and charge - """ - return self.number, -getVertexConnectivityValue(self), self.radicalElectrons, self.lonePairs, self.charge - def isSpecificCaseOf(self, other): """ Return ``True`` if `self` is a specific case of `other`, or ``False`` @@ -551,6 +569,34 @@ def __reduce__(self): """ return (Bond, (self.vertex1, self.vertex2, self.order)) + def __hash__(self): + """ + Define a custom hash method to allow Bond objects to be used in dictionaries and sets. + """ + return hash(('Bond', self.order, + self.atom1.symbol if self.atom1 is not None else '', + self.atom2.symbol if self.atom2 is not None else '')) + + def __eq__(self, other): + """Method to test equality of two Bond objects.""" + return self is other + + def __lt__(self, other): + """Define less than comparison. For comparing against other Bond objects (e.g. when sorting).""" + if isinstance(other, Bond): + return self.sorting_key < other.sorting_key + else: + raise NotImplementedError('Cannot perform less than comparison between Bond and ' + '{0}.'.format(type(other).__name__)) + + def __gt__(self, other): + """Define greater than comparison. For comparing against other Bond objects (e.g. when sorting).""" + if isinstance(other, Bond): + return self.sorting_key > other.sorting_key + else: + raise NotImplementedError('Cannot perform greater than comparison between Bond and ' + '{0}.'.format(type(other).__name__)) + @property def atom1(self): return self.vertex1 @@ -862,26 +908,35 @@ def __deepcopy__(self, memo): return self.copy(deep=True) def __hash__(self): - return hash((self.fingerprint)) + """ + Define a custom hash method to allow Molecule objects to be used in dictionaries and sets. - def __richcmp__(x, y, op): - if op == 2: # Py_EQ - return x.is_equal(y) - if op == 3: # Py_NE - return not x.is_equal(y) - else: - raise NotImplementedError("Can only check equality of molecules, not > or <") + Use the fingerprint property, which is currently defined as the molecular formula, though + this is not an ideal hash since there will be significant hash collision, leading to inefficient lookups. + """ + return hash(('Molecule', self.fingerprint)) - def is_equal(self, other): + def __eq__(self, other): """Method to test equality of two Molecule objects.""" - if not isinstance(other, Molecule): - return False # different type - elif self is other: - return True # same reference in memory - elif self.fingerprint != other.fingerprint: - return False + return self is other or (isinstance(other, Molecule) and + self.fingerprint == other.fingerprint and + self.isIsomorphic(other)) + + def __lt__(self, other): + """Define less than comparison. For comparing against other Molecule objects (e.g. when sorting).""" + if isinstance(other, Molecule): + return self.sorting_key < other.sorting_key + else: + raise NotImplementedError('Cannot perform less than comparison between Molecule and ' + '{0}.'.format(type(other).__name__)) + + def __gt__(self, other): + """Define greater than comparison. For comparing against other Molecule objects (e.g. when sorting).""" + if isinstance(other, Molecule): + return self.sorting_key > other.sorting_key else: - return self.isIsomorphic(other) + raise NotImplementedError('Cannot perform greater than comparison between Molecule and ' + '{0}.'.format(type(other).__name__)) def __str__(self): """ @@ -1063,7 +1118,7 @@ def sortAtoms(self): if vertex.sortingLabel < 0: self.updateConnectivityValues() break - self.atoms.sort(key=lambda a: a.get_descriptor(), reverse=True) + self.atoms.sort(reverse=True) for index, vertex in enumerate(self.vertices): vertex.sortingLabel = index diff --git a/rmgpy/molecule/moleculeTest.py b/rmgpy/molecule/moleculeTest.py index 4fc7116cd6..be491724cb 100644 --- a/rmgpy/molecule/moleculeTest.py +++ b/rmgpy/molecule/moleculeTest.py @@ -49,6 +49,11 @@ def setUp(self): """ self.atom = Atom(element=getElement('C'), radicalElectrons=1, charge=0, label='*1', lonePairs=0) + self.atom1 = Atom(element=getElement('C'), radicalElectrons=0, lonePairs=0) + self.atom2 = Atom(element=getElement('C'), radicalElectrons=0, lonePairs=0) + self.atom3 = Atom(element=getElement('C'), radicalElectrons=1, lonePairs=0) + self.atom4 = Atom(element=getElement('H'), radicalElectrons=1, lonePairs=0) + def testMass(self): """ Test the Atom.mass property. @@ -67,6 +72,33 @@ def testSymbol(self): """ self.assertTrue(self.atom.symbol == self.atom.element.symbol) + def test_equality(self): + """Test that we can perform equality comparison with Atom objects""" + self.assertEqual(self.atom1, self.atom1) + self.assertNotEqual(self.atom1, self.atom2) + self.assertNotEqual(self.atom1, self.atom3) + self.assertNotEqual(self.atom1, self.atom4) + + def test_less_than(self): + """Test that we can perform less than comparison with Atom objects""" + self.assertFalse(self.atom1 < self.atom2) # Because the sorting keys should be identical + self.assertLess(self.atom2, self.atom3) + self.assertLess(self.atom4, self.atom1) + + def test_greater_than(self): + """Test that we can perform greater than comparison with Atom objects""" + self.assertFalse(self.atom2 > self.atom1) # Because the sorting keys should be identical + self.assertGreater(self.atom3, self.atom1) + self.assertGreater(self.atom1, self.atom4) + + def test_hash(self): + """Test behavior of Atom hashing using dictionaries and sets""" + # Test dictionary behavior + self.assertEqual(len(dict.fromkeys([self.atom1, self.atom2, self.atom3, self.atom4])), 4) + + # Test set behavior + self.assertEqual(len({self.atom1, self.atom2, self.atom3, self.atom4}), 4) + def testIsHydrogen(self): """ Test the Atom.isHydrogen() method. @@ -390,6 +422,38 @@ def setUp(self): self.bond = Bond(atom1=None, atom2=None, order=2) self.orderList = [1, 2, 3, 4, 1.5, 0.30000000000000004] + self.bond1 = Bond(atom1=None, atom2=None, order=1) + self.bond2 = Bond(atom1=None, atom2=None, order=1) + self.bond3 = Bond(atom1=None, atom2=None, order=2) + self.bond4 = Bond(atom1=None, atom2=None, order=3) + + def test_equality(self): + """Test that we can perform equality comparison with Bond objects""" + self.assertEqual(self.bond1, self.bond1) + self.assertNotEqual(self.bond1, self.bond2) + self.assertNotEqual(self.bond1, self.bond3) + self.assertNotEqual(self.bond1, self.bond4) + + def test_less_than(self): + """Test that we can perform less than comparison with Bond objects""" + self.assertFalse(self.bond1 < self.bond2) # Because the sorting keys should be identical + self.assertLess(self.bond2, self.bond3) + self.assertLess(self.bond3, self.bond4) + + def test_greater_than(self): + """Test that we can perform greater than comparison with Bond objects""" + self.assertFalse(self.bond2 > self.bond1) # Because the sorting keys should be identical + self.assertGreater(self.bond3, self.bond1) + self.assertGreater(self.bond4, self.bond1) + + def test_hash(self): + """Test behavior of Bond hashing using dictionaries and sets""" + # Test dictionary behavior + self.assertEqual(len(dict.fromkeys([self.bond1, self.bond2, self.bond3, self.bond4])), 4) + + # Test set behavior + self.assertEqual(len({self.bond1, self.bond2, self.bond3, self.bond4}), 4) + def testGetOrderStr(self): """ test the Bond.getOrderStr() method @@ -773,6 +837,34 @@ def setUp(self): self.mHBonds = Molecule().fromSMILES('C(NC=O)OO') + self.mol1 = Molecule(SMILES='C') + self.mol2 = Molecule(SMILES='C') + self.mol3 = Molecule(SMILES='CC') + + def test_equality(self): + """Test that we can perform equality comparison with Molecule objects""" + self.assertEqual(self.mol1, self.mol1) + self.assertEqual(self.mol1, self.mol2) + self.assertNotEqual(self.mol1, self.mol3) + + def test_less_than(self): + """Test that we can perform less than comparison with Molecule objects""" + self.assertFalse(self.mol1 < self.mol2) # Because the sorting keys should be identical + self.assertLess(self.mol1, self.mol3) + + def test_greater_than(self): + """Test that we can perform greater than comparison with Molecule objects""" + self.assertFalse(self.mol2 > self.mol1) # Because the sorting keys should be identical + self.assertGreater(self.mol3, self.mol1) + + def test_hash(self): + """Test behavior of Molecule hashing using dictionaries and sets""" + # Test dictionary behavior + self.assertEqual(len(dict.fromkeys([self.mol1, self.mol2, self.mol3])), 2) + + # Test set behavior + self.assertEqual(len({self.mol1, self.mol2, self.mol3}), 2) + def testClearLabeledAtoms(self): """ Test the Molecule.clearLabeledAtoms() method. diff --git a/rmgpy/species.py b/rmgpy/species.py index 6760dc1a24..c788ff6cbc 100644 --- a/rmgpy/species.py +++ b/rmgpy/species.py @@ -184,6 +184,36 @@ def __reduce__(self): return (Species, (self.index, self.label, self.thermo, self.conformer, self.molecule, self.transportData, self.molecularWeight, self.energyTransferModel, self.reactive, self.props)) + def __hash__(self): + """ + Define a custom hash method to allow Species objects to be used in dictionaries and sets. + + Use the fingerprint property, which is taken from the first molecule entry. + This is currently defined as the molecular formula, which is not an ideal hash, since there will be significant + hash collisions, leading to inefficient lookups. + """ + return hash(('Species', self.fingerprint)) + + def __eq__(self, other): + """Define equality comparison. Define as a reference comparison""" + return self is other + + def __lt__(self, other): + """Define less than comparison. For comparing against other Species objects (e.g. when sorting).""" + if isinstance(other, Species): + return self.sorting_key < other.sorting_key + else: + raise NotImplementedError('Cannot perform less than comparison between Species and ' + '{0}.'.format(type(other).__name__)) + + def __gt__(self, other): + """Define greater than comparison. For comparing against other Species objects (e.g. when sorting).""" + if isinstance(other, Species): + return self.sorting_key > other.sorting_key + else: + raise NotImplementedError('Cannot perform greater than comparison between Species and ' + '{0}.'.format(type(other).__name__)) + @property def sorting_key(self): """Returns a sorting key for comparing Species objects. Read-only""" diff --git a/rmgpy/speciesTest.py b/rmgpy/speciesTest.py index 1356ce5746..df6ef54e54 100644 --- a/rmgpy/speciesTest.py +++ b/rmgpy/speciesTest.py @@ -154,6 +154,56 @@ def testOutput(self): self.assertEqual(self.species.molecularWeight.units, species.molecularWeight.units) self.assertEqual(self.species.reactive, species.reactive) + def test_equality(self): + """Test that we can perform equality comparison with Species objects""" + spc1 = Species(SMILES='C') + spc2 = Species(SMILES='C') + + self.assertNotEqual(spc1, spc2) + self.assertEqual(spc1, spc1) + self.assertEqual(spc2, spc2) + + def test_less_than(self): + """Test that we can perform less than comparison with Species objects""" + spc1 = Species(index=1, label='a', SMILES='C') + spc2 = Species(index=2, label='a', SMILES='C') + spc3 = Species(index=2, label='b', SMILES='C') + spc4 = Species(index=1, label='a', SMILES='CC') + + self.assertLess(spc1, spc2) + self.assertLess(spc1, spc3) + self.assertLess(spc2, spc3) + self.assertLess(spc1, spc4) + self.assertLess(spc2, spc4) + self.assertLess(spc3, spc4) + + def test_greater_than(self): + """Test that we can perform greater than comparison with Species objects""" + spc1 = Species(index=1, label='a', SMILES='C') + spc2 = Species(index=2, label='a', SMILES='C') + spc3 = Species(index=2, label='b', SMILES='C') + spc4 = Species(index=1, label='a', SMILES='CC') + + self.assertGreater(spc2, spc1) + self.assertGreater(spc3, spc1) + self.assertGreater(spc3, spc2) + self.assertGreater(spc4, spc1) + self.assertGreater(spc4, spc2) + self.assertGreater(spc4, spc3) + + def test_hash(self): + """Test behavior of Species hashing using dictionaries and sets""" + spc1 = Species(index=1, label='a', SMILES='C') + spc2 = Species(index=2, label='a', SMILES='C') + spc3 = Species(index=2, label='b', SMILES='C') + spc4 = Species(index=1, label='a', SMILES='CC') + + # Test dictionary behavior + self.assertEqual(len(dict.fromkeys([spc1, spc2, spc3, spc4])), 4) + + # Test set behavior + self.assertEqual(len({spc1, spc2, spc3, spc4}), 4) + def testToAdjacencyList(self): """ Test that toAdjacencyList() works as expected.