From e7a7e9a529c2c484316c571816311ee055d9c868 Mon Sep 17 00:00:00 2001 From: SeonghwanSeo Date: Wed, 25 Dec 2024 17:03:30 +0900 Subject: [PATCH] linting, add offline setting --- .gitignore | 3 +- README.md | 26 +++-- modeling.py | 60 ++++++++--- pyproject.toml | 14 ++- src/pmnet/__init__.py | 8 +- src/pmnet/data/constant.py | 6 +- src/pmnet/data/extract_pocket.py | 3 +- src/pmnet/data/objects/atom_classes.py | 15 ++- src/pmnet/data/objects/objects.py | 35 ++++--- src/pmnet/data/objects/utils.py | 12 +-- src/pmnet/data/pointcloud.py | 6 +- src/pmnet/data/token_inference.py | 11 +-- src/pmnet/module.py | 93 ++++++++++++----- src/pmnet/network/backbones/swin.py | 2 +- src/pmnet/network/backbones/swinv2.py | 2 +- src/pmnet/network/backbones/timm.py | 27 +++-- src/pmnet/network/builder.py | 3 +- src/pmnet/network/cavity_head.py | 25 +++-- src/pmnet/network/decoders/fpn_decoder.py | 75 +++++++++----- src/pmnet/network/detector.py | 22 +++-- src/pmnet/network/feature_embedding.py | 12 ++- src/pmnet/network/mask_head.py | 99 +++++++++++++------ src/pmnet/network/nn/__init__.py | 2 +- src/pmnet/network/nn/layers.py | 49 +++++---- src/pmnet/network/token_head.py | 42 +++++--- src/pmnet/network/utils/registry.py | 14 ++- src/pmnet/scoring/match_utils.py | 2 +- src/pmnet/utils/density_map.py | 2 +- src/pmnet/utils/smoothing.py | 4 +- .../network/pharmacophore_encoder.py | 2 +- src/pmnet_appl/sbddreward/proxy.py | 64 +++++++++--- src/pmnet_appl/tacogfn_reward/proxy.py | 2 +- utils/parse_rcsb_pdb.py | 31 +++--- utils/visualize.py | 10 +- 34 files changed, 512 insertions(+), 271 deletions(-) diff --git a/.gitignore b/.gitignore index 07f3350..5393df3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,12 @@ # USER .DS_Store -weights +weights/ run.sh result/ examples/library/ nogit/ maintain_test/ +uv.lock # Byte-compiled / optimized / DLL files diff --git a/README.md b/README.md index d387c1e..e5ab38e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ - # PharmacoNet: Open-source Protein-based Pharmacophore Modeling -[![DOI](https://zenodo.org/badge/699273873.svg)](https://zenodo.org/doi/10.5281/zenodo.12168474) +[![DOI](https://zenodo.org/badge/699273873.svg)](https://zenodo.org/doi/10.5281/zenodo.12168474) **Before using PharmacoNet, consider using OpenPharmaco - GUI powered by PharmacoNet.** @@ -32,7 +31,6 @@ If you have any problems or need help with the code, please add an github issue - [Pre-trained Docking Proxy](#pretrained-docking-proxy) - [Citation](#citation) - ## Quick Start ```bash @@ -50,6 +48,7 @@ python feature_extraction.py --protein --center --out ``` ## Installation + - Using `environment.yml` For various environment including Linux, MacOS and Window, the script installs **cpu-only version of PyTorch** by default. You can install a cuda-available version by modifying `environment.yml` or installing PyTorch manually. @@ -61,8 +60,9 @@ python feature_extraction.py --protein --center --out ``` - Manual Installation + ```bash - # Required python>=3.9, Best Performance at higher version. (3.9, 3.10, 3.11, 3.12 - best) + # Required python>=3.9, Best Performance at higher version. (3.9, 3.10, 3.11, 3.12(best)) conda create --name openph python=3.10 openbabel=3.1.1 pymol-open-source=3.0.0 numpy=1.26.4 conda activate pmnet @@ -140,6 +140,15 @@ INFO:root:Save Pharmacophore Model to result/6OIM/6OIM_2.0_-8.0_-1.0_model.pm INFO:root:Save Pymol Visualization Session to result/6OIM/6OIM_2.0_-8.0_-1.0_model.pse ``` +#### Example with custom model weight file (offline) + +PharmacoNet's weight file is automatically downloaded during `modeling.py`. +If your environment is offline, you can download the weight files from [Google Drive](https://drive.google.com/uc?id=1gzjdM7bD3jPm23LBcDXtkSk18nETL04p). + +```bash +> python modeling.py --pdb 6oim --weight_path +``` + ## Virtual Screening We provide the simple script for screening. @@ -175,7 +184,7 @@ score = model.scoring_smiles(, ) ## Pharmacophore Feature Extraction -***See: [`./developer/`](/developer/), [`./src/pmnet_appl/`](/src/pmnet_appl/).*** +**_See: [`./developer/`](/developer/), [`./src/pmnet_appl/`](/src/pmnet_appl/)._** For deep learning researcher who want to use PharmacoNet as pre-trained model for feature extraction, we provide the python API. @@ -225,7 +234,8 @@ pmnet_attr = (multi_scale_features, hotspot_infos) ``` ## Pretrained Docking Proxy -***See: [`./src/pmnet_appl/`](/src/pmnet_appl/).*** + +**_See: [`./src/pmnet_appl/`](/src/pmnet_appl/)._** We provide pre-trained docking proxy models which predict docking score against arbitrary protein using PharmacoNet. We hope this implementation prompts the molecule optimization. @@ -233,11 +243,12 @@ We hope this implementation prompts the molecule optimization. If you use this implementation, please cite PharmacoNet with original papers. Implementation List: + - TacoGFN: Target-conditioned GFlowNet for Structure-based Drug Design [[paper](https://arxiv.org/abs/2310.03223)] Related Works: -- RxnFlow: Generative Flows on Synthetic Pathway for Drug Design [paper] +- RxnFlow: Generative Flows on Synthetic Pathway for Drug Design [paper] ## Citation @@ -252,4 +263,3 @@ Paper on [arxiv](https://arxiv.org/abs/2310.00681) url = {https://arxiv.org/abs/2310.00681}, } ``` - diff --git a/modeling.py b/modeling.py index 92e6e68..2ed840c 100644 --- a/modeling.py +++ b/modeling.py @@ -24,9 +24,13 @@ def __init__(self): cfg_args = self.add_argument_group("config") cfg_args.add_argument("--pdb", type=str, help="RCSB PDB code") cfg_args.add_argument("-l", "--ligand_id", type=str, help="RCSB ligand code") - cfg_args.add_argument("-p", "--protein", type=str, help="custom path of protein pdb file (.pdb)") + cfg_args.add_argument( + "-p", "--protein", type=str, help="custom path of protein pdb file (.pdb)" + ) cfg_args.add_argument("-c", "--chain", type=str, help="Chain") - cfg_args.add_argument("-a", "--all", action="store_true", help="use all binding sites") + cfg_args.add_argument( + "-a", "--all", action="store_true", help="use all binding sites" + ) cfg_args.add_argument( "--out_dir", type=str, @@ -43,8 +47,15 @@ def __init__(self): # system config env_args = self.add_argument_group("environment") - env_args.add_argument("--cuda", action="store_true", help="use gpu acceleration with CUDA") - env_args.add_argument("--force", action="store_true", help="force to save the pharmacophore model") + env_args.add_argument( + "--weight_path", type=str, help="(Optional) custom pharmaconet weight path" + ) + env_args.add_argument( + "--cuda", action="store_true", help="use gpu acceleration with CUDA" + ) + env_args.add_argument( + "--force", action="store_true", help="force to save the pharmacophore model" + ) env_args.add_argument("-v", "--verbose", action="store_true", help="verbose") # config @@ -54,12 +65,16 @@ def __init__(self): type=str, help="path of ligand to define the center of box (.sdf, .pdb, .mol2)", ) - adv_args.add_argument("--center", nargs="+", type=float, help="coordinate of the center") + adv_args.add_argument( + "--center", nargs="+", type=float, help="coordinate of the center" + ) def main(args): logging.info(pmnet.__description__) - assert args.prefix is not None or args.pdb is not None, "MISSING PREFIX: `--prefix` or `--pdb`" + assert ( + args.prefix is not None or args.pdb is not None + ), "MISSING PREFIX: `--prefix` or `--pdb`" PREFIX = args.prefix if args.prefix else args.pdb # NOTE: Setting @@ -70,19 +85,20 @@ def main(args): SAVE_DIR.mkdir(exist_ok=True, parents=True) # NOTE: Load PharmacoNet - module = PharmacoNet("cuda" if args.cuda else "cpu") - logging.info(f"Load PharmacoNet finish") + module = PharmacoNet("cuda" if args.cuda else "cpu", weight_path=args.weight_path) + logging.info("Load PharmacoNet finish") # NOTE: Set Protein + protein_path: str if isinstance(args.pdb, str): - protein_path: str = str(SAVE_DIR / f"{PREFIX}.pdb") + protein_path = str(SAVE_DIR / f"{PREFIX}.pdb") if not os.path.exists(protein_path): logging.info(f"Download {args.pdb} to {protein_path}") download_pdb(args.pdb, protein_path) else: logging.info(f"Load {protein_path}") elif isinstance(args.protein, str): - protein_path: str = args.protein + protein_path = args.protein assert os.path.exists(protein_path) logging.info(f"Load {protein_path}") else: @@ -96,13 +112,17 @@ def run_pmnet(filename, ligand_path=None, center=None) -> PharmacophoreModel: logging.warning(f"Modeling Pass - {model_path} exists") pharmacophore_model = PharmacophoreModel.load(str(model_path)) else: - pharmacophore_model = module.run(protein_path, ref_ligand_path=ligand_path, center=center) + pharmacophore_model = module.run( + protein_path, ref_ligand_path=ligand_path, center=center + ) pharmacophore_model.save(str(model_path)) logging.info(f"Save Pharmacophore Model to {model_path}") if (not args.force) and os.path.exists(pymol_path): logging.warning(f"Visualizing Pass - {pymol_path} exists\n") else: - visualize.visualize_single(pharmacophore_model, protein_path, ligand_path, PREFIX, str(pymol_path)) + visualize.visualize_single( + pharmacophore_model, protein_path, ligand_path, PREFIX, str(pymol_path) + ) logging.info(f"Save Pymol Visualization Session to {pymol_path}\n") return pharmacophore_model @@ -173,16 +193,22 @@ def run_pmnet_manual_center(): if (not args.force) and os.path.exists(pymol_path): logging.warning(f"Visualizing Pass - {pymol_path} exists\n") else: - visualize.visualize_multiple(model_dict, protein_path, PREFIX, str(pymol_path)) + visualize.visualize_multiple( + model_dict, protein_path, PREFIX, str(pymol_path) + ) logging.info(f"Save Pymol Visualization Session to {pymol_path}\n") return inform_list_text = "\n\n".join(str(inform) for inform in inform_list) - logging.info(f"A total of {len(inform_list)} ligand(s) are detected!\n{inform_list_text}\n") + logging.info( + f"A total of {len(inform_list)} ligand(s) are detected!\n{inform_list_text}\n" + ) # NOTE: Case 3-3: pattern matching if args.ligand_id is not None or args.chain is not None: - logging.info(f"Filtering with matching pattern - ligand id: {args.ligand_id}, chain: {args.chain}") + logging.info( + f"Filtering with matching pattern - ligand id: {args.ligand_id}, chain: {args.chain}" + ) filtered_inform_list = [] for inform in inform_list: if args.ligand_id is not None and args.ligand_id.upper() != inform.id: @@ -201,7 +227,9 @@ def run_pmnet_manual_center(): return FAIL if len(inform_list) > 1: inform_list_text = "\n\n".join(str(inform) for inform in inform_list) - logging.info(f"A total of {len(inform_list)} ligands are selected!\n{inform_list_text}\n") + logging.info( + f"A total of {len(inform_list)} ligands are selected!\n{inform_list_text}\n" + ) if len(inform_list) == 1: run_pmnet_inform(inform_list[0]) diff --git a/pyproject.toml b/pyproject.toml index b77ec2e..eda4775 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pharmaconet" -version = "2.1.0" +version = "2.1.1" description = "PharmacoNet: Open-Source Software for Protein-based Pharmacophore Modeling and Virtual Screening" license = { text = "MIT" } authors = [{ name = "Seonghwan Seo", email = "shwan0106@kaist.ac.kr" }] @@ -31,7 +31,8 @@ dependencies = [ "omegaconf>=2.3.0", "molvoxel>=0.1.3", "gdown>=5.1.0", - "biopython>=1.83" + "biopython>=1.83", + "ruff>=0.8.4", ] [project.optional-dependencies] @@ -68,19 +69,22 @@ line-length = 120 select = ["E", "F", "B", "UP", "T203",] ignore = ["E501"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "__init__.py" = [ "F401", # imported but unused "E402", # Module level import not at top of file ] -[tool.pyright] +[tool.basedpyright] pythonVersion = "3.10" typeCheckingMode = "standard" diagnosticMode = "openFilesOnly" reportImplicitStringConcatenation = false +useLibraryCodeForTypes = true reportGeneralTypeIssues = "warning" reportDeprecated = "warning" reportUnusedVariable = false reportUnusedImport = false - +venvPath = '.' +venv = '.venv' +exclude = [".venv/"] diff --git a/src/pmnet/__init__.py b/src/pmnet/__init__.py index b6e84a2..5a1c2b1 100644 --- a/src/pmnet/__init__.py +++ b/src/pmnet/__init__.py @@ -1,11 +1,7 @@ from .pharmacophore_model import PharmacophoreModel -__version__ = "2.1.0" -__citation_information__ = ( - "Seo, S., & Kim, W. Y. (2023, December). " - "PharmacoNet: Accelerating Large-Scale Virtual Screening by Deep Pharmacophore Modeling. " - "In NeurIPS 2023 Workshop on New Frontiers of AI for Drug Discovery and Development." -) +__version__ = "2.1.1" +__citation_information__ = "Seo, S., & Kim, W. Y. (2024). PharmacoNet: deep learning-guided pharmacophore modeling for ultra-large-scale virtual screening. Chemical Science, 15(46), 19473-19487." __maintainer__ = "https://github.com/SeonghwanSeo/PharmacoNet" __description__ = ( diff --git a/src/pmnet/data/constant.py b/src/pmnet/data/constant.py index 947b1e9..cdd9aa4 100644 --- a/src/pmnet/data/constant.py +++ b/src/pmnet/data/constant.py @@ -1,4 +1,4 @@ -from typing import Sequence, Set +from collections.abc import Sequence INTERACTION_LIST: Sequence[str] = ( "Hydrophobic", @@ -40,7 +40,7 @@ XBOND: 4.5, # 4.0 + 0.5 } -LONG_INTERACTION: Set[int] = { +LONG_INTERACTION: set[int] = { PISTACKING_P, PISTACKING_T, PICATION_PRING, @@ -49,7 +49,7 @@ SALTBRIDGE_PNEG, } -SHORT_INTERACTION: Set[int] = { +SHORT_INTERACTION: set[int] = { HYDROPHOBIC, HBOND_LDON, HBOND_PDON, diff --git a/src/pmnet/data/extract_pocket.py b/src/pmnet/data/extract_pocket.py index fb847ae..b3b1a7a 100644 --- a/src/pmnet/data/extract_pocket.py +++ b/src/pmnet/data/extract_pocket.py @@ -5,7 +5,6 @@ from Bio.PDB import PDBParser, PDBIO from Bio.PDB.PDBIO import Select -from typing import Union from numpy.typing import ArrayLike from pathlib import Path @@ -87,7 +86,7 @@ def accept_residue(self, residue): def extract_pocket( - protein_pdb_path: Union[str, Path], out_pocket_pdb_path: str, center: ArrayLike, cutoff: float = DEFAULT_CUTOFF + protein_pdb_path: str | Path, out_pocket_pdb_path: str, center: ArrayLike, cutoff: float = DEFAULT_CUTOFF ): parser = PDBParser() structure = parser.get_structure("protein", str(protein_pdb_path)) diff --git a/src/pmnet/data/objects/atom_classes.py b/src/pmnet/data/objects/atom_classes.py index 7481cec..a7d8bc0 100644 --- a/src/pmnet/data/objects/atom_classes.py +++ b/src/pmnet/data/objects/atom_classes.py @@ -3,13 +3,12 @@ import numpy as np from collections.abc import Sequence -from typing import Sequence, List, Tuple from numpy.typing import NDArray from functools import cached_property from . import utils -Tuple3D = Tuple[float, float, float] +Tuple3D = tuple[float, float, float] @dataclass @@ -49,7 +48,7 @@ def __len__(self): @dataclass -class BaseInteractablePart(): +class BaseInteractablePart: @property def small(self): @@ -58,7 +57,7 @@ def small(self): return self.__small def to_small(self): - raise NotImplemented + raise NotImplementedError @dataclass @@ -123,7 +122,7 @@ def __post_init__(self): self.normal = utils.normalize(np.cross(v1, v2)) @property - def indices(self) -> List[int]: + def indices(self) -> list[int]: return [obatom.GetIdx() - 1 for obatom in self.obatoms] @@ -141,7 +140,7 @@ def __post_init__(self): self.center = Point3D.from_array(np.mean(coords_list, axis=0)) @property - def indices(self) -> List[int]: + def indices(self) -> list[int]: return [obatom.GetIdx() - 1 for obatom in self.obatoms] @@ -180,7 +179,7 @@ def C_index(self) -> int: return self.C.GetIdx() - 1 @property - def indices(self) -> List[int]: + def indices(self) -> list[int]: return [self.X_index, self.C_index] @@ -204,7 +203,7 @@ def Y_index(self) -> int: return self.Y.GetIdx() - 1 @property - def indices(self) -> List[int]: + def indices(self) -> list[int]: return [self.O_index, self.Y_index] diff --git a/src/pmnet/data/objects/objects.py b/src/pmnet/data/objects/objects.py index 989ff41..5d7a4ff 100644 --- a/src/pmnet/data/objects/objects.py +++ b/src/pmnet/data/objects/objects.py @@ -2,7 +2,6 @@ from openbabel import pybel from openbabel.pybel import ob -from typing import List, Tuple from .atom_classes import ( HydrophobicAtom_P, @@ -35,7 +34,7 @@ def __init__( self.pbmol = pbmol.clone self.pbmol.removeh() self.obmol = self.pbmol.OBMol - self.obatoms: List[ob.OBAtom] = list(ob.OBMolAtomIter(self.obmol)) + self.obatoms: list[ob.OBAtom] = list(ob.OBMolAtomIter(self.obmol)) self.num_heavyatoms = len(self.obatoms) self.pbmol_hyd: pybel.Molecule @@ -45,21 +44,21 @@ def __init__( else: self.pbmol_hyd = pbmol self.obmol_hyd = self.pbmol_hyd.OBMol - self.obatoms_hyd: List[ob.OBAtom] = list(ob.OBMolAtomIter(self.obmol_hyd))[: self.num_heavyatoms] - self.obatoms_hyd_nonwater: List[ob.OBAtom] = [ + self.obatoms_hyd: list[ob.OBAtom] = list(ob.OBMolAtomIter(self.obmol_hyd))[: self.num_heavyatoms] + self.obatoms_hyd_nonwater: list[ob.OBAtom] = [ obatom for obatom in self.obatoms_hyd if obatom.GetResidue().GetName() != "HOH" and obatom.GetAtomicNum() in [6, 7, 8, 16] ] - self.obresidues_hyd: List[ob.OBResidue] = list(ob.OBResidueIter(self.obmol_hyd)) + self.obresidues_hyd: list[ob.OBResidue] = list(ob.OBResidueIter(self.obmol_hyd)) - self.hydrophobic_atoms_all: List[HydrophobicAtom_P] - self.rings_all: List[Ring_P] - self.pos_charged_atoms_all: List[PosCharged_P] - self.neg_charged_atoms_all: List[NegCharged_P] - self.hbond_donors_all: List[HBondDonor_P] - self.hbond_acceptors_all: List[HBondAcceptor_P] - self.xbond_acceptors_all: List[XBondAcceptor_P] + self.hydrophobic_atoms_all: list[HydrophobicAtom_P] + self.rings_all: list[Ring_P] + self.pos_charged_atoms_all: list[PosCharged_P] + self.neg_charged_atoms_all: list[NegCharged_P] + self.hbond_donors_all: list[HBondDonor_P] + self.hbond_acceptors_all: list[HBondAcceptor_P] + self.xbond_acceptors_all: list[XBondAcceptor_P] self.hydrophobic_atoms_all = self.__find_hydrophobic_atoms() self.rings_all = self.__find_rings() @@ -74,7 +73,7 @@ def from_pdbfile(cls, path, addh=True, **kwargs): return cls(pbmol, addh, **kwargs) # Search Interactable Part - def __find_hydrophobic_atoms(self) -> List[HydrophobicAtom_P]: + def __find_hydrophobic_atoms(self) -> list[HydrophobicAtom_P]: hydrophobics = [ HydrophobicAtom_P(obatom) for obatom in self.obatoms_hyd_nonwater @@ -82,15 +81,15 @@ def __find_hydrophobic_atoms(self) -> List[HydrophobicAtom_P]: ] return hydrophobics - def __find_hbond_acceptors(self) -> List[HBondAcceptor_P]: + def __find_hbond_acceptors(self) -> list[HBondAcceptor_P]: acceptors = [HBondAcceptor_P(obatom) for obatom in self.obatoms_hyd_nonwater if obatom.IsHbondAcceptor()] return acceptors - def __find_hbond_donors(self) -> List[HBondDonor_P]: + def __find_hbond_donors(self) -> list[HBondDonor_P]: donors = [HBondDonor_P(obatom) for obatom in self.obatoms_hyd_nonwater if obatom.IsHbondDonor()] return donors - def __find_rings(self) -> List[Ring_P]: + def __find_rings(self) -> list[Ring_P]: rings = [] ring_candidates = self.pbmol_hyd.sssr for ring in ring_candidates: @@ -103,7 +102,7 @@ def __find_rings(self) -> List[Ring_P]: rings.append(Ring_P(obatoms)) return rings - def __find_charged_atoms(self) -> Tuple[List[PosCharged_P], List[NegCharged_P]]: + def __find_charged_atoms(self) -> tuple[list[PosCharged_P], list[NegCharged_P]]: pos_charged = [] neg_charged = [] @@ -129,7 +128,7 @@ def __find_charged_atoms(self) -> Tuple[List[PosCharged_P], List[NegCharged_P]]: return pos_charged, neg_charged - def __find_xbond_acceptors(self) -> List[XBondAcceptor_P]: + def __find_xbond_acceptors(self) -> list[XBondAcceptor_P]: """Look for halogen bond acceptors (Y-{O|N|S}, with Y=N,C)""" acceptors = [] for obatom in self.obatoms_hyd_nonwater: diff --git a/src/pmnet/data/objects/utils.py b/src/pmnet/data/objects/utils.py index c66e544..5d3615d 100644 --- a/src/pmnet/data/objects/utils.py +++ b/src/pmnet/data/objects/utils.py @@ -3,7 +3,7 @@ from openbabel.pybel import ob -from typing import Sequence, Tuple, Union +from collections.abc import Sequence from numpy.typing import NDArray @@ -34,12 +34,12 @@ def angle_btw_vectors(vec1: NDArray, vec2: NDArray, degree=True, normalized=Fals return math.degrees(angle) if degree else angle -def vector(p1: Union[Sequence[float], NDArray], p2: Union[Sequence[float], NDArray]) -> NDArray: +def vector(p1: Sequence[float] | NDArray, p2: Sequence[float] | NDArray) -> NDArray: return np.subtract(p2, p1) -def euclidean3d(p1: Union[Sequence[float], NDArray], p2: Union[Sequence[float], NDArray]) -> float: - return math.sqrt(sum([(a - b) ** 2 for a, b in zip(p1, p2)])) +def euclidean3d(p1: Sequence[float] | NDArray, p2: Sequence[float] | NDArray) -> float: + return math.sqrt(sum([(a - b) ** 2 for a, b in zip(p1, p2, strict=False)])) def normalize(vec: NDArray) -> NDArray: @@ -48,7 +48,7 @@ def normalize(vec: NDArray) -> NDArray: return vec / norm -def projection(point: Union[Sequence[float], NDArray], origin: Union[Sequence[float], NDArray], +def projection(point: Sequence[float] | NDArray, origin: Sequence[float] | NDArray, normal: NDArray) -> NDArray: """ point: point to be projected @@ -58,5 +58,5 @@ def projection(point: Union[Sequence[float], NDArray], origin: Union[Sequence[fl return np.subtract(point, c * normal) -def ob_coords(obatom: ob.OBAtom) -> Tuple[float, float, float]: +def ob_coords(obatom: ob.OBAtom) -> tuple[float, float, float]: return (obatom.x(), obatom.y(), obatom.z()) diff --git a/src/pmnet/data/pointcloud.py b/src/pmnet/data/pointcloud.py index 949da2b..c3ff44c 100644 --- a/src/pmnet/data/pointcloud.py +++ b/src/pmnet/data/pointcloud.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Tuple, Sequence +from collections.abc import Sequence from openbabel.pybel import ob from numpy.typing import NDArray @@ -23,7 +23,7 @@ NUM_PROTEIN_CHANNEL = len(PROTEIN_CHANNEL_LIST) -def get_position(obatom: ob.OBAtom) -> Tuple[float, float, float]: +def get_position(obatom: ob.OBAtom) -> tuple[float, float, float]: return (obatom.x(), obatom.y(), obatom.z()) @@ -41,7 +41,7 @@ def protein_atom_function(atom: ob.OBAtom, out: NDArray, **kwargs) -> NDArray[np return out -def get_protein_pointcloud(pocket_obj: Protein) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: +def get_protein_pointcloud(pocket_obj: Protein) -> tuple[NDArray[np.float32], NDArray[np.float32]]: positions = np.array( [(obatom.x(), obatom.y(), obatom.z()) for obatom in pocket_obj.obatoms], dtype=np.float32 diff --git a/src/pmnet/data/token_inference.py b/src/pmnet/data/token_inference.py index 70b0a61..59380a8 100644 --- a/src/pmnet/data/token_inference.py +++ b/src/pmnet/data/token_inference.py @@ -1,14 +1,13 @@ import numpy as np import math -from typing import Tuple, List from numpy.typing import NDArray, ArrayLike from .objects import Protein from . import constant as C -def get_token_informations(protein_obj: Protein) -> Tuple[NDArray[np.float32], NDArray[np.int16]]: +def get_token_informations(protein_obj: Protein) -> tuple[NDArray[np.float32], NDArray[np.int16]]: """get token information Args: @@ -28,8 +27,8 @@ def get_token_informations(protein_obj: Protein) -> Tuple[NDArray[np.float32], N + len(protein_obj.xbond_acceptors_all) ) - positions: List[Tuple[float, float, float]] = [] - classes: List[int] = [] + positions: list[tuple[float, float, float]] = [] + classes: list[int] = [] # NOTE: Hydrophobic positions.extend(tuple(hydrop.coords) for hydrop in protein_obj.hydrophobic_atoms_all) @@ -82,7 +81,7 @@ def get_token_and_filter( positions: NDArray[np.float32], classes: NDArray[np.int16], center: NDArray[np.float32], -) -> Tuple[NDArray[np.int16], NDArray[np.int16]]: +) -> tuple[NDArray[np.int16], NDArray[np.int16]]: """Create token and Filtering valid instances Args: @@ -103,7 +102,7 @@ def get_token_and_filter( x_start = x_center - (dimension / 2) * resolution y_start = y_center - (dimension / 2) * resolution z_start = z_center - (dimension / 2) * resolution - for i, ((x, y, z), c) in enumerate(zip(positions, classes)): + for i, ((x, y, z), c) in enumerate(zip(positions, classes, strict=False)): _x = int((x - x_start) // resolution) _y = int((y - y_start) // resolution) _z = int((z - z_start) // resolution) diff --git a/src/pmnet/module.py b/src/pmnet/module.py index 33f94e4..79ca7b3 100644 --- a/src/pmnet/module.py +++ b/src/pmnet/module.py @@ -21,7 +21,11 @@ from pmnet.data.parser import ProteinParser from pmnet.utils.smoothing import GaussianSmoothing from pmnet.utils.download_weight import download_pretrained_model -from pmnet.pharmacophore_model import PharmacophoreModel, INTERACTION_TO_PHARMACOPHORE, INTERACTION_TO_HOTSPOT +from pmnet.pharmacophore_model import ( + PharmacophoreModel, + INTERACTION_TO_PHARMACOPHORE, + INTERACTION_TO_HOTSPOT, +) DEFAULT_FOCUS_THRESHOLD = 0.5 DEFAULT_BOX_THRESHOLD = 0.5 @@ -49,6 +53,7 @@ def __init__( score_threshold: float | dict[str, float] | None = DEFAULT_SCORE_THRESHOLD, verbose: bool = True, molvoxel_library: str = "numba", + weight_path: str | Path | None = None, ): """ device: 'cpu' | 'cuda' @@ -64,9 +69,12 @@ def __init__( self.parser: ProteinParser = ProteinParser(molvoxel_library=molvoxel_library) running_path = Path(__file__) - weight_path = running_path.parent / "weights" / "model.tar" - if not weight_path.exists(): - download_pretrained_model(weight_path, verbose) + if weight_path is None: + weight_path = running_path.parent / "weights" / "model.tar" + if not weight_path.exists(): + download_pretrained_model(weight_path, verbose) + else: + weight_path = Path(weight_path) checkpoint = torch.load(weight_path, map_location="cpu") config = OmegaConf.create(checkpoint["config"]) model = build_model(config.MODEL) @@ -78,7 +86,8 @@ def __init__( self.smoothing = GaussianSmoothing(kernel_size=5, sigma=0.5).to(device) self.score_distributions = { - typ: np.array(distribution["focus"]) for typ, distribution in checkpoint["score_distributions"].items() + typ: np.array(distribution["focus"]) + for typ, distribution in checkpoint["score_distributions"].items() } del checkpoint @@ -132,12 +141,20 @@ def run_extraction( tokens = tokens.to(device=self.device) mask = mask.to(device=self.device) - multi_scale_features = self.model.forward_feature(protein_image.unsqueeze(0)) # List[[1, D, H, W, F]] - token_scores, token_features = self.model.forward_token_prediction(multi_scale_features[-1], [tokens]) + multi_scale_features = self.model.forward_feature( + protein_image.unsqueeze(0) + ) # List[[1, D, H, W, F]] + token_scores, token_features = self.model.forward_token_prediction( + multi_scale_features[-1], [tokens] + ) token_scores = token_scores[0].sigmoid() # [Ntoken,] token_features = token_features[0] # [Ntoken, F] - cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(multi_scale_features[-1]) - cavity_narrow = cavity_narrow[0].sigmoid() > self.focus_threshold # [1, D, H, W] + cavity_narrow, cavity_wide = self.model.forward_cavity_extraction( + multi_scale_features[-1] + ) + cavity_narrow = ( + cavity_narrow[0].sigmoid() > self.focus_threshold + ) # [1, D, H, W] cavity_wide = cavity_wide[0].sigmoid() > self.focus_threshold # [1, D, H, W] indices = [] @@ -146,7 +163,12 @@ def run_extraction( x, y, z, typ = tokens[i].tolist() # NOTE: Check the token score absolute_score = token_scores[i].item() - relative_score = float((self.score_distributions[C.INTERACTION_LIST[int(typ)]] < absolute_score).mean()) + relative_score = float( + ( + self.score_distributions[C.INTERACTION_LIST[int(typ)]] + < absolute_score + ).mean() + ) if relative_score < self.score_threshold[C.INTERACTION_LIST[int(typ)]]: continue # NOTE: Check the token exists in cavity @@ -161,7 +183,9 @@ def run_extraction( del protein_image, mask, token_pos, tokens hotspot_infos = [] - for hotspot, score, position, feature in zip(hotspots, rel_scores, hotpsot_pos, hotspot_features, strict=True): + for hotspot, score, position, feature in zip( + hotspots, rel_scores, hotpsot_pos, hotspot_features, strict=True + ): interaction_type = C.INTERACTION_LIST[int(hotspot[3])] hotspot_infos.append( { @@ -196,7 +220,9 @@ def get_center( extension = os.path.splitext(ref_ligand_path)[1] assert extension in [".sdf", ".pdb", ".mol2"] ref_ligand = next(pybel.readfile(extension[1:], str(ref_ligand_path))) - x, y, z = np.mean([atom.coords for atom in ref_ligand.atoms], axis=0, dtype=np.float32).tolist() + x, y, z = np.mean( + [atom.coords for atom in ref_ligand.atoms], axis=0, dtype=np.float32 + ).tolist() return float(x), float(y), float(z) @torch.no_grad() @@ -211,12 +237,20 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor "debug", f"Protein-based Pharmacophore Modeling... (device: {self.device})", ) - multi_scale_features = self.model.forward_feature(protein_image.unsqueeze(0)) # List[[1, D, H, W, F]] - token_scores, token_features = self.model.forward_token_prediction(multi_scale_features[-1], [tokens]) + multi_scale_features = self.model.forward_feature( + protein_image.unsqueeze(0) + ) # List[[1, D, H, W, F]] + token_scores, token_features = self.model.forward_token_prediction( + multi_scale_features[-1], [tokens] + ) token_scores = token_scores[0].sigmoid() # [Ntoken,] token_features = token_features[0] # [Ntoken, F] - cavity_narrow, cavity_wide = self.model.forward_cavity_extraction(multi_scale_features[-1]) - cavity_narrow = cavity_narrow[0].sigmoid() > self.focus_threshold # [1, D, H, W] + cavity_narrow, cavity_wide = self.model.forward_cavity_extraction( + multi_scale_features[-1] + ) + cavity_narrow = ( + cavity_narrow[0].sigmoid() > self.focus_threshold + ) # [1, D, H, W] cavity_wide = cavity_wide[0].sigmoid() > self.focus_threshold # [1, D, H, W] num_tokens = tokens.shape[0] @@ -226,7 +260,12 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor x, y, z, typ = tokens[i].tolist() # NOTE: Check the token score absolute_score = token_scores[i].item() - relative_score = float((self.score_distributions[C.INTERACTION_LIST[int(typ)]] < absolute_score).mean()) + relative_score = float( + ( + self.score_distributions[C.INTERACTION_LIST[int(typ)]] + < absolute_score + ).mean() + ) if relative_score < self.score_threshold[C.INTERACTION_LIST[int(typ)]]: continue # NOTE: Check the token exists in cavity @@ -238,7 +277,9 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor continue indices.append(i) rel_scores.append(relative_score) - selected_indices = torch.tensor(indices, device=self.device, dtype=torch.long) # [Ntoken',] + selected_indices = torch.tensor( + indices, device=self.device, dtype=torch.long + ) # [Ntoken',] hotspots = tokens[selected_indices] # [Ntoken',] hotpsot_pos = token_pos[selected_indices] # [Ntoken', 3] hotspot_features = token_features[selected_indices] # [Ntoken', F] @@ -256,8 +297,12 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor disable=(self.logger is None), ) as pbar: for idx in range(0, hotspots.size(0), step): - _hotspots, _hotspot_features = [hotspots[idx : idx + step]], [hotspot_features[idx : idx + step]] - density_maps = self.model.forward_segmentation(multi_scale_features, _hotspots, _hotspot_features)[0] + _hotspots, _hotspot_features = [hotspots[idx : idx + step]], [ + hotspot_features[idx : idx + step] + ] + density_maps = self.model.forward_segmentation( + multi_scale_features, _hotspots, _hotspot_features + )[0] density_maps = density_maps[0].sigmoid() # [4, D, H, W] density_maps_list.append(density_maps) pbar.update(len(_hotspots)) @@ -265,7 +310,9 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor density_maps = torch.cat(density_maps_list, dim=0) # [Ntoken', D, H, W] box_area = token_inference.get_box_area(hotspots) - box_area = torch.from_numpy(box_area).to(device=self.device, dtype=torch.bool) # [Ntoken', D, H, W] + box_area = torch.from_numpy(box_area).to( + device=self.device, dtype=torch.bool + ) # [Ntoken', D, H, W] unavailable_area = ~(box_area & mask & cavity_narrow) # [Ntoken', D, H, W] # NOTE: masking should be performed before smoothing - masked area is not trained. @@ -275,7 +322,9 @@ def create_density_maps(self, protein_data: tuple[Tensor, Tensor, Tensor, Tensor density_maps[density_maps < self.box_threshold] = 0.0 hotspot_infos = [] - for hotspot, score, position, map in zip(hotspots, rel_scores, hotpsot_pos, density_maps, strict=True): + for hotspot, score, position, map in zip( + hotspots, rel_scores, hotpsot_pos, density_maps, strict=True + ): if torch.all(map < 1e-6): continue interaction_type = C.INTERACTION_LIST[int(hotspot[3])] diff --git a/src/pmnet/network/backbones/swin.py b/src/pmnet/network/backbones/swin.py index 4b3143c..7231de2 100644 --- a/src/pmnet/network/backbones/swin.py +++ b/src/pmnet/network/backbones/swin.py @@ -440,7 +440,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, int, int, int, Tensor, int, int, i mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0) for blk in self.blocks: x = blk(x, attn_mask) diff --git a/src/pmnet/network/backbones/swinv2.py b/src/pmnet/network/backbones/swinv2.py index e36d5bc..c73a05d 100644 --- a/src/pmnet/network/backbones/swinv2.py +++ b/src/pmnet/network/backbones/swinv2.py @@ -237,7 +237,7 @@ def __init__( mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0) else: attn_mask = None diff --git a/src/pmnet/network/backbones/timm.py b/src/pmnet/network/backbones/timm.py index 77711dc..456ba21 100644 --- a/src/pmnet/network/backbones/timm.py +++ b/src/pmnet/network/backbones/timm.py @@ -31,12 +31,12 @@ def norm_cdf(x): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) + u = norm_cdf((a - mean) / std) + v = norm_cdf((b - mean) / std) - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + # Uniformly fill tensor with values from [u, v], then translate to + # [2u-1, 2v-1]. + tensor.uniform_(2 * u - 1, 2 * v - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal @@ -51,8 +51,13 @@ def norm_cdf(x): return tensor -def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): - # type: (Tensor, float, float, float, float) -> Tensor +def trunc_normal_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> torch.Tensor: r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` @@ -78,7 +83,9 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): return _trunc_normal_(tensor, mean, std, a, b) -def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, @@ -91,7 +98,9 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) diff --git a/src/pmnet/network/builder.py b/src/pmnet/network/builder.py index d7c4c00..a5f64e3 100644 --- a/src/pmnet/network/builder.py +++ b/src/pmnet/network/builder.py @@ -1,5 +1,4 @@ from torch import nn -from typing import Dict from .utils.registry import Registry @@ -13,7 +12,7 @@ MODEL = Registry("Model") -def build_model(config: Dict) -> nn.Module: +def build_model(config: dict) -> nn.Module: registry_key = "registry" module_key = "name" return Registry.build_from_config(config, registry_key, module_key, convert_key_to_lower_case=True, safe_build=True) diff --git a/src/pmnet/network/cavity_head.py b/src/pmnet/network/cavity_head.py index 4321a1b..a5c7bb4 100644 --- a/src/pmnet/network/cavity_head.py +++ b/src/pmnet/network/cavity_head.py @@ -2,7 +2,6 @@ from functools import partial -from typing import Tuple, Type, Optional from torch import Tensor from .builder import HEAD @@ -15,17 +14,29 @@ def __init__( self, feature_dim: int, hidden_dim: int, - norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm3d, - act_layer: Optional[Type[nn.Module]] = partial(nn.ReLU, inplace=True), + norm_layer: type[nn.Module] | None = nn.BatchNorm3d, + act_layer: type[nn.Module] | None = partial(nn.ReLU, inplace=True), ): - super(CavityHead, self).__init__() + super().__init__() self.short_head = nn.Sequential( - BaseConv3d(feature_dim, hidden_dim, kernel_size=3, norm_layer=norm_layer, act_layer=act_layer), + BaseConv3d( + feature_dim, + hidden_dim, + kernel_size=3, + norm_layer=norm_layer, + act_layer=act_layer, + ), BaseConv3d(hidden_dim, 1, kernel_size=1, norm_layer=None, act_layer=None), ) self.long_head = nn.Sequential( - BaseConv3d(feature_dim, hidden_dim, kernel_size=3, norm_layer=norm_layer, act_layer=act_layer), + BaseConv3d( + feature_dim, + hidden_dim, + kernel_size=3, + norm_layer=norm_layer, + act_layer=act_layer, + ), BaseConv3d(hidden_dim, 1, kernel_size=1, norm_layer=None, act_layer=None), ) @@ -38,7 +49,7 @@ def initialize_weights(self): def forward( self, features: Tensor, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: """Pocket Extraction Function Args: diff --git a/src/pmnet/network/decoders/fpn_decoder.py b/src/pmnet/network/decoders/fpn_decoder.py index 529e660..159b613 100644 --- a/src/pmnet/network/decoders/fpn_decoder.py +++ b/src/pmnet/network/decoders/fpn_decoder.py @@ -1,7 +1,7 @@ import torch.nn as nn import torch.nn.functional as F -from typing import Sequence, Optional, Type, List +from collections.abc import Sequence from torch import Tensor from ..nn.layers import BaseConv3d @@ -20,15 +20,15 @@ def __init__( feature_channels: Sequence[int], num_convs: Sequence[int], channels: int = 64, - interpolate_mode: str = 'nearest', + interpolate_mode: str = "nearest", align_corners: bool = False, - norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm3d, - act_layer: Optional[Type[nn.Module]] = nn.ReLU, + norm_layer: type[nn.Module] | None = nn.BatchNorm3d, + act_layer: type[nn.Module] | None = nn.ReLU, ): - super(FPNDecoder, self).__init__() + super().__init__() self.feature_channels = feature_channels self.interpolate_mode = interpolate_mode - if interpolate_mode == 'trilinear': + if interpolate_mode == "trilinear": self.align_corners = align_corners else: self.align_corners = None @@ -36,23 +36,45 @@ def __init__( lateral_conv_list = [] fpn_convs_list = [] - for level, (channels, num_conv) in enumerate(zip(self.feature_channels, num_convs)): - if level == (len(self.feature_channels) - 1): # Lowest-Resolution Channels (Top) + for level, (channels, num_conv) in enumerate( + zip(self.feature_channels, num_convs, strict=False) + ): + if level == ( + len(self.feature_channels) - 1 + ): # Lowest-Resolution Channels (Top) lateral_conv = nn.Identity() - fpn_convs = nn.Sequential(*[ - BaseConv3d( - channels if i == 0 else self.channels, self.channels, - kernel_size=3, norm_layer=norm_layer, act_layer=act_layer, - ) for i in range(num_conv) - ]) + fpn_convs = nn.Sequential( + *[ + BaseConv3d( + channels if i == 0 else self.channels, + self.channels, + kernel_size=3, + norm_layer=norm_layer, + act_layer=act_layer, + ) + for i in range(num_conv) + ] + ) else: - lateral_conv = BaseConv3d(channels, self.channels, kernel_size=1, norm_layer=norm_layer, act_layer=act_layer) - fpn_convs = nn.Sequential(*[ - BaseConv3d( - self.channels, self.channels, - kernel_size=3, norm_layer=norm_layer, act_layer=act_layer, - ) for _ in range(num_conv) - ]) + lateral_conv = BaseConv3d( + channels, + self.channels, + kernel_size=1, + norm_layer=norm_layer, + act_layer=act_layer, + ) + fpn_convs = nn.Sequential( + *[ + BaseConv3d( + self.channels, + self.channels, + kernel_size=3, + norm_layer=norm_layer, + act_layer=act_layer, + ) + for _ in range(num_conv) + ] + ) lateral_conv_list.append(lateral_conv) fpn_convs_list.append(fpn_convs) @@ -67,7 +89,7 @@ def initialize_weights(self): for m in seqm.children(): m.initialize_weights() - def forward(self, features: Sequence[Tensor]) -> List[Tensor]: + def forward(self, features: Sequence[Tensor]) -> list[Tensor]: """Forward function. Args: features: Bottom-Up, [Highest-Resolution Feature Map, ..., Lowest-Resolution Feature Map] @@ -83,12 +105,17 @@ def forward(self, features: Sequence[Tensor]) -> List[Tensor]: lateral_conv = self.lateral_conv_list[level] fpn_convs = self.fpn_convs_list[level] current_fpn = lateral_conv(feature) - if level == (num_levels - 1): # Top + if level == (num_levels - 1): # Top assert fpn is None fpn = current_fpn else: assert fpn is not None - fpn = current_fpn + F.interpolate(fpn, size=current_fpn.size()[-3:], mode=self.interpolate_mode, align_corners=self.align_corners) + fpn = current_fpn + F.interpolate( + fpn, + size=current_fpn.size()[-3:], + mode=self.interpolate_mode, + align_corners=self.align_corners, + ) fpn = fpn_convs(fpn) multi_scale_features.append(fpn) return multi_scale_features diff --git a/src/pmnet/network/detector.py b/src/pmnet/network/detector.py index 1ee6ffb..351b9cf 100644 --- a/src/pmnet/network/detector.py +++ b/src/pmnet/network/detector.py @@ -1,6 +1,6 @@ import torch.nn as nn -from typing import Sequence, List, Optional, Tuple, Optional +from collections.abc import Sequence from torch import Tensor, IntTensor from .feature_embedding import FeaturePyramidNetwork @@ -20,7 +20,7 @@ def __init__( mask_head: MaskHead, num_interactions: int, ): - super(PharmacoFormer, self).__init__() + super().__init__() self.num_interactions = num_interactions self.embedding = embedding self.cavity_head = cavity_head @@ -35,7 +35,7 @@ def initialize_weights(self): def setup_train(self, criterion: nn.Module): self.criterion = criterion - def forward_feature(self, in_image: Tensor) -> Tuple[Tensor, ...]: + def forward_feature(self, in_image: Tensor) -> tuple[Tensor, ...]: """Feature Embedding Args: in_image: FloatTensor [N, C, Din, Hin, Win] @@ -44,7 +44,7 @@ def forward_feature(self, in_image: Tensor) -> Tuple[Tensor, ...]: """ return tuple(self.embedding.forward(in_image)) - def forward_cavity_extraction(self, features: Tensor) -> Tuple[Tensor, Tensor]: + def forward_cavity_extraction(self, features: Tensor) -> tuple[Tensor, Tensor]: """Cavity Extraction Args: features: FloatTensor [N, F, Dout, Hout, Wout] @@ -58,7 +58,7 @@ def forward_token_prediction( self, features: Tensor, tokens_list: Sequence[IntTensor], - ) -> Tuple[List[Tensor], List[Tensor]]: + ) -> tuple[list[Tensor], list[Tensor]]: """token Selection Network Args: @@ -69,16 +69,18 @@ def forward_token_prediction( token_scores_list: List[FloatTensor [Ntoken,] $\\in$ [0, 1]] token_features_list: List[FloatTensor [Ntoken, F]] """ - token_scores_list, token_features_list = self.token_head.forward(features, tokens_list) + token_scores_list, token_features_list = self.token_head.forward( + features, tokens_list + ) return token_scores_list, token_features_list def forward_segmentation( self, - multi_scale_features: Tuple[Tensor, ...], + multi_scale_features: tuple[Tensor, ...], box_tokens_list: Sequence[IntTensor], box_token_features_list: Sequence[Tensor], return_aux: bool = False, - ) -> Tuple[List[Tensor], Optional[List[List[Tensor]]]]: + ) -> tuple[list[Tensor], list[list[Tensor]] | None]: """Mask Prediction Args: @@ -90,4 +92,6 @@ def forward_segmentation( box_masks_list: List[FloatTensor [Nbox, D, H, W]] aux_box_masks_list: List[List[FloatTensor [Nbox, D_scale, H_scale, W_scale]]] """ - return self.mask_head.forward(multi_scale_features, box_tokens_list, box_token_features_list, return_aux) + return self.mask_head.forward( + multi_scale_features, box_tokens_list, box_token_features_list, return_aux + ) diff --git a/src/pmnet/network/feature_embedding.py b/src/pmnet/network/feature_embedding.py index b7280ab..95670e7 100644 --- a/src/pmnet/network/feature_embedding.py +++ b/src/pmnet/network/feature_embedding.py @@ -1,6 +1,6 @@ from torch import nn -from typing import Sequence, Optional +from collections.abc import Sequence from torch import Tensor from .builder import EMBEDDING @@ -12,11 +12,11 @@ def __init__( self, backbone: nn.Module, decoder: nn.Module, - neck: Optional[nn.Module] = None, - feature_indices: Optional[Sequence[int]] = None, + neck: nn.Module | None = None, + feature_indices: Sequence[int] | None = None, set_input_to_bottom: bool = True, ): - super(FeaturePyramidNetwork, self).__init__() + super().__init__() self.backbone = backbone self.decoder = decoder self.feature_indices = feature_indices @@ -43,7 +43,9 @@ def forward(self, in_image: Tensor) -> Sequence[Tensor]: """ bottom_up_features: Sequence[Tensor] = self.backbone(in_image) if self.feature_indices is not None: - bottom_up_features = [bottom_up_features[index] for index in self.feature_indices] + bottom_up_features = [ + bottom_up_features[index] for index in self.feature_indices + ] if self.input_is_bottom: bottom_up_features = [in_image, *bottom_up_features] if self.with_neck: diff --git a/src/pmnet/network/mask_head.py b/src/pmnet/network/mask_head.py index d710de6..9cc973d 100644 --- a/src/pmnet/network/mask_head.py +++ b/src/pmnet/network/mask_head.py @@ -1,7 +1,7 @@ import torch from torch import nn -from typing import Sequence, List, Optional, Tuple +from collections.abc import Sequence from torch import Tensor from .builder import HEAD @@ -14,23 +14,30 @@ def __init__( decoder: nn.Module, token_feature_dim: int, ): - super(MaskHead, self).__init__() - feature_channels_list: List[int] = decoder.feature_channels + super().__init__() + feature_channels_list: list[int] = decoder.feature_channels self.point_mlp_list = nn.ModuleList( - [nn.Linear(token_feature_dim, channels) for channels in feature_channels_list] + [ + nn.Linear(token_feature_dim, channels) + for channels in feature_channels_list + ] ) self.background_mlp_list = nn.ModuleList( - [nn.Linear(token_feature_dim, channels) for channels in feature_channels_list] + [ + nn.Linear(token_feature_dim, channels) + for channels in feature_channels_list + ] ) self.decoder = decoder self.conv_logits = nn.Conv3d(decoder.channels, 1, kernel_size=1) def initialize_weights(self): def _init_weight(m): - if isinstance(m, (nn.Linear, nn.Conv3d)): + if isinstance(m, nn.Linear | nn.Conv3d): nn.init.normal_(m.weight, std=0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) + for m in [self.conv_logits, self.point_mlp_list, self.background_mlp_list]: m.apply(_init_weight) @@ -42,7 +49,7 @@ def forward( tokens_list: Sequence[Tensor], token_features_list: Sequence[Tensor], return_aux: bool = False, - ) -> Tuple[List[Tensor], Optional[List[List[Tensor]]]]: + ) -> tuple[list[Tensor], list[list[Tensor]] | None]: """Box Predicting Function Args: @@ -56,13 +63,16 @@ def forward( """ num_images = len(tokens_list) assert len(multi_scale_features[0]) == num_images - multi_scale_features = multi_scale_features[::-1] # Top-Down -> Bottom-Up + multi_scale_features = multi_scale_features[::-1] # Top-Down -> Bottom-Up if return_aux: out_masks_list = [] aux_masks_list = [] for image_idx in range(num_images): out = self.do_predict_w_aux( - [multi_scale_features[level][image_idx] for level in range(len(multi_scale_features))], + [ + multi_scale_features[level][image_idx] + for level in range(len(multi_scale_features)) + ], tokens_list[image_idx], token_features_list[image_idx], ) @@ -72,10 +82,14 @@ def forward( aux_masks_list = None out_masks_list = [ self.do_predict_single( - [multi_scale_features[level][image_idx] for level in range(len(multi_scale_features))], + [ + multi_scale_features[level][image_idx] + for level in range(len(multi_scale_features)) + ], tokens_list[image_idx], token_features_list[image_idx], - ) for image_idx in range(num_images) + ) + for image_idx in range(num_images) ] return out_masks_list, aux_masks_list @@ -84,7 +98,7 @@ def do_predict_w_aux( multi_scale_features: Sequence[Tensor], tokens: Tensor, token_features: Tensor, - ) -> List[Tensor]: + ) -> list[Tensor]: """Box Predicting Function Args: @@ -99,24 +113,37 @@ def do_predict_w_aux( multi_scale_size = [features.size()[1:] for features in multi_scale_features] if Nbox > 0: Dout, Hout, Wout = multi_scale_size[0] - token_indices = torch.split(tokens, 1, dim=1) # (x_list, y_list, z_list, i_list) + token_indices = torch.split( + tokens, 1, dim=1 + ) # (x_list, y_list, z_list, i_list) xs, ys, zs, _ = token_indices bottom_up_box_features = [] for level in range(len(multi_scale_features)): features = multi_scale_features[level] _, D, H, W = features.shape - _xs = torch.div(xs, Dout // D, rounding_mode='trunc') - _ys = torch.div(ys, Hout // H, rounding_mode='trunc') - _zs = torch.div(zs, Wout // W, rounding_mode='trunc') - box_features = self.get_box_features(features, (_xs, _ys, _zs), token_features, level) + _xs = torch.div(xs, Dout // D, rounding_mode="trunc") + _ys = torch.div(ys, Hout // H, rounding_mode="trunc") + _zs = torch.div(zs, Wout // W, rounding_mode="trunc") + box_features = self.get_box_features( + features, (_xs, _ys, _zs), token_features, level + ) bottom_up_box_features.append(box_features) top_down_features = self.decoder(bottom_up_box_features) - top_down_box_masks = [self.conv_logits(features).squeeze(1) for features in top_down_features] + top_down_box_masks = [ + self.conv_logits(features).squeeze(1) for features in top_down_features + ] return top_down_box_masks else: - return [torch.empty((0, *size), dtype=multi_scale_features[0].dtype, device=tokens.device) for size in multi_scale_size[::-1]] + return [ + torch.empty( + (0, *size), + dtype=multi_scale_features[0].dtype, + device=tokens.device, + ) + for size in multi_scale_size[::-1] + ] def do_predict_single( self, @@ -138,30 +165,38 @@ def do_predict_single( multi_scale_size = [features.size()[1:] for features in multi_scale_features] Dout, Hout, Wout = multi_scale_size[0] if Nbox > 0: - token_indices = torch.split(tokens, 1, dim=1) # (x_list, y_list, z_list, i_list) + token_indices = torch.split( + tokens, 1, dim=1 + ) # (x_list, y_list, z_list, i_list) xs, ys, zs, _ = token_indices bottom_up_box_features = [] for level in range(len(multi_scale_features)): features = multi_scale_features[level] _, D, H, W = features.shape - _xs = torch.div(xs, Dout // D, rounding_mode='trunc') - _ys = torch.div(ys, Hout // H, rounding_mode='trunc') - _zs = torch.div(zs, Wout // W, rounding_mode='trunc') - box_features = self.get_box_features(features, (_xs, _ys, _zs), token_features, level) + _xs = torch.div(xs, Dout // D, rounding_mode="trunc") + _ys = torch.div(ys, Hout // H, rounding_mode="trunc") + _zs = torch.div(zs, Wout // W, rounding_mode="trunc") + box_features = self.get_box_features( + features, (_xs, _ys, _zs), token_features, level + ) bottom_up_box_features.append(box_features) top_down_features = self.decoder(bottom_up_box_features) return self.conv_logits(top_down_features[-1]).squeeze(1) else: - return torch.empty((0, Dout, Hout, Wout), dtype=multi_scale_features[0].dtype, device=tokens.device) + return torch.empty( + (0, Dout, Hout, Wout), + dtype=multi_scale_features[0].dtype, + device=tokens.device, + ) def get_box_features( self, features: Tensor, - token_indices: Tuple[Tensor, Tensor, Tensor], + token_indices: tuple[Tensor, Tensor, Tensor], token_features: Tensor, - level: int + level: int, ) -> Tensor: """Extract token features @@ -177,9 +212,13 @@ def get_box_features( xs, ys, zs = token_indices Nbox = token_features.size(0) Nboxs = torch.arange(Nbox, dtype=xs.dtype, device=xs.device) - background_features = self.background_mlp_list[level](token_features) # [Nbox, F] - point_features = self.point_mlp_list[level](token_features) # [Nbox, F] - box_features = background_features.view(Nbox, F, 1, 1, 1).repeat(1, 1, D, H, W) # [Nbox, F, D, H, W] + background_features = self.background_mlp_list[level]( + token_features + ) # [Nbox, F] + point_features = self.point_mlp_list[level](token_features) # [Nbox, F] + box_features = background_features.view(Nbox, F, 1, 1, 1).repeat( + 1, 1, D, H, W + ) # [Nbox, F, D, H, W] box_features[Nboxs, :, xs, ys, zs] += point_features features = features.unsqueeze(0) + box_features return features diff --git a/src/pmnet/network/nn/__init__.py b/src/pmnet/network/nn/__init__.py index 69a388d..f1874f8 100644 --- a/src/pmnet/network/nn/__init__.py +++ b/src/pmnet/network/nn/__init__.py @@ -1 +1 @@ -from .layers import * +from .layers import BaseConv3d diff --git a/src/pmnet/network/nn/layers.py b/src/pmnet/network/nn/layers.py index 5c64f45..cde4541 100644 --- a/src/pmnet/network/nn/layers.py +++ b/src/pmnet/network/nn/layers.py @@ -1,37 +1,50 @@ from torch import nn -from typing import Optional, Type - class BaseConv3d(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = None, - dilation: int = 1, - groups: int = 1, - norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm3d, - act_layer: Optional[Type[nn.Module]] = nn.ReLU, + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = None, + dilation: int = 1, + groups: int = 1, + norm_layer: type[nn.Module] | None = nn.BatchNorm3d, + act_layer: type[nn.Module] | None = nn.ReLU, ): super().__init__() if padding is None: padding = (kernel_size - 1) // 2 - bias = (norm_layer is None) - self._conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) - self._norm = norm_layer(out_channels) if norm_layer is not None else nn.Identity() + bias = norm_layer is None + self._conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self._norm = ( + norm_layer(out_channels) if norm_layer is not None else nn.Identity() + ) self._act = act_layer() if act_layer is not None else nn.Identity() def initialize_weights(self): if isinstance(self._act, nn.LeakyReLU): a = self._act.negative_slope - nn.init.kaiming_normal_(self._conv.weight, a, mode='fan_out', nonlinearity='leaky_relu') + nn.init.kaiming_normal_( + self._conv.weight, a, mode="fan_out", nonlinearity="leaky_relu" + ) else: - nn.init.kaiming_normal_(self._conv.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_( + self._conv.weight, mode="fan_out", nonlinearity="relu" + ) if self._conv.bias is not None: - nn.init.constant_(self._conv.bias, 0.) + nn.init.constant_(self._conv.bias, 0.0) if not isinstance(self._norm, nn.Identity): nn.init.constant_(self._norm.weight, 1.0) diff --git a/src/pmnet/network/token_head.py b/src/pmnet/network/token_head.py index dd26148..2459559 100644 --- a/src/pmnet/network/token_head.py +++ b/src/pmnet/network/token_head.py @@ -1,7 +1,7 @@ import torch from torch import nn -from typing import Sequence, List, Tuple +from collections.abc import Sequence from torch import Tensor from .builder import HEAD @@ -17,7 +17,7 @@ def __init__( num_feature_mlp_layers: int, num_score_mlp_layers: int, ): - super(TokenHead, self).__init__() + super().__init__() self.interaction_embedding = nn.Embedding(num_interactions, feature_dim) self.token_feature_dim = token_feature_dim @@ -43,7 +43,7 @@ def __init__( def initialize_weights(self): def _init_weight(m): if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=.01) + nn.init.normal_(m.weight, std=0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) @@ -51,7 +51,9 @@ def _init_weight(m): for m in [self.feature_mlp, self.score_mlp, self.skip]: m.apply(_init_weight) - def forward(self, features: Tensor, tokens_list: Sequence[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + def forward( + self, features: Tensor, tokens_list: Sequence[Tensor] + ) -> tuple[list[Tensor], list[Tensor]]: """Token Scoring Function Args: @@ -63,8 +65,14 @@ def forward(self, features: Tensor, tokens_list: Sequence[Tensor]) -> Tuple[List token_features_list: List[FloatTensor [Ntoken, F]] """ num_images = len(tokens_list) - token_features_list = [self.extract_token_features(features[idx], tokens_list[idx]) for idx in range(num_images)] - token_scores_list = [self.score_mlp(token_features).squeeze(-1) for token_features in token_features_list] + token_features_list = [ + self.extract_token_features(features[idx], tokens_list[idx]) + for idx in range(num_images) + ] + token_scores_list = [ + self.score_mlp(token_features).squeeze(-1) + for token_features in token_features_list + ] return token_scores_list, token_features_list def extract_token_features(self, features: Tensor, tokens: Tensor) -> Tensor: @@ -78,11 +86,19 @@ def extract_token_features(self, features: Tensor, tokens: Tensor) -> Tensor: token_features: FloatTensor [Ntoken, Fh] """ if tokens.size(0) == 0: - return torch.empty([0, self.token_feature_dim], dtype=torch.float, device=features.device) + return torch.empty( + [0, self.token_feature_dim], dtype=torch.float, device=features.device + ) else: - features = features.permute(1, 2, 3, 0).contiguous() # [D, H, W, F] - x_list, y_list, z_list, i_list = torch.split(tokens, 1, dim=1) # (x_list, y_list, z_list, i_list) - token_features = features[x_list, y_list, z_list].squeeze(1) # [Ntoken, F] - embeddings = self.interaction_embedding(i_list).squeeze(1) # [Ntoken, F] - token_features = torch.cat([token_features, embeddings], dim=1) # [Ntoken, 2F] - return self.skip(token_features) + self.feature_mlp(token_features) # [Ntoken, Fh] + features = features.permute(1, 2, 3, 0).contiguous() # [D, H, W, F] + x_list, y_list, z_list, i_list = torch.split( + tokens, 1, dim=1 + ) # (x_list, y_list, z_list, i_list) + token_features = features[x_list, y_list, z_list].squeeze(1) # [Ntoken, F] + embeddings = self.interaction_embedding(i_list).squeeze(1) # [Ntoken, F] + token_features = torch.cat( + [token_features, embeddings], dim=1 + ) # [Ntoken, 2F] + return self.skip(token_features) + self.feature_mlp( + token_features + ) # [Ntoken, Fh] diff --git a/src/pmnet/network/utils/registry.py b/src/pmnet/network/utils/registry.py index 13c9a71..814bfc7 100644 --- a/src/pmnet/network/utils/registry.py +++ b/src/pmnet/network/utils/registry.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, Optional, Callable +from typing import Any from collections.abc import MutableMapping import copy -def convert_dict_key_to_lower_case(dictionary: MutableMapping) -> Dict: +def convert_dict_key_to_lower_case(dictionary: MutableMapping) -> dict: out = {} for key, value in dictionary.items(): if isinstance(value, MutableMapping): @@ -12,12 +12,12 @@ def convert_dict_key_to_lower_case(dictionary: MutableMapping) -> Dict: return out -class Registry(): +class Registry: __OBJ_DICT__ = {} def __init__(self, name: str): self._name = name - self._module_dict: Dict[str, Any] = dict() + self._module_dict: dict[str, Any] = dict() self.__OBJ_DICT__[name] = self def __len__(self): @@ -42,12 +42,10 @@ def get(self, name: str) -> Any: def _do_register(self, name: str, obj: Any) -> None: assert ( name not in self._module_dict - ), "An object named '{}' was already registered in '{}' registry!".format( - name, self._name - ) + ), f"An object named '{name}' was already registered in '{self._name}' registry!" self._module_dict[name] = obj - def register(self, module: Optional[Any] = None) -> Any: + def register(self, module: Any | None = None) -> Any: if module is None: def deco(_module: Any) -> Any: name = _module.__name__ diff --git a/src/pmnet/scoring/match_utils.py b/src/pmnet/scoring/match_utils.py index e5ced16..91f7028 100644 --- a/src/pmnet/scoring/match_utils.py +++ b/src/pmnet/scoring/match_utils.py @@ -79,7 +79,7 @@ def scoring_matching_pair( return tuple( float(score) if num_fail <= match_threshold else -1 - for score, num_fail in zip(match_scores, num_fails) + for score, num_fail in zip(match_scores, num_fails, strict=False) ) diff --git a/src/pmnet/utils/density_map.py b/src/pmnet/utils/density_map.py index 2624c82..e547df2 100644 --- a/src/pmnet/utils/density_map.py +++ b/src/pmnet/utils/density_map.py @@ -85,7 +85,7 @@ def __extract_pharmacophores( grid_scores: list[float] - [score] """ x_indices, y_indices, z_indices = np.where(mask > 0.0) - points = {(int(x), int(y), int(z)) for x, y, z in zip(x_indices, y_indices, z_indices)} + points = {(int(x), int(y), int(z)) for x, y, z in zip(x_indices, y_indices, z_indices, strict=False)} while len(points) > 0: point = x, y, z = points.pop() cluster = [point] diff --git a/src/pmnet/utils/smoothing.py b/src/pmnet/utils/smoothing.py index b6074a0..3073116 100644 --- a/src/pmnet/utils/smoothing.py +++ b/src/pmnet/utils/smoothing.py @@ -20,7 +20,7 @@ def __init__( kernel_size: int | tuple[int, int, int], sigma: float | tuple[float, float, float], ): - super(GaussianSmoothing, self).__init__() + super().__init__() kernel_size = to_3tuple(kernel_size) sigma = to_3tuple(sigma) @@ -31,7 +31,7 @@ def __init__( indexing="ij", ) kernel: torch.Tensor = None - for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + for size, std, mgrid in zip(kernel_size, sigma, meshgrids, strict=False): mean = (size - 1) / 2 # _kernel = 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / (2 * std)) ** 2) _kernel = torch.exp( diff --git a/src/pmnet_appl/sbddreward/network/pharmacophore_encoder.py b/src/pmnet_appl/sbddreward/network/pharmacophore_encoder.py index 97daf21..6bb77d3 100644 --- a/src/pmnet_appl/sbddreward/network/pharmacophore_encoder.py +++ b/src/pmnet_appl/sbddreward/network/pharmacophore_encoder.py @@ -39,7 +39,7 @@ def forward(self, pmnet_attr: tuple[MultiScaleFeature, list[HotspotInfo]]) -> tu hotspot_positions = torch.zeros((0, 3), device=dev) hotspot_features = torch.zeros((0, self.hidden_dim), device=dev) pocket_features: Tensor = torch.cat( - [mlp(feat.squeeze(0)).mean((-1, -2, -3)) for mlp, feat in zip(self.pocket_mlp_list, multi_scale_features)], + [mlp(feat.squeeze(0)).mean((-1, -2, -3)) for mlp, feat in zip(self.pocket_mlp_list, multi_scale_features, strict=False)], dim=-1, ) pocket_features = self.pocket_layer(pocket_features) diff --git a/src/pmnet_appl/sbddreward/proxy.py b/src/pmnet_appl/sbddreward/proxy.py index 7b8e2ae..c03a8c6 100644 --- a/src/pmnet_appl/sbddreward/proxy.py +++ b/src/pmnet_appl/sbddreward/proxy.py @@ -22,7 +22,11 @@ from pmnet.api.typing import HotspotInfo, MultiScaleFeature from pmnet_appl.base.proxy import BaseProxy -from pmnet_appl.sbddreward.network import PharmacophoreEncoder, GraphEncoder, AffinityHead +from pmnet_appl.sbddreward.network import ( + PharmacophoreEncoder, + GraphEncoder, + AffinityHead, +) from pmnet_appl.sbddreward.data import NUM_ATOM_FEATURES, NUM_BOND_FEATURES, smi2graph @@ -35,30 +39,44 @@ class SBDDReward_Proxy(BaseProxy): def _setup_model(self): self.model = _RewardNetwork() - def _get_cache(self, pmnet_attr: tuple[MultiScaleFeature, list[HotspotInfo]]) -> Cache: + def _get_cache( + self, pmnet_attr: tuple[MultiScaleFeature, list[HotspotInfo]] + ) -> Cache: return self.model.get_cache(pmnet_attr) @torch.no_grad() - def _scoring_list(self, cache: Cache, smiles_list: list[str], return_sigma: bool = False) -> Tensor: - cache = (cache[0].to(self.device), cache[1].to(self.device), cache[2].to(self.device), cache[3], cache[4]) + def _scoring_list( + self, cache: Cache, smiles_list: list[str], return_sigma: bool = False + ) -> Tensor: + cache = ( + cache[0].to(self.device), + cache[1].to(self.device), + cache[2].to(self.device), + cache[3], + cache[4], + ) ligand_graphs = [] flag = [] for smi in smiles_list: try: graph = smi2graph(smi) - except: + except Exception: flag.append(False) else: flag.append(True) ligand_graphs.append(graph) if not any(flag): - return torch.zeros(len(smiles_list), dtype=torch.float32, device=self.device) + return torch.zeros( + len(smiles_list), dtype=torch.float32, device=self.device + ) ligand_batch: gd.Batch = gd.Batch.from_data_list(ligand_graphs).to(self.device) if all(flag): return self.model.scoring(cache, ligand_batch, return_sigma) else: - result = torch.zeros(len(smiles_list), dtype=torch.float32, device=self.device) + result = torch.zeros( + len(smiles_list), dtype=torch.float32, device=self.device + ) result[flag] = self.model.scoring(cache, ligand_batch, return_sigma) return result @@ -110,7 +128,9 @@ def scoring(self, target: str, smiles: str, return_sigma: bool = False) -> Tenso """ return self._scoring_list(self._cache[target], [smiles], return_sigma) - def scoring_list(self, target: str, smiles_list: list[str], return_sigma: bool = False) -> Tensor: + def scoring_list( + self, target: str, smiles_list: list[str], return_sigma: bool = False + ) -> Tensor: """Scoring multiple molecules with their SMILES Parameters @@ -139,27 +159,43 @@ class _RewardNetwork(nn.Module): def __init__(self): super().__init__() self.pharmacophore_encoder: PharmacophoreEncoder = PharmacophoreEncoder(128) - self.ligand_encoder: GraphEncoder = GraphEncoder(NUM_ATOM_FEATURES, NUM_BOND_FEATURES, 128, 128, 4) + self.ligand_encoder: GraphEncoder = GraphEncoder( + NUM_ATOM_FEATURES, NUM_BOND_FEATURES, 128, 128, 4 + ) self.head: AffinityHead = AffinityHead(128, 3) def get_cache(self, pmnet_attr) -> Cache: - X_protein, pos_protein, Z_protein = self.pharmacophore_encoder.forward(pmnet_attr) + X_protein, pos_protein, Z_protein = self.pharmacophore_encoder.forward( + pmnet_attr + ) mu, std = self.head.cal_mu(Z_protein), self.head.cal_std(Z_protein) - return X_protein.cpu(), pos_protein.cpu(), Z_protein.cpu(), mu.item(), std.item() + return ( + X_protein.cpu(), + pos_protein.cpu(), + Z_protein.cpu(), + mu.item(), + std.item(), + ) def scoring(self, cache: Cache, ligand_batch: gd.Batch, return_sigma: bool = False): X_protein, pos_protein, Z_protein, mu, std = cache X_ligand, Z_ligand = self.ligand_encoder.forward(ligand_batch) - sigma = self.head.cal_sigma(X_protein, pos_protein, Z_protein, X_ligand, Z_ligand, ligand_batch.batch) + sigma = self.head.cal_sigma( + X_protein, pos_protein, Z_protein, X_ligand, Z_ligand, ligand_batch.batch + ) if return_sigma: return sigma else: return sigma * std + mu - def get_info(self, cache: Cache, ligand_batch: gd.Batch) -> tuple[float, float, Tensor]: + def get_info( + self, cache: Cache, ligand_batch: gd.Batch + ) -> tuple[float, float, Tensor]: X_protein, pos_protein, Z_protein, mu, std = cache X_ligand, Z_ligand = self.ligand_encoder.forward(ligand_batch) - sigma = self.head.cal_sigma(X_protein, pos_protein, Z_protein, X_ligand, Z_ligand, ligand_batch.batch) + sigma = self.head.cal_sigma( + X_protein, pos_protein, Z_protein, X_ligand, Z_ligand, ligand_batch.batch + ) return mu, std, sigma diff --git a/src/pmnet_appl/tacogfn_reward/proxy.py b/src/pmnet_appl/tacogfn_reward/proxy.py index 958484a..63b39c0 100644 --- a/src/pmnet_appl/tacogfn_reward/proxy.py +++ b/src/pmnet_appl/tacogfn_reward/proxy.py @@ -201,7 +201,7 @@ def ready_to_calculate( token_features_list: List[FloatTensor [Nbox, F_hidden]] """ multi_scale_features = multi_scale_features[::-1] # Top-Down -> Bottom-Up - multi_scale_features = [layer(feature) for layer, feature in zip(self.pocket_mlp_list, multi_scale_features)] + multi_scale_features = [layer(feature) for layer, feature in zip(self.pocket_mlp_list, multi_scale_features, strict=False)] pocket_features: Tensor = self.pocket_mlp( torch.cat([feature.mean(dim=(-1, -2, -3)) for feature in multi_scale_features], dim=-1) ) # [N, Fh] diff --git a/utils/parse_rcsb_pdb.py b/utils/parse_rcsb_pdb.py index a18c720..05fa28c 100644 --- a/utils/parse_rcsb_pdb.py +++ b/utils/parse_rcsb_pdb.py @@ -1,3 +1,4 @@ +from pathlib import Path from openbabel import pybel import pymol import numpy as np @@ -5,9 +6,7 @@ from dataclasses import dataclass from urllib.request import urlopen - -from os import PathLike -from typing import List, Tuple, Optional +PathLike = str | Path @dataclass @@ -17,10 +16,10 @@ class LigandInform: pdbchain: str authchain: str residx: int - center: Tuple[float, float, float] - file_path: PathLike[str] - name: Optional[str] - synonyms: Optional[str] + center: tuple[float, float, float] + file_path: PathLike + name: str | None + synonyms: str | None def __str__(self) -> str: x, y, z = self.center @@ -36,7 +35,7 @@ def __str__(self) -> str: return string -def download_pdb(pdb_code: str, output_file: PathLike[str]): +def download_pdb(pdb_code: str, output_file: PathLike): url = f"https://files.rcsb.org/download/{pdb_code.lower()}.pdb" try: with urlopen(url) as response: @@ -47,13 +46,15 @@ def download_pdb(pdb_code: str, output_file: PathLike[str]): print(f"Error downloading PDB file: {e}") -def parse_pdb(pdb_code: str, protein_path: PathLike[str], save_dir: PathLike[str]) -> List[LigandInform]: +def parse_pdb( + pdb_code: str, protein_path: PathLike, save_dir: PathLike +) -> list[LigandInform]: protein: pybel.Molecule = next(pybel.readfile("pdb", str(protein_path))) if "HET" not in protein.data.keys(): return [] - het_lines = protein.data["HET"].split("\n") - hetnam_lines = protein.data["HETNAM"].split("\n") + het_lines: list[str] = protein.data["HET"].split("\n") + hetnam_lines: list[str] = protein.data["HETNAM"].split("\n") if "HETSYN" in protein.data.keys(): hetsyn_lines = protein.data["HETSYN"].split("\n") else: @@ -63,7 +64,7 @@ def parse_pdb(pdb_code: str, protein_path: PathLike[str], save_dir: PathLike[str ligand_name_dict = {} for line in hetnam_lines: - line: str = line.strip() + line = line.strip() if line.startswith(het_id_list): key, *strings = line.split() assert key not in ligand_name_dict @@ -105,7 +106,11 @@ def parse_pdb(pdb_code: str, protein_path: PathLike[str], save_dir: PathLike[str ligid, authchain, residue_idx, - ) = vs[0], vs[1][0], vs[1][1:] + ) = ( + vs[0], + vs[1][0], + vs[1][1:], + ) pdbchain = chr(ord(last_chain) + idx + 1) identify_key = f"{pdb_code}_{pdbchain}_{ligid}" ligand_path = os.path.join(save_dir, f"{identify_key}.pdb") diff --git a/utils/visualize.py b/utils/visualize.py index e877a43..1f52848 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -1,6 +1,5 @@ import os import argparse -import logging import tempfile import pymol @@ -9,7 +8,6 @@ from pmnet import PharmacophoreModel -from typing import Optional, Dict, Tuple class Visualize_ArgParser(argparse.ArgumentParser): @@ -49,8 +47,8 @@ def __init__(self): def visualize_single( model: PharmacophoreModel, - protein_path: Optional[str], - ligand_path: Optional[str], + protein_path: str | None, + ligand_path: str | None, prefix: str, save_path: str, ): @@ -108,7 +106,7 @@ def visualize_single( cmd.color('gray90', f'{prefix}Protein and (name C*)') cmd.set('sphere_scale', 0.3, '*hotspot*') - cmd.set('sphere_transparency', 0.2, f'*point*') + cmd.set('sphere_transparency', 0.2, '*point*') cmd.set('dash_gap', 0.2, '*interaction*') cmd.set('dash_length', 0.4, '*interaction*') cmd.hide('label', '*interaction*') @@ -124,7 +122,7 @@ def visualize_single( def visualize_multiple( - model_dict: Dict[str, Tuple[PharmacophoreModel, str]], + model_dict: dict[str, tuple[PharmacophoreModel, str]], protein_path: str, pdb: str, save_path: str,