-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enable digress in workflow interface api
Signed-off-by: kta-intel <kevin.ta@intel.com>
- Loading branch information
Showing
29 changed files
with
5,784 additions
and
0 deletions.
There are no files selected for viewing
554 changes: 554 additions & 0 deletions
554
openfl-tutorials/experimental/DiGress/Workflow_Interface_DiGress.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
Empty file.
339 changes: 339 additions & 0 deletions
339
openfl-tutorials/experimental/DiGress/digress/analysis/rdkit_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,339 @@ | ||
# Copyright (c) 2012-2022 Clement Vignac, Igor Krawczuk, Antoine Siraudin | ||
# source: https://github.com/cvignac/DiGress/ | ||
|
||
import numpy as np | ||
import torch | ||
import re | ||
# import wandb | ||
try: | ||
from rdkit import Chem | ||
print("Found rdkit, all good") | ||
except ModuleNotFoundError as e: | ||
use_rdkit = False | ||
from warnings import warn | ||
warn("Didn't find rdkit, this will fail") | ||
assert use_rdkit, "Didn't find rdkit" | ||
|
||
from rdkit import RDLogger | ||
RDLogger.DisableLog('rdApp.*') | ||
|
||
allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3, 'Si': 4, 'P': [3, 5], | ||
'S': 4, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2], 'Bi': [3, 5], 'Se': [2, 4, 6]} | ||
bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, | ||
Chem.rdchem.BondType.AROMATIC] | ||
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} | ||
|
||
|
||
class BasicMolecularMetrics(object): | ||
def __init__(self, dataset_info, train_smiles=None): | ||
self.atom_decoder = dataset_info.atom_decoder | ||
self.dataset_info = dataset_info | ||
|
||
# Retrieve dataset smiles only for qm9 currently. | ||
self.dataset_smiles_list = train_smiles | ||
|
||
def compute_validity(self, generated): | ||
""" generated: list of couples (positions, atom_types)""" | ||
valid = [] | ||
num_components = [] | ||
all_smiles = [] | ||
for graph in generated: | ||
atom_types, edge_types = graph | ||
mol = build_molecule(atom_types, edge_types, self.dataset_info.atom_decoder) | ||
smiles = mol2smiles(mol) | ||
try: | ||
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | ||
num_components.append(len(mol_frags)) | ||
except: | ||
pass | ||
if smiles is not None: | ||
try: | ||
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | ||
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) | ||
smiles = mol2smiles(largest_mol) | ||
valid.append(smiles) | ||
all_smiles.append(smiles) | ||
except Chem.rdchem.AtomValenceException: | ||
print("Valence error in GetmolFrags") | ||
all_smiles.append(None) | ||
except Chem.rdchem.KekulizeException: | ||
print("Can't kekulize molecule") | ||
all_smiles.append(None) | ||
else: | ||
all_smiles.append(None) | ||
|
||
return valid, len(valid) / len(generated), np.array(num_components), all_smiles | ||
|
||
def compute_uniqueness(self, valid): | ||
""" valid: list of SMILES strings.""" | ||
return list(set(valid)), len(set(valid)) / len(valid) | ||
|
||
def compute_novelty(self, unique): | ||
num_novel = 0 | ||
novel = [] | ||
if self.dataset_smiles_list is None: | ||
print("Dataset smiles is None, novelty computation skipped") | ||
return 1, 1 | ||
for smiles in unique: | ||
if smiles not in self.dataset_smiles_list: | ||
novel.append(smiles) | ||
num_novel += 1 | ||
return novel, num_novel / len(unique) | ||
|
||
def compute_relaxed_validity(self, generated): | ||
valid = [] | ||
for graph in generated: | ||
atom_types, edge_types = graph | ||
mol = build_molecule_with_partial_charges(atom_types, edge_types, self.dataset_info.atom_decoder) | ||
smiles = mol2smiles(mol) | ||
if smiles is not None: | ||
try: | ||
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | ||
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) | ||
smiles = mol2smiles(largest_mol) | ||
valid.append(smiles) | ||
except Chem.rdchem.AtomValenceException: | ||
print("Valence error in GetmolFrags") | ||
except Chem.rdchem.KekulizeException: | ||
print("Can't kekulize molecule") | ||
return valid, len(valid) / len(generated) | ||
|
||
def evaluate(self, generated): | ||
""" generated: list of pairs (positions: n x 3, atom_types: n [int]) | ||
the positions and atom types should already be masked. """ | ||
valid, validity, num_components, all_smiles = self.compute_validity(generated) | ||
nc_mu = num_components.mean() if len(num_components) > 0 else 0 | ||
nc_min = num_components.min() if len(num_components) > 0 else 0 | ||
nc_max = num_components.max() if len(num_components) > 0 else 0 | ||
# print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}%") | ||
# print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}") | ||
|
||
relaxed_valid, relaxed_validity = self.compute_relaxed_validity(generated) | ||
# print(f"Relaxed validity over {len(generated)} molecules: {relaxed_validity * 100 :.2f}%") | ||
if relaxed_validity > 0: | ||
unique, uniqueness = self.compute_uniqueness(relaxed_valid) | ||
# print(f"Uniqueness over {len(relaxed_valid)} valid molecules: {uniqueness * 100 :.2f}%") | ||
|
||
if self.dataset_smiles_list is not None: | ||
_, novelty = self.compute_novelty(unique) | ||
# print(f"Novelty over {len(unique)} unique valid molecules: {novelty * 100 :.2f}%") | ||
else: | ||
novelty = -1.0 | ||
else: | ||
novelty = -1.0 | ||
uniqueness = 0.0 | ||
unique = [] | ||
return ([validity, relaxed_validity, uniqueness, novelty], unique, | ||
dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles) | ||
|
||
|
||
def mol2smiles(mol): | ||
try: | ||
Chem.SanitizeMol(mol) | ||
except ValueError: | ||
return None | ||
return Chem.MolToSmiles(mol) | ||
|
||
|
||
def build_molecule(atom_types, edge_types, atom_decoder, verbose=False): | ||
if verbose: | ||
print("building new molecule") | ||
|
||
mol = Chem.RWMol() | ||
for atom in atom_types: | ||
a = Chem.Atom(atom_decoder[atom.item()]) | ||
mol.AddAtom(a) | ||
if verbose: | ||
print("Atom added: ", atom.item(), atom_decoder[atom.item()]) | ||
|
||
edge_types = torch.triu(edge_types) | ||
all_bonds = torch.nonzero(edge_types) | ||
for i, bond in enumerate(all_bonds): | ||
if bond[0].item() != bond[1].item(): | ||
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()]) | ||
if verbose: | ||
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(), | ||
bond_dict[edge_types[bond[0], bond[1]].item()] ) | ||
return mol | ||
|
||
|
||
def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False): | ||
if verbose: | ||
print("\nbuilding new molecule") | ||
|
||
mol = Chem.RWMol() | ||
for atom in atom_types: | ||
a = Chem.Atom(atom_decoder[atom.item()]) | ||
mol.AddAtom(a) | ||
if verbose: | ||
print("Atom added: ", atom.item(), atom_decoder[atom.item()]) | ||
edge_types = torch.triu(edge_types) | ||
all_bonds = torch.nonzero(edge_types) | ||
|
||
for i, bond in enumerate(all_bonds): | ||
if bond[0].item() != bond[1].item(): | ||
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()]) | ||
if verbose: | ||
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(), | ||
bond_dict[edge_types[bond[0], bond[1]].item()]) | ||
# add formal charge to atom: e.g. [O+], [N+], [S+] | ||
# not support [O-], [N-], [S-], [NH+] etc. | ||
flag, atomid_valence = check_valency(mol) | ||
if verbose: | ||
print("flag, valence", flag, atomid_valence) | ||
if flag: | ||
continue | ||
else: | ||
assert len(atomid_valence) == 2 | ||
idx = atomid_valence[0] | ||
v = atomid_valence[1] | ||
an = mol.GetAtomWithIdx(idx).GetAtomicNum() | ||
if verbose: | ||
print("atomic num of atom with a large valence", an) | ||
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: | ||
mol.GetAtomWithIdx(idx).SetFormalCharge(1) | ||
# print("Formal charge added") | ||
return mol | ||
|
||
|
||
# Functions from GDSS | ||
def check_valency(mol): | ||
try: | ||
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) | ||
return True, None | ||
except ValueError as e: | ||
e = str(e) | ||
p = e.find('#') | ||
e_sub = e[p:] | ||
atomid_valence = list(map(int, re.findall(r'\d+', e_sub))) | ||
return False, atomid_valence | ||
|
||
|
||
def correct_mol(m): | ||
# xsm = Chem.MolToSmiles(x, isomericSmiles=True) | ||
mol = m | ||
|
||
##### | ||
no_correct = False | ||
flag, _ = check_valency(mol) | ||
if flag: | ||
no_correct = True | ||
|
||
while True: | ||
flag, atomid_valence = check_valency(mol) | ||
if flag: | ||
break | ||
else: | ||
assert len(atomid_valence) == 2 | ||
idx = atomid_valence[0] | ||
v = atomid_valence[1] | ||
queue = [] | ||
check_idx = 0 | ||
for b in mol.GetAtomWithIdx(idx).GetBonds(): | ||
type = int(b.GetBondType()) | ||
queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx())) | ||
if type == 12: | ||
check_idx += 1 | ||
queue.sort(key=lambda tup: tup[1], reverse=True) | ||
|
||
if queue[-1][1] == 12: | ||
return None, no_correct | ||
elif len(queue) > 0: | ||
start = queue[check_idx][2] | ||
end = queue[check_idx][3] | ||
t = queue[check_idx][1] - 1 | ||
mol.RemoveBond(start, end) | ||
if t >= 1: | ||
mol.AddBond(start, end, bond_dict[t]) | ||
return mol, no_correct | ||
|
||
|
||
def valid_mol_can_with_seg(m, largest_connected_comp=True): | ||
if m is None: | ||
return None | ||
sm = Chem.MolToSmiles(m, isomericSmiles=True) | ||
if largest_connected_comp and '.' in sm: | ||
vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.') | ||
vsm.sort(key=lambda tup: tup[1], reverse=True) | ||
mol = Chem.MolFromSmiles(vsm[0][0]) | ||
else: | ||
mol = Chem.MolFromSmiles(sm) | ||
return mol | ||
|
||
|
||
if __name__ == '__main__': | ||
smiles_mol = 'C1CCC1' | ||
print("Smiles mol %s" % smiles_mol) | ||
chem_mol = Chem.MolFromSmiles(smiles_mol) | ||
block_mol = Chem.MolToMolBlock(chem_mol) | ||
print("Block mol:") | ||
print(block_mol) | ||
|
||
use_rdkit = True | ||
|
||
|
||
def check_stability(atom_types, edge_types, dataset_info, debug=False,atom_decoder=None): | ||
if atom_decoder is None: | ||
atom_decoder = dataset_info.atom_decoder | ||
|
||
n_bonds = np.zeros(len(atom_types), dtype='int') | ||
|
||
for i in range(len(atom_types)): | ||
for j in range(i + 1, len(atom_types)): | ||
n_bonds[i] += abs((edge_types[i, j] + edge_types[j, i])/2) | ||
n_bonds[j] += abs((edge_types[i, j] + edge_types[j, i])/2) | ||
n_stable_bonds = 0 | ||
for atom_type, atom_n_bond in zip(atom_types, n_bonds): | ||
possible_bonds = allowed_bonds[atom_decoder[atom_type]] | ||
if type(possible_bonds) == int: | ||
is_stable = possible_bonds == atom_n_bond | ||
else: | ||
is_stable = atom_n_bond in possible_bonds | ||
if not is_stable and debug: | ||
print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type], atom_n_bond)) | ||
n_stable_bonds += int(is_stable) | ||
|
||
molecule_stable = n_stable_bonds == len(atom_types) | ||
return molecule_stable, n_stable_bonds, len(atom_types) | ||
|
||
|
||
def compute_molecular_metrics(molecule_list, train_smiles, dataset_info): | ||
""" molecule_list: (dict) """ | ||
|
||
if not dataset_info.remove_h: | ||
print(f'Analyzing molecule stability...') | ||
|
||
molecule_stable = 0 | ||
nr_stable_bonds = 0 | ||
n_atoms = 0 | ||
n_molecules = len(molecule_list) | ||
|
||
for i, mol in enumerate(molecule_list): | ||
atom_types, edge_types = mol | ||
|
||
validity_results = check_stability(atom_types, edge_types, dataset_info) | ||
|
||
molecule_stable += int(validity_results[0]) | ||
nr_stable_bonds += int(validity_results[1]) | ||
n_atoms += int(validity_results[2]) | ||
|
||
# Validity | ||
fraction_mol_stable = molecule_stable / float(n_molecules) | ||
fraction_atm_stable = nr_stable_bonds / float(n_atoms) | ||
validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable} | ||
# if wandb.run: | ||
# wandb.log(validity_dict) | ||
else: | ||
validity_dict = {'mol_stable': -1, 'atm_stable': -1} | ||
|
||
metrics = BasicMolecularMetrics(dataset_info, train_smiles) | ||
rdkit_metrics = metrics.evaluate(molecule_list) | ||
all_smiles = rdkit_metrics[-1] | ||
# if wandb.run: | ||
# nc = rdkit_metrics[-2] | ||
# dic = {'Validity': rdkit_metrics[0][0], 'Relaxed Validity': rdkit_metrics[0][1], | ||
# 'Uniqueness': rdkit_metrics[0][2], 'Novelty': rdkit_metrics[0][3], | ||
# 'nc_max': nc['nc_max'], 'nc_mu': nc['nc_mu']} | ||
# wandb.log(dic) | ||
|
||
return validity_dict, rdkit_metrics, all_smiles |
Oops, something went wrong.