Skip to content

Commit

Permalink
Refine AC2BO in xyz2mol
Browse files Browse the repository at this point in the history
  • Loading branch information
choglass committed Nov 14, 2024
1 parent 0b1763d commit 23f4b09
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 78 deletions.
58 changes: 38 additions & 20 deletions cell2mol/charge_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,37 +623,55 @@ def get_charge_manual(spec, debug: int=0):
smiles = "[O-]Cl(=O)(=O)=O"
charge = -1
order = [0, 1, 2, 3, 4] # Cl index is 1
for idx, a in enumerate(spec.atoms):
if a.label == "Cl":
new_order = order # Default to initial order in case no move is needed

# Find the index of the "Cl" atom and update `new_order`
for idx, atom in enumerate(spec.atoms):
if atom.label == "Cl":
new_order = move_element(order, 1, idx)
if debug >= 2: print(f" O4-Cl: {new_order=}")
break # Stop once we find and move Cl

# Debug output if needed
if debug >= 2:
print(f"O4-Cl: {new_order=}")

elif spec.formula == "N3":
smiles = "[N-]=[N+]=[N-]"
smiles = "[N-]=[N+]=[N-]"
charge = -1
order = [0, 1, 2]
for idx, a in enumerate(spec.atoms):
list_of_adj_atoms = []
for adj in a.adjacency:
if debug >= 2: print(f" N3: {adj=}", spec.get_parent("molecule").labels[adj])
list_of_adj_atoms.append(spec.get_parent("molecule").labels[adj])
numN = list_of_adj_atoms.count("N")
if numN == 2:
new_order = order # Default to initial order if no modification is needed

# Iterate over each atom to check adjacency
for idx, atom in enumerate(spec.atoms):
# Find adjacent atoms and check if there are exactly two "N" atoms
adjacent_labels = [spec.get_parent("molecule").labels[adj] for adj in atom.adjacency]
if debug >= 2: print(f"N3: atom index {idx}, adjacent_labels={adjacent_labels}")

# Check if the atom has exactly 2 nitrogen neighbors
if adjacent_labels.count("N") == 2:
new_order = move_element(order, 1, idx)
if debug >= 2: print(f" N3: {new_order=}")
break # Found the target atom, no need to check further

if debug >= 2: print(f"N3: new_order={new_order}")

elif spec.formula == "I3":
smiles = "I[I-]I"
charge = -1
order = [0, 1, 2]
for idx, a in enumerate(spec.atoms):
list_of_adj_atoms = []
for adj in a.adjacency:
if debug >= 2: print(f" I3: {adj=}", spec.get_parent("molecule").labels[adj])
list_of_adj_atoms.append(spec.get_parent("molecule").labels[adj])
numI = list_of_adj_atoms.count("I")
if numI == 2:
new_order = order # Default to initial order if no modification is needed

# Iterate over each atom to check adjacency
for idx, atom in enumerate(spec.atoms):
# Find adjacent atoms and check if there are exactly two "N" atoms
adjacent_labels = [spec.get_parent("molecule").labels[adj] for adj in atom.adjacency]
if debug >= 2: print(f"I3: atom index {idx}, adjacent_labels={adjacent_labels}")

# Check if the atom has exactly 2 nitrogen neighbors
if adjacent_labels.count("I") == 2:
new_order = move_element(order, 1, idx)
if debug >= 2: print(f" I3: {new_order=}")
break # Found the target atom, no need to check further

if debug >= 2: print(f"I3: new_order={new_order}")

temp_mol = Chem.MolFromSmiles(smiles, sanitize=False)
mol = Chem.RenumberAtoms(temp_mol, new_order)
Expand Down
37 changes: 11 additions & 26 deletions cell2mol/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def get_coord_sphere_formula(self):
return self.coord_sphere_formula

#######################################################
def get_connected_groups(self, debug: int=0):
def get_connected_groups(self, debug: int=2):
from cell2mol.connectivity import split_group
# metal.groups will be used for the calculation of the relative metal radius
# and define the coordination geometry of the metal /hapicitiy/ hapttype
Expand All @@ -1130,6 +1130,7 @@ def get_connected_groups(self, debug: int=0):
for lig in mol.ligands:
for group in lig.groups:
if debug > 1: print(group.formula)
ligand_indices = [ a.get_parent_index("ligand") for a in group.atoms ]
tmplabels = []
tmpcoord = []
tmplabels.append(self.label)
Expand All @@ -1144,9 +1145,12 @@ def get_connected_groups(self, debug: int=0):
if all(tmpadjnum[1:]):
self.groups.append(group)
elif any(tmpadjnum[1:]):

if debug > 1: print(f"Metal {self.label} is connected to {group.formula} but not all atoms are connected")
conn_idx = [ idx for idx, num in enumerate(tmpadjnum[1:]) if num == 1 ]
splitted_groups = split_group(group, conn_idx, debug=debug)
conn_ligand_indices = [ ligand_indices[idx] for idx, num in enumerate(tmpadjnum[1:]) if num == 1 ]
print(f"get_connected_groups {tmpadjnum[1:]=} {conn_idx=} {conn_ligand_indices=} {ligand_indices=}")
splitted_groups = split_group(group, conn_idx, conn_ligand_indices, debug=debug)
for g in splitted_groups:
self.groups.append(g)
if debug > 1: print(f"Metal {self.label} is connected to {g.formula}")
Expand Down Expand Up @@ -1358,7 +1362,10 @@ def get_reference_molecules(self, ref_labels: list, ref_fracs: list, cov_factor:
for atom, idx in zip(newmolec.atoms, b):
atom.add_parent(refcell, index=idx)
# This must be below the frac_coord, so they are carried on to the ligands
if newmolec.iscomplex: newmolec.split_complex()
if newmolec.iscomplex:
newmolec.split_complex()
else:
newmolec.add_parent(newmolec, indices=[*range(0,newmolec.natoms,1)])
self.refmoleclist.append(newmolec)

if debug >= 0: print(f"GETREFS: found {len(self.refmoleclist)} reference molecules")
Expand Down Expand Up @@ -1731,28 +1738,6 @@ def assign_charges (self, debug: int=0):
prepare_mol(mol)

#######################################################
def assign_final_charge_to_unique_species(self, final_charges, debug: int=0):
for specie, final_charge in zip(self.unique_species, final_charges):
print(specie.unique_index, specie.formula)
if (specie.subtype == "molecule" and specie.iscomplex == False) or (specie.subtype == "ligand"):
charge_list = [cs.corr_total_charge for cs in specie.possible_cs]
idx = charge_list.index(final_charge)
cs = specie.possible_cs[idx]
specie.charge_state = cs
# print(specie.charge_state.protonation)
specie.set_charges(cs.corr_total_charge, cs.corr_atom_charges, cs.smiles, cs.rdkit_obj)
elif specie.subtype == "metal" :
charge_list = specie.possible_cs
idx = charge_list.index(final_charge)
cs = specie.possible_cs[idx]
specie.set_charge(cs)
for specie in self.unique_species:
print("Unique Species final charges")
if (specie.subtype == "molecule" and specie.iscomplex == False) or (specie.subtype == "ligand"):
print(specie.formula, specie.totcharge)
elif specie.subtype == "metal" :
print(specie.formula, specie.charge)
#######################################################
def assign_charges_for_refcell(self, debug: int=0):
for idx, ref in enumerate(self.refmoleclist):
print(f"Refenrence Molecule {idx}: {ref.formula}")
Expand Down Expand Up @@ -1812,7 +1797,7 @@ def assign_charges_old (self, debug: int=0) -> object:
# The whole process is done by 4 functions, which are run at the specie class level:
# 1) spec.get_protonation_states(), which determines which atoms of the specie must have added elements (see above) to have a meaningful Lewis structure
# 2) spec.get_possible_cs(), which retrieves the possible charge states associated with the specie
# 3) spec.get_charge(), which generates one connectivity for a set of charges
# 3) spec.get_possible_charge_state(), which generates one connectivity for a set of charges
# 4) cell.select_charge_distr() chooses the best connectivity among the generated ones.

# Basically, this function connects these other three functions,
Expand Down
31 changes: 21 additions & 10 deletions cell2mol/new_c2m_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ def cell2mol(newcell: object, refcell: object, sym_ops, reconstruction: bool=Tru
# Cell Reconstruction
all_molecules, reconstructed_molecules = reconstuct(refcell, newcell, sym_ops, debug=debug)
all_molecules.extend(reconstructed_molecules)

if newcell.error_get_fragments: return newcell
elif newcell.error_reconstruction: return newcell
tend = time.time()
if newcell.error_get_fragments:
if debug >= 1: print(f"\nCell Reconstruction Failed. Total execution time: {tend - tini:.2f} seconds")
return newcell
elif newcell.error_reconstruction:
if debug >= 1: print(f"\nCell Reconstruction Failed. Total execution time: {tend - tini:.2f} seconds")
return newcell
else:
tend = time.time()
if debug >= 1: print(f"\nCell Reconstruction Finished Normally. Total execution time: {tend - tini:.2f} seconds")

# Get moleclist for the unit cell
Expand Down Expand Up @@ -62,18 +65,26 @@ def cell2mol(newcell: object, refcell: object, sym_ops, reconstruction: bool=Tru

# Assign charge for the unit cell and check charge neutrality
newcell.assign_charges(debug=debug)

if newcell.error_get_poscharges : return newcell
elif newcell.error_multiple_distrib : return newcell
elif newcell.error_empty_distrib : return newcell
tend = time.time()
if newcell.error_get_poscharges :
if debug >= 1: print(f"Charge Assignment Failed. Total execution time: {tend - tini:.2f} seconds")
return newcell
elif newcell.error_multiple_distrib :
if debug >= 1: print(f"Charge Assignment Failed. Total execution time: {tend - tini:.2f} seconds")
return newcell
elif newcell.error_empty_distrib :
if debug >= 1: print(f"Charge Assignment Failed. Total execution time: {tend - tini:.2f} seconds")
return newcell
else :
tend = time.time()

if debug >= 1: print(f"Charge Assignment Finished Normally. Total execution time: {tend - tini:.2f} seconds")

newcell.check_charge_neutrality(debug=debug)
newcell.create_bonds(debug=debug)

if newcell.error_create_bonds: return newcell
if newcell.error_create_bonds:
if debug >= 1: print(f"Creating bonds Failed")
return newcell
else:
if debug >= 1: print("Creating bonds Finished Normally")
else:
Expand Down
5 changes: 4 additions & 1 deletion cell2mol/new_cell_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,10 @@ def get_moleclist (newcell, refcell, all_molecules, debug: int=0):
atom.add_parent(newcell, index=idx)
for atom, idx in zip(newmolec.atoms, mol.ref_indices):
atom.add_parent(refcell, index=idx)
if newmolec.iscomplex: newmolec.split_complex()
if newmolec.iscomplex:
newmolec.split_complex()
else:
newmolec.add_parent(newmolec, indices=[*range(0,newmolec.natoms,1)])
newcell.moleclist.append(newmolec)

for mol in newcell.moleclist:
Expand Down
7 changes: 5 additions & 2 deletions cell2mol/new_charge_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@ def set_charge_state(reference, target, mode, debug: int=0):
if target.formula in ["O4-Cl", "N3", "I3"]:
cs = get_charge_manual(target, debug=debug)
else :
#if not hasattr(target, "possible_cs"): target.get_possible_cs(debug=debug)
target.get_possible_cs(debug=debug)
if not hasattr(target, "possible_cs"):
target.get_possible_cs(debug=debug)
else:
if debug >= 1: print("SET_CHARGE_STATE: possible_cs of reference already exists")
print(f"{target.formula=} {target.possible_cs=}")
charge_list = [cs.corr_total_charge for cs in target.possible_cs]
print(charge_list, final_charge, target.possible_cs)
idx = charge_list.index(final_charge)
Expand Down
18 changes: 12 additions & 6 deletions cell2mol/refcell.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cell2mol.cell_operations import frac2cart_fromparam
from cell2mol.new_cell_reconstruction import modify_cov_factor_due_to_H, modify_cov_factor_due_to_possible_charges
from cell2mol.other import handle_error
import time

# Constants
VERSION = "2.0"
Expand Down Expand Up @@ -51,28 +52,33 @@ def process_refcell(input_path, name, current_dir, debug=0):

def create_reference (input_path, name, cell_vector, cell_param, debug):
"""Create the reference cell object."""

tini = time.time()
ref_labels, ref_fracs = get_wyckoff_positions(input_path)
ref_pos = frac2cart_fromparam(ref_fracs, cell_param)

refcell = cell(name, ref_labels, ref_pos, ref_fracs, cell_vector, cell_param)
refcell.get_subtype("reference")
refcell.get_reference_molecules(ref_labels, ref_fracs, cov_factor=COV_FACTOR, debug=debug)
refcell = modify_cov_factor_due_to_H(refcell, debug=debug)

#refcell = modify_cov_factor_due_to_H(refcell, debug=debug)
if not refcell.has_isolated_H:
refcell.check_missing_H(debug=debug)
refcell.assess_errors(mode="hydrogens")
tend = time.time()
if debug >= 1: print(f"\nReference molecules are generated. Total execution time: {tend - tini:.2f} seconds")
return refcell

def get_unique_species_in_reference (refcell, debug):
"""Processes the reference cell to obtain unique species and handle any errors."""

tini = time.time()
refcell.get_unique_species(debug=debug)
if debug >= 1:
print(f"Unique species: {[specie.formula for specie in refcell.unique_species]}")
print(f"Species list: {[specie.formula for specie in refcell.species_list]}\n")

refcell = modify_cov_factor_due_to_possible_charges(refcell, debug=debug)
#refcell = modify_cov_factor_due_to_possible_charges(refcell, debug=debug)
refcell.get_selected_cs(debug=debug)
refcell.assess_errors(mode="possible_charges")
tend = time.time()
if debug >= 1: print(f"\nAssign possible charges of Reference molecules. Total execution time: {tend - tini:.2f} seconds")

# Run the main function
if __name__ == "__main__":
Expand Down
35 changes: 22 additions & 13 deletions cell2mol/xyz2mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from rdkit.Chem import AllChem, rdmolops

from cell2mol.elementdata import ElementData

from cell2mol.connectivity import labels2formula
elemdatabase = ElementData()

###############################
Expand Down Expand Up @@ -489,7 +489,8 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
# make a list of valences, e.g. for CO: [[4],[2,1]]
valences_list_of_lists = []
AC_valence = list(AC.sum(axis=1))
#print(f"{AC_valence=}")
print(f"{AC_valence=}")
formula = labels2formula([elemdatabase.elementsym[atom] for atom in atoms])
wrong = 0

for i, (atomicNum, valence) in enumerate(zip(atoms, AC_valence)):
Expand All @@ -505,7 +506,8 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
possible_valence = [x for x in atomic_valence[atomicNum] if x >= valence]
if atomicNum == 7:
#print("Possible valences for:", atomicNum,"are",possible_valence, valence)
possible_valence.append(valence)
if valence not in possible_valence:
possible_valence.append(valence)
# if atomicNum == 15:
# print("Possible valences for:", atomicNum,"are",possible_valence, valence)
if len(possible_valence) == 0:
Expand All @@ -524,7 +526,7 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
# print(f"AC2BO: {wrong=}")
return None, atomic_valence_electrons

#print(f"\tAC2BO: {valences_list_of_lists=}")
print(f"\tAC2BO: {valences_list_of_lists=}")

# convert [[4],[2,1]] to [[4,2],[4,1]]
valences_list = []
Expand All @@ -535,17 +537,17 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
valences_list.append(tmp)

best_BO = AC.copy()
# print("Final valences list:", list(valences_list), len(list(valences_list)))
BO_is_OK_list = []
# print(f"AC2BO: {valences_list=}")
print(f"AC2BO: {formula=} {len(valences_list)=}")
count = 0
for valences in valences_list:

#print(f"\tSending", valences, AC_valence, "to get_UA")
print(f"\tSending", valences, AC_valence, "to get_UA")
UA, DU_from_AC = get_UA(valences, AC_valence)

check_len = len(UA) == 0
#print (f"\tAC2BO: check_len", check_len)
#print(f"\tUA", UA)
print (f"\tAC2BO: check_len", check_len)
print(f"\tUA", UA)
if check_len:
check_bo = BO_is_OK(
AC,
Expand All @@ -561,9 +563,9 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
check_bo = None

if check_len and check_bo:
#print(f"\tAC2BO: return AC", check_len, check_bo)
print(f"\tAC2BO: {formula=} return AC", check_len, check_bo, f"{charge=} {count=}")
return AC, atomic_valence_electrons

UA_pairs_list = get_UA_pairs(UA, AC, use_graph=use_graph)
for UA_pairs in UA_pairs_list:
BO = get_BO(AC, UA, DU_from_AC, valences, UA_pairs, use_graph=use_graph)
Expand All @@ -589,15 +591,21 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
)

if status:
#print(f"\tAC2BO: status", status)
print(f"\tAC2BO: {formula=} status", status, f"{charge=} {count=}")
return BO, atomic_valence_electrons
elif (
BO.sum() >= best_BO.sum()
and valences_not_too_large(BO, valences)
and charge_OK
):
# print(f"\tAC2BO: status", status, "BO.sum()", BO.sum(), "best_BO.sum()", best_BO.sum())
print(f"\tAC2BO: status", status, "BO.sum()", BO.sum(), "best_BO.sum()", best_BO.sum())
best_BO = BO.copy()

count += 1
if count > 1000:
print(f"Failing AC2BO: {formula=} {charge=} {count=}")
return best_BO, atomic_valence_electrons

# if status:
# return BO, atomic_valence_electrons
# if status:
Expand All @@ -611,6 +619,7 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True):
# print("best bo", best_BO)
#print(f"\tAC2BO: return best bo")
#print("AC2BO: return best bo", best_BO)
print(f"Failing AC2BO: {formula=} {charge=} {count=}")
return best_BO, atomic_valence_electrons


Expand Down

0 comments on commit 23f4b09

Please sign in to comment.