Skip to content

Commit

Permalink
Improved typing
Browse files Browse the repository at this point in the history
  • Loading branch information
pschwllr committed Apr 21, 2024
1 parent 77b5459 commit 15a05c9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
24 changes: 12 additions & 12 deletions src/rxn_insight/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Any, Dict, List, Optional
from typing import Any, Optional

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -31,7 +31,7 @@ class ReactionClassifier:
def __init__(
self,
reaction: str,
rxn_mapper: RXNMapper | None = None,
rxn_mapper: Optional[RXNMapper] = None,
keep_mapping: bool = False,
):
if keep_mapping:
Expand Down Expand Up @@ -176,7 +176,7 @@ def get_template_smiles(self) -> str | None:
return rxn

def get_functional_group_smarts(
self, molecule: Mol, matrix: npt.NDArray[Any], map_dict: Dict[int, int]
self, molecule: Mol, matrix: npt.NDArray[Any], map_dict: dict[int, int]
) -> tuple[str, ...]:
maps = self.transformation_mapping
matrix_indices = [self.atom_mapping_index[atom_map] for atom_map in maps]
Expand Down Expand Up @@ -263,14 +263,14 @@ def get_functional_group_smarts(
return tuple(functional_groups)

def get_functional_groups(
self, mol: Mol, map_dict: Dict[int, int], df: pd.DataFrame
self, mol: Mol, map_dict: dict[int, int], df: pd.DataFrame
) -> list[str]:
maps = self.transformation_mapping
atom_indices = np.array(
[map_dict[atom_map] for atom_map in maps if atom_map in map_dict]
)
fg = []
visited_atoms: List[List[int]] = []
visited_atoms: list[list[int]] = []
for i in df.index:
if len(np.in1d(visited_atoms, atom_indices)) != 0:
if len(visited_atoms[np.in1d(visited_atoms, atom_indices)]) == len(
Expand Down Expand Up @@ -303,7 +303,7 @@ def get_functional_groups(
return fg

def get_ring_type(
self, mol: Mol, map_dict: Optional[Dict[int, int]] = None
self, mol: Mol, map_dict: Optional[dict[int, int]] = None
) -> list[str]:
try:
rs = get_ring_systems(mol, include_spiro=True)
Expand Down Expand Up @@ -343,7 +343,7 @@ def balance_reaction(self, fgr: list[str], fgp: list[str]) -> list[str]:
mp = self.be_matrix_products
lost_heavy = self.mol_reactant.GetNumAtoms() - self.mol_product.GetNumAtoms()
if lost_heavy == 0:
return [""]
return []
negative_values = np.where(d < 0)[0]
metals = np.array([3, 5, 11, 12, 29, 30, 34, 47, 50])
metal_indices = np.where(np.in1d(self.reaction_center_atoms, metals))[0]
Expand Down Expand Up @@ -444,8 +444,8 @@ def balance_reaction(self, fgr: list[str], fgp: list[str]) -> list[str]:

return small_molecules

def get_reaction_center_info(self, df: pd.DataFrame) -> Dict[str, List[str] | str]:
reaction_center: Dict[str, list[str] | str] = dict()
def get_reaction_center_info(self, df: pd.DataFrame) -> dict[str, list[str] | str]:
reaction_center: dict[str, list[str] | str] = dict()
reaction_center["REACTION"] = self.sanitized_reaction
reaction_center["MAPPED_REACTION"] = self.sanitized_mapped_reaction
reaction_center["N_REACTANTS"] = self.num_reactants
Expand Down Expand Up @@ -474,7 +474,7 @@ def get_reaction_center_info(self, df: pd.DataFrame) -> Dict[str, List[str] | st

def get_atom_mapping_indices(
self,
) -> tuple[Dict[int, int], npt.NDArray[Any], npt.NDArray[Any], int]:
) -> tuple[dict[int, int], npt.NDArray[Any], npt.NDArray[Any], int]:
"""Make a dictionary that gives a unique index to all atoms in reactants and products.
Necessary since reactions are not balanced.
:return: Dictionary that links atom map and index. Size of BE-matrix
Expand Down Expand Up @@ -1261,8 +1261,8 @@ def is_heteroatom_alkylation(self) -> bool:
)[0]
carbon_bonds = maps_1[np.in1d(maps_1, carbons)]
# rcid = np.array(self.reaction_center_idx) <--- seems to be unused
carbonyls_r: List[int] = []
carbonyls_p: List[int] = []
carbonyls_r: list[int] = []
carbonyls_p: list[int] = []
if len(o_indices) != 0:
for carbon in carbons:
carbonyls_r += list(
Expand Down
14 changes: 7 additions & 7 deletions src/rxn_insight/reaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hashlib
import warnings
from typing import Dict, List, Optional
from typing import Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
self.mapped_reaction = self.classifier.sanitized_mapped_reaction
self.reaction_class = ""
self.template = self.classifier.template
self.reaction_info: Dict[str, tuple[str, ...] | str] = dict()
self.reaction_info: dict[str, tuple[str, ...] | str] = dict()
self.tag = ""
self.name = ""
self.byproducts: tuple[str, ...] = tuple()
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_rings_in_reaction_center(
]
)

def get_functional_groups(self) -> tuple[List[str], ...]:
def get_functional_groups(self) -> tuple[list[str], ...]:
if self.fg_db is None:
from importlib import resources

Expand All @@ -137,7 +137,7 @@ def get_functional_groups(self) -> tuple[List[str], ...]:
]
)

def get_byproducts(self) -> List[str]:
def get_byproducts(self) -> list[str]:
fg_r, fg_p = self.get_functional_groups()
byproducts = self.classifier.balance_reaction(fg_r, fg_p)
self.byproducts = byproducts
Expand All @@ -157,7 +157,7 @@ def get_name(self) -> str:
self.name = self.classifier.name_reaction(self.smirks_db)
return self.name

def get_reaction_info(self) -> Dict[str, list[str] | str]:
def get_reaction_info(self) -> dict[str, list[str] | str]:
if self.fg_db is None:
from importlib import resources

Expand Down Expand Up @@ -326,7 +326,7 @@ def give_broad_tag(self) -> str:
hashtag = hashlib.sha256(tag_bytes).hexdigest()
return str(hashtag)

def suggest_conditions(self, df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
def suggest_conditions(self, df: pd.DataFrame) -> dict[str, pd.DataFrame]:
if self.neighbors is None or len(self.neighbors.index) == 0:
nbs = self.find_neighbors(df, max_return=5000, threshold=0.3, broaden=True)
else:
Expand Down Expand Up @@ -467,7 +467,7 @@ def get_functional_groups(self, df: pd.DataFrame = None) -> list[str]:
mol = self.mol
atom_indices = np.array([atom.GetIdx() for atom in mol.GetAtoms()])
fg = []
visited_atoms: List[List[int]] = []
visited_atoms: list[list[int]] = []
for i in df.index:
if len(np.in1d(visited_atoms, atom_indices)) != 0:
if len(visited_atoms[np.in1d(visited_atoms, atom_indices)]) == len(
Expand Down
8 changes: 4 additions & 4 deletions src/rxn_insight/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import hashlib
from typing import Any, Dict, List, Optional
from typing import Any, Optional

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -98,7 +98,7 @@ def get_atom_mapping(rxn: str, rxn_mapper: Optional[RXNMapper] = None) -> str:
return mapped_rxn


def sanitize_mapped_reaction(rxn: str) -> tuple[str, str, List[str]]:
def sanitize_mapped_reaction(rxn: str) -> tuple[str, str, list[str]]:
"""Remove reactants that are unmapped from the reactants.
:param rxn: Reaction SMILES with atom mapping
:return: Mapped and unmapped reaction SMILES without reagents.
Expand Down Expand Up @@ -571,12 +571,12 @@ def get_scaffold(mol: Mol) -> str | None:
"""
[a.SetAtomMapNum(0) for a in mol.GetAtoms()]
scaffold = GetScaffoldForMol(mol)
smi: str | None = Chem.MolToSmiles(scaffold)
smi: Optional[str] = Chem.MolToSmiles(scaffold)

return smi


def tag_reaction(rxn_info: Dict[str, List[str] | str]) -> str:
def tag_reaction(rxn_info: dict[str, list[str] | str]) -> str:
tag = f"{rxn_info['CLASS']} "
fg_r = sorted(list(rxn_info["FG_REACTANTS"]))
fg_p = sorted(list(rxn_info["FG_PRODUCTS"]))
Expand Down

0 comments on commit 15a05c9

Please sign in to comment.