Skip to content

Commit

Permalink
Add/update comparison methods for Atom, Bond, Molecule, Species
Browse files Browse the repository at this point in the history
Python 3 no longer allows implicit sorting of custom classes.

This implements hash, eq, lt, gt special methods for these classes.

In general, the goal was to maintain prior behavior for hash and eq
and implement reasonable behavior for gt and lt.

Sorting of these classes is generally done to enable deterministic behavior,
rather than any quantitative reason, so that was the main consideration.
  • Loading branch information
mliu49 committed Sep 10, 2019
1 parent c90c407 commit 0c0c7ed
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 25 deletions.
105 changes: 80 additions & 25 deletions rmgpy/molecule/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,32 @@ def __setstate__(self, d):
self.atomType = atomTypes[d['atomType']] if d['atomType'] else None
self.lonePairs = d['lonePairs']

def __hash__(self):
"""
Define a custom hash method to allow Atom objects to be used in dictionaries and sets.
"""
return hash(('Atom', self.symbol))

def __eq__(self, other):
"""Method to test equality of two Atom objects."""
return self is other

def __lt__(self, other):
"""Define less than comparison. For comparing against other Atom objects (e.g. when sorting)."""
if isinstance(other, Atom):
return self.sorting_key < other.sorting_key
else:
raise NotImplementedError('Cannot perform less than comparison between Atom and '
'{0}.'.format(type(other).__name__))

def __gt__(self, other):
"""Define greater than comparison. For comparing against other Atom objects (e.g. when sorting)."""
if isinstance(other, Atom):
return self.sorting_key > other.sorting_key
else:
raise NotImplementedError('Cannot perform greater than comparison between Atom and '
'{0}.'.format(type(other).__name__))

@property
def mass(self):
return self.element.mass
Expand Down Expand Up @@ -235,14 +261,6 @@ def equivalent(self, other, strict=True):
return False
return True

def get_descriptor(self):
"""
Return a tuple used for sorting atoms.
Currently uses atomic number, connectivity value,
radical electrons, lone pairs, and charge
"""
return self.number, -getVertexConnectivityValue(self), self.radicalElectrons, self.lonePairs, self.charge

def isSpecificCaseOf(self, other):
"""
Return ``True`` if `self` is a specific case of `other`, or ``False``
Expand Down Expand Up @@ -551,6 +569,34 @@ def __reduce__(self):
"""
return (Bond, (self.vertex1, self.vertex2, self.order))

def __hash__(self):
"""
Define a custom hash method to allow Bond objects to be used in dictionaries and sets.
"""
return hash(('Bond', self.order,
self.atom1.symbol if self.atom1 is not None else '',
self.atom2.symbol if self.atom2 is not None else ''))

def __eq__(self, other):
"""Method to test equality of two Bond objects."""
return self is other

def __lt__(self, other):
"""Define less than comparison. For comparing against other Bond objects (e.g. when sorting)."""
if isinstance(other, Bond):
return self.sorting_key < other.sorting_key
else:
raise NotImplementedError('Cannot perform less than comparison between Bond and '
'{0}.'.format(type(other).__name__))

def __gt__(self, other):
"""Define greater than comparison. For comparing against other Bond objects (e.g. when sorting)."""
if isinstance(other, Bond):
return self.sorting_key > other.sorting_key
else:
raise NotImplementedError('Cannot perform greater than comparison between Bond and '
'{0}.'.format(type(other).__name__))

@property
def atom1(self):
return self.vertex1
Expand Down Expand Up @@ -862,26 +908,35 @@ def __deepcopy__(self, memo):
return self.copy(deep=True)

def __hash__(self):
return hash((self.fingerprint))
"""
Define a custom hash method to allow Molecule objects to be used in dictionaries and sets.
def __richcmp__(x, y, op):
if op == 2: # Py_EQ
return x.is_equal(y)
if op == 3: # Py_NE
return not x.is_equal(y)
else:
raise NotImplementedError("Can only check equality of molecules, not > or <")
Use the fingerprint property, which is currently defined as the molecular formula, though
this is not an ideal hash since there will be significant hash collision, leading to inefficient lookups.
"""
return hash(('Molecule', self.fingerprint))

def is_equal(self, other):
def __eq__(self, other):
"""Method to test equality of two Molecule objects."""
if not isinstance(other, Molecule):
return False # different type
elif self is other:
return True # same reference in memory
elif self.fingerprint != other.fingerprint:
return False
return self is other or (isinstance(other, Molecule) and
self.fingerprint == other.fingerprint and
self.isIsomorphic(other))

def __lt__(self, other):
"""Define less than comparison. For comparing against other Molecule objects (e.g. when sorting)."""
if isinstance(other, Molecule):
return self.sorting_key < other.sorting_key
else:
raise NotImplementedError('Cannot perform less than comparison between Molecule and '
'{0}.'.format(type(other).__name__))

def __gt__(self, other):
"""Define greater than comparison. For comparing against other Molecule objects (e.g. when sorting)."""
if isinstance(other, Molecule):
return self.sorting_key > other.sorting_key
else:
return self.isIsomorphic(other)
raise NotImplementedError('Cannot perform greater than comparison between Molecule and '
'{0}.'.format(type(other).__name__))

def __str__(self):
"""
Expand Down Expand Up @@ -1063,7 +1118,7 @@ def sortAtoms(self):
if vertex.sortingLabel < 0:
self.updateConnectivityValues()
break
self.atoms.sort(key=lambda a: a.get_descriptor(), reverse=True)
self.atoms.sort(reverse=True)
for index, vertex in enumerate(self.vertices):
vertex.sortingLabel = index

Expand Down
92 changes: 92 additions & 0 deletions rmgpy/molecule/moleculeTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def setUp(self):
"""
self.atom = Atom(element=getElement('C'), radicalElectrons=1, charge=0, label='*1', lonePairs=0)

self.atom1 = Atom(element=getElement('C'), radicalElectrons=0, lonePairs=0)
self.atom2 = Atom(element=getElement('C'), radicalElectrons=0, lonePairs=0)
self.atom3 = Atom(element=getElement('C'), radicalElectrons=1, lonePairs=0)
self.atom4 = Atom(element=getElement('H'), radicalElectrons=1, lonePairs=0)

def testMass(self):
"""
Test the Atom.mass property.
Expand All @@ -67,6 +72,33 @@ def testSymbol(self):
"""
self.assertTrue(self.atom.symbol == self.atom.element.symbol)

def test_equality(self):
"""Test that we can perform equality comparison with Atom objects"""
self.assertEqual(self.atom1, self.atom1)
self.assertNotEqual(self.atom1, self.atom2)
self.assertNotEqual(self.atom1, self.atom3)
self.assertNotEqual(self.atom1, self.atom4)

def test_less_than(self):
"""Test that we can perform less than comparison with Atom objects"""
self.assertFalse(self.atom1 < self.atom2) # Because the sorting keys should be identical
self.assertLess(self.atom2, self.atom3)
self.assertLess(self.atom4, self.atom1)

def test_greater_than(self):
"""Test that we can perform greater than comparison with Atom objects"""
self.assertFalse(self.atom2 > self.atom1) # Because the sorting keys should be identical
self.assertGreater(self.atom3, self.atom1)
self.assertGreater(self.atom1, self.atom4)

def test_hash(self):
"""Test behavior of Atom hashing using dictionaries and sets"""
# Test dictionary behavior
self.assertEqual(len(dict.fromkeys([self.atom1, self.atom2, self.atom3, self.atom4])), 4)

# Test set behavior
self.assertEqual(len({self.atom1, self.atom2, self.atom3, self.atom4}), 4)

def testIsHydrogen(self):
"""
Test the Atom.isHydrogen() method.
Expand Down Expand Up @@ -390,6 +422,38 @@ def setUp(self):
self.bond = Bond(atom1=None, atom2=None, order=2)
self.orderList = [1, 2, 3, 4, 1.5, 0.30000000000000004]

self.bond1 = Bond(atom1=None, atom2=None, order=1)
self.bond2 = Bond(atom1=None, atom2=None, order=1)
self.bond3 = Bond(atom1=None, atom2=None, order=2)
self.bond4 = Bond(atom1=None, atom2=None, order=3)

def test_equality(self):
"""Test that we can perform equality comparison with Bond objects"""
self.assertEqual(self.bond1, self.bond1)
self.assertNotEqual(self.bond1, self.bond2)
self.assertNotEqual(self.bond1, self.bond3)
self.assertNotEqual(self.bond1, self.bond4)

def test_less_than(self):
"""Test that we can perform less than comparison with Bond objects"""
self.assertFalse(self.bond1 < self.bond2) # Because the sorting keys should be identical
self.assertLess(self.bond2, self.bond3)
self.assertLess(self.bond3, self.bond4)

def test_greater_than(self):
"""Test that we can perform greater than comparison with Bond objects"""
self.assertFalse(self.bond2 > self.bond1) # Because the sorting keys should be identical
self.assertGreater(self.bond3, self.bond1)
self.assertGreater(self.bond4, self.bond1)

def test_hash(self):
"""Test behavior of Bond hashing using dictionaries and sets"""
# Test dictionary behavior
self.assertEqual(len(dict.fromkeys([self.bond1, self.bond2, self.bond3, self.bond4])), 4)

# Test set behavior
self.assertEqual(len({self.bond1, self.bond2, self.bond3, self.bond4}), 4)

def testGetOrderStr(self):
"""
test the Bond.getOrderStr() method
Expand Down Expand Up @@ -773,6 +837,34 @@ def setUp(self):

self.mHBonds = Molecule().fromSMILES('C(NC=O)OO')

self.mol1 = Molecule(SMILES='C')
self.mol2 = Molecule(SMILES='C')
self.mol3 = Molecule(SMILES='CC')

def test_equality(self):
"""Test that we can perform equality comparison with Molecule objects"""
self.assertEqual(self.mol1, self.mol1)
self.assertEqual(self.mol1, self.mol2)
self.assertNotEqual(self.mol1, self.mol3)

def test_less_than(self):
"""Test that we can perform less than comparison with Molecule objects"""
self.assertFalse(self.mol1 < self.mol2) # Because the sorting keys should be identical
self.assertLess(self.mol1, self.mol3)

def test_greater_than(self):
"""Test that we can perform greater than comparison with Molecule objects"""
self.assertFalse(self.mol2 > self.mol1) # Because the sorting keys should be identical
self.assertGreater(self.mol3, self.mol1)

def test_hash(self):
"""Test behavior of Molecule hashing using dictionaries and sets"""
# Test dictionary behavior
self.assertEqual(len(dict.fromkeys([self.mol1, self.mol2, self.mol3])), 2)

# Test set behavior
self.assertEqual(len({self.mol1, self.mol2, self.mol3}), 2)

def testClearLabeledAtoms(self):
"""
Test the Molecule.clearLabeledAtoms() method.
Expand Down
30 changes: 30 additions & 0 deletions rmgpy/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,36 @@ def __reduce__(self):
return (Species, (self.index, self.label, self.thermo, self.conformer, self.molecule, self.transportData,
self.molecularWeight, self.energyTransferModel, self.reactive, self.props))

def __hash__(self):
"""
Define a custom hash method to allow Species objects to be used in dictionaries and sets.
Use the fingerprint property, which is taken from the first molecule entry.
This is currently defined as the molecular formula, which is not an ideal hash, since there will be significant
hash collisions, leading to inefficient lookups.
"""
return hash(('Species', self.fingerprint))

def __eq__(self, other):
"""Define equality comparison. Define as a reference comparison"""
return self is other

def __lt__(self, other):
"""Define less than comparison. For comparing against other Species objects (e.g. when sorting)."""
if isinstance(other, Species):
return self.sorting_key < other.sorting_key
else:
raise NotImplementedError('Cannot perform less than comparison between Species and '
'{0}.'.format(type(other).__name__))

def __gt__(self, other):
"""Define greater than comparison. For comparing against other Species objects (e.g. when sorting)."""
if isinstance(other, Species):
return self.sorting_key > other.sorting_key
else:
raise NotImplementedError('Cannot perform greater than comparison between Species and '
'{0}.'.format(type(other).__name__))

@property
def sorting_key(self):
"""Returns a sorting key for comparing Species objects. Read-only"""
Expand Down
50 changes: 50 additions & 0 deletions rmgpy/speciesTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,56 @@ def testOutput(self):
self.assertEqual(self.species.molecularWeight.units, species.molecularWeight.units)
self.assertEqual(self.species.reactive, species.reactive)

def test_equality(self):
"""Test that we can perform equality comparison with Species objects"""
spc1 = Species(SMILES='C')
spc2 = Species(SMILES='C')

self.assertNotEqual(spc1, spc2)
self.assertEqual(spc1, spc1)
self.assertEqual(spc2, spc2)

def test_less_than(self):
"""Test that we can perform less than comparison with Species objects"""
spc1 = Species(index=1, label='a', SMILES='C')
spc2 = Species(index=2, label='a', SMILES='C')
spc3 = Species(index=2, label='b', SMILES='C')
spc4 = Species(index=1, label='a', SMILES='CC')

self.assertLess(spc1, spc2)
self.assertLess(spc1, spc3)
self.assertLess(spc2, spc3)
self.assertLess(spc1, spc4)
self.assertLess(spc2, spc4)
self.assertLess(spc3, spc4)

def test_greater_than(self):
"""Test that we can perform greater than comparison with Species objects"""
spc1 = Species(index=1, label='a', SMILES='C')
spc2 = Species(index=2, label='a', SMILES='C')
spc3 = Species(index=2, label='b', SMILES='C')
spc4 = Species(index=1, label='a', SMILES='CC')

self.assertGreater(spc2, spc1)
self.assertGreater(spc3, spc1)
self.assertGreater(spc3, spc2)
self.assertGreater(spc4, spc1)
self.assertGreater(spc4, spc2)
self.assertGreater(spc4, spc3)

def test_hash(self):
"""Test behavior of Species hashing using dictionaries and sets"""
spc1 = Species(index=1, label='a', SMILES='C')
spc2 = Species(index=2, label='a', SMILES='C')
spc3 = Species(index=2, label='b', SMILES='C')
spc4 = Species(index=1, label='a', SMILES='CC')

# Test dictionary behavior
self.assertEqual(len(dict.fromkeys([spc1, spc2, spc3, spc4])), 4)

# Test set behavior
self.assertEqual(len({spc1, spc2, spc3, spc4}), 4)

def testToAdjacencyList(self):
"""
Test that toAdjacencyList() works as expected.
Expand Down

0 comments on commit 0c0c7ed

Please sign in to comment.