Skip to content

Commit

Permalink
enable digress in workflow interface api
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <kevin.ta@intel.com>
  • Loading branch information
kta-intel committed Sep 20, 2024
1 parent 3881a48 commit 0993e21
Show file tree
Hide file tree
Showing 29 changed files with 5,784 additions and 0 deletions.
554 changes: 554 additions & 0 deletions openfl-tutorials/experimental/DiGress/Workflow_Interface_DiGress.ipynb

Large diffs are not rendered by default.

Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
# Copyright (c) 2012-2022 Clement Vignac, Igor Krawczuk, Antoine Siraudin
# source: https://github.com/cvignac/DiGress/

import numpy as np
import torch
import re
# import wandb
try:
from rdkit import Chem
print("Found rdkit, all good")
except ModuleNotFoundError as e:
use_rdkit = False
from warnings import warn
warn("Didn't find rdkit, this will fail")
assert use_rdkit, "Didn't find rdkit"

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3, 'Si': 4, 'P': [3, 5],
'S': 4, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2], 'Bi': [3, 5], 'Se': [2, 4, 6]}
bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}


class BasicMolecularMetrics(object):
def __init__(self, dataset_info, train_smiles=None):
self.atom_decoder = dataset_info.atom_decoder
self.dataset_info = dataset_info

# Retrieve dataset smiles only for qm9 currently.
self.dataset_smiles_list = train_smiles

def compute_validity(self, generated):
""" generated: list of couples (positions, atom_types)"""
valid = []
num_components = []
all_smiles = []
for graph in generated:
atom_types, edge_types = graph
mol = build_molecule(atom_types, edge_types, self.dataset_info.atom_decoder)
smiles = mol2smiles(mol)
try:
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
num_components.append(len(mol_frags))
except:
pass
if smiles is not None:
try:
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
smiles = mol2smiles(largest_mol)
valid.append(smiles)
all_smiles.append(smiles)
except Chem.rdchem.AtomValenceException:
print("Valence error in GetmolFrags")
all_smiles.append(None)
except Chem.rdchem.KekulizeException:
print("Can't kekulize molecule")
all_smiles.append(None)
else:
all_smiles.append(None)

return valid, len(valid) / len(generated), np.array(num_components), all_smiles

def compute_uniqueness(self, valid):
""" valid: list of SMILES strings."""
return list(set(valid)), len(set(valid)) / len(valid)

def compute_novelty(self, unique):
num_novel = 0
novel = []
if self.dataset_smiles_list is None:
print("Dataset smiles is None, novelty computation skipped")
return 1, 1
for smiles in unique:
if smiles not in self.dataset_smiles_list:
novel.append(smiles)
num_novel += 1
return novel, num_novel / len(unique)

def compute_relaxed_validity(self, generated):
valid = []
for graph in generated:
atom_types, edge_types = graph
mol = build_molecule_with_partial_charges(atom_types, edge_types, self.dataset_info.atom_decoder)
smiles = mol2smiles(mol)
if smiles is not None:
try:
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
smiles = mol2smiles(largest_mol)
valid.append(smiles)
except Chem.rdchem.AtomValenceException:
print("Valence error in GetmolFrags")
except Chem.rdchem.KekulizeException:
print("Can't kekulize molecule")
return valid, len(valid) / len(generated)

def evaluate(self, generated):
""" generated: list of pairs (positions: n x 3, atom_types: n [int])
the positions and atom types should already be masked. """
valid, validity, num_components, all_smiles = self.compute_validity(generated)
nc_mu = num_components.mean() if len(num_components) > 0 else 0
nc_min = num_components.min() if len(num_components) > 0 else 0
nc_max = num_components.max() if len(num_components) > 0 else 0
# print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}%")
# print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")

relaxed_valid, relaxed_validity = self.compute_relaxed_validity(generated)
# print(f"Relaxed validity over {len(generated)} molecules: {relaxed_validity * 100 :.2f}%")
if relaxed_validity > 0:
unique, uniqueness = self.compute_uniqueness(relaxed_valid)
# print(f"Uniqueness over {len(relaxed_valid)} valid molecules: {uniqueness * 100 :.2f}%")

if self.dataset_smiles_list is not None:
_, novelty = self.compute_novelty(unique)
# print(f"Novelty over {len(unique)} unique valid molecules: {novelty * 100 :.2f}%")
else:
novelty = -1.0
else:
novelty = -1.0
uniqueness = 0.0
unique = []
return ([validity, relaxed_validity, uniqueness, novelty], unique,
dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles)


def mol2smiles(mol):
try:
Chem.SanitizeMol(mol)
except ValueError:
return None
return Chem.MolToSmiles(mol)


def build_molecule(atom_types, edge_types, atom_decoder, verbose=False):
if verbose:
print("building new molecule")

mol = Chem.RWMol()
for atom in atom_types:
a = Chem.Atom(atom_decoder[atom.item()])
mol.AddAtom(a)
if verbose:
print("Atom added: ", atom.item(), atom_decoder[atom.item()])

edge_types = torch.triu(edge_types)
all_bonds = torch.nonzero(edge_types)
for i, bond in enumerate(all_bonds):
if bond[0].item() != bond[1].item():
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()])
if verbose:
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(),
bond_dict[edge_types[bond[0], bond[1]].item()] )
return mol


def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False):
if verbose:
print("\nbuilding new molecule")

mol = Chem.RWMol()
for atom in atom_types:
a = Chem.Atom(atom_decoder[atom.item()])
mol.AddAtom(a)
if verbose:
print("Atom added: ", atom.item(), atom_decoder[atom.item()])
edge_types = torch.triu(edge_types)
all_bonds = torch.nonzero(edge_types)

for i, bond in enumerate(all_bonds):
if bond[0].item() != bond[1].item():
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()])
if verbose:
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(),
bond_dict[edge_types[bond[0], bond[1]].item()])
# add formal charge to atom: e.g. [O+], [N+], [S+]
# not support [O-], [N-], [S-], [NH+] etc.
flag, atomid_valence = check_valency(mol)
if verbose:
print("flag, valence", flag, atomid_valence)
if flag:
continue
else:
assert len(atomid_valence) == 2
idx = atomid_valence[0]
v = atomid_valence[1]
an = mol.GetAtomWithIdx(idx).GetAtomicNum()
if verbose:
print("atomic num of atom with a large valence", an)
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
mol.GetAtomWithIdx(idx).SetFormalCharge(1)
# print("Formal charge added")
return mol


# Functions from GDSS
def check_valency(mol):
try:
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
return True, None
except ValueError as e:
e = str(e)
p = e.find('#')
e_sub = e[p:]
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
return False, atomid_valence


def correct_mol(m):
# xsm = Chem.MolToSmiles(x, isomericSmiles=True)
mol = m

#####
no_correct = False
flag, _ = check_valency(mol)
if flag:
no_correct = True

while True:
flag, atomid_valence = check_valency(mol)
if flag:
break
else:
assert len(atomid_valence) == 2
idx = atomid_valence[0]
v = atomid_valence[1]
queue = []
check_idx = 0
for b in mol.GetAtomWithIdx(idx).GetBonds():
type = int(b.GetBondType())
queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
if type == 12:
check_idx += 1
queue.sort(key=lambda tup: tup[1], reverse=True)

if queue[-1][1] == 12:
return None, no_correct
elif len(queue) > 0:
start = queue[check_idx][2]
end = queue[check_idx][3]
t = queue[check_idx][1] - 1
mol.RemoveBond(start, end)
if t >= 1:
mol.AddBond(start, end, bond_dict[t])
return mol, no_correct


def valid_mol_can_with_seg(m, largest_connected_comp=True):
if m is None:
return None
sm = Chem.MolToSmiles(m, isomericSmiles=True)
if largest_connected_comp and '.' in sm:
vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.')
vsm.sort(key=lambda tup: tup[1], reverse=True)
mol = Chem.MolFromSmiles(vsm[0][0])
else:
mol = Chem.MolFromSmiles(sm)
return mol


if __name__ == '__main__':
smiles_mol = 'C1CCC1'
print("Smiles mol %s" % smiles_mol)
chem_mol = Chem.MolFromSmiles(smiles_mol)
block_mol = Chem.MolToMolBlock(chem_mol)
print("Block mol:")
print(block_mol)

use_rdkit = True


def check_stability(atom_types, edge_types, dataset_info, debug=False,atom_decoder=None):
if atom_decoder is None:
atom_decoder = dataset_info.atom_decoder

n_bonds = np.zeros(len(atom_types), dtype='int')

for i in range(len(atom_types)):
for j in range(i + 1, len(atom_types)):
n_bonds[i] += abs((edge_types[i, j] + edge_types[j, i])/2)
n_bonds[j] += abs((edge_types[i, j] + edge_types[j, i])/2)
n_stable_bonds = 0
for atom_type, atom_n_bond in zip(atom_types, n_bonds):
possible_bonds = allowed_bonds[atom_decoder[atom_type]]
if type(possible_bonds) == int:
is_stable = possible_bonds == atom_n_bond
else:
is_stable = atom_n_bond in possible_bonds
if not is_stable and debug:
print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type], atom_n_bond))
n_stable_bonds += int(is_stable)

molecule_stable = n_stable_bonds == len(atom_types)
return molecule_stable, n_stable_bonds, len(atom_types)


def compute_molecular_metrics(molecule_list, train_smiles, dataset_info):
""" molecule_list: (dict) """

if not dataset_info.remove_h:
print(f'Analyzing molecule stability...')

molecule_stable = 0
nr_stable_bonds = 0
n_atoms = 0
n_molecules = len(molecule_list)

for i, mol in enumerate(molecule_list):
atom_types, edge_types = mol

validity_results = check_stability(atom_types, edge_types, dataset_info)

molecule_stable += int(validity_results[0])
nr_stable_bonds += int(validity_results[1])
n_atoms += int(validity_results[2])

# Validity
fraction_mol_stable = molecule_stable / float(n_molecules)
fraction_atm_stable = nr_stable_bonds / float(n_atoms)
validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable}
# if wandb.run:
# wandb.log(validity_dict)
else:
validity_dict = {'mol_stable': -1, 'atm_stable': -1}

metrics = BasicMolecularMetrics(dataset_info, train_smiles)
rdkit_metrics = metrics.evaluate(molecule_list)
all_smiles = rdkit_metrics[-1]
# if wandb.run:
# nc = rdkit_metrics[-2]
# dic = {'Validity': rdkit_metrics[0][0], 'Relaxed Validity': rdkit_metrics[0][1],
# 'Uniqueness': rdkit_metrics[0][2], 'Novelty': rdkit_metrics[0][3],
# 'nc_max': nc['nc_max'], 'nc_mu': nc['nc_mu']}
# wandb.log(dic)

return validity_dict, rdkit_metrics, all_smiles
Loading

0 comments on commit 0993e21

Please sign in to comment.