diff --git a/fiora/GNN/GNNModules.py b/fiora/GNN/GNNModules.py index de35b02..2f22181 100644 --- a/fiora/GNN/GNNModules.py +++ b/fiora/GNN/GNNModules.py @@ -265,7 +265,7 @@ def load_from_state_dict(cls, PATH: str) -> 'GNNCompiler': with open(PARAMS_PATH, 'r') as fp: params = json.load(fp) model = GNNCompiler(params) - model.load_state_dict(torch.load(STATE_PATH)) + model.load_state_dict(torch.load(STATE_PATH, map_location=torch.serialization.default_restore_location, weights_only=True)) if not isinstance(model, cls): raise ValueError(f'file {PATH} contains incorrect model class {type(model)}') @@ -291,3 +291,22 @@ def save(self, PATH: str, dev: str="cpu") -> None: #Reset to previous device self.to(prev_device) + + # def save(self, PATH: str, dev: str="cpu") -> None: + + # # Set device to cpu for saving + # prev_device = next(self.parameters()).device + # self.to(dev) + # with open(PATH, 'wb') as f: + # dill.dump(self.to(dev), f) + + # params_path = '.'.join(PATH.split('.')[:-1]) + '_params.json' + # with open(params_path, 'w') as fp: + # json.dump(self.model_params, fp) + + # state_dict_path = params_path.replace("_params.json", "_state.pt") + # self.to(dev) + + # torch.save(self.state_dict(), state_dict_path, _use_new_zipfile_serialization=False) + + # self.to(prev_device) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/notebooks/test_model.ipynb b/notebooks/test_model.ipynb index ced1d25..43d0046 100644 --- a/notebooks/test_model.ipynb +++ b/notebooks/test_model.ipynb @@ -9,9 +9,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "[10:21:18] WARNING: not removing hydrogen atom without neighbors\n", - "[10:21:18] WARNING: not removing hydrogen atom without neighbors\n", - "[10:21:18] WARNING: not removing hydrogen atom without neighbors\n" + "[11:46:34] WARNING: not removing hydrogen atom without neighbors\n", + "[11:46:34] WARNING: not removing hydrogen atom without neighbors\n", + "[11:46:34] WARNING: not removing hydrogen atom without neighbors\n" ] }, { @@ -68,7 +68,16 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ynowatzk/miniforge3/envs/fiora/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n" + ] + } + ], "source": [ "depth = 6\n", "\n", @@ -85,16 +94,12 @@ "from fiora.GNN.GNNModules import GNNCompiler\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", - "try: \n", - " model = GNNCompiler.load(MODEL_PATH)\n", + "\n", + "try:\n", + " model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n", "except:\n", - " try:\n", - " print(f\"Warning: Failed loading the model {MODEL_PATH}. Fall back: Loading the model from state dictionary.\")\n", - " model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n", - " print(\"Model loaded from state dict without further errors.\")\n", - " except:\n", - " raise NameError(\"Error: Failed loading from state dict.\")\n", - " " + " raise NameError(\"Error: Failed loading from state dict.\")\n", + " " ] }, { @@ -289,13 +294,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_937446/3302879299.py:24: SettingWithCopyWarning: \n", + "/tmp/ipykernel_3323114/3302879299.py:24: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " df_train[\"Metabolite\"] = df_train[\"SMILES\"].apply(Metabolite) # TRAIN Metabolites are only tracked for tanimoto distance\n", - "/tmp/ipykernel_937446/3302879299.py:25: SettingWithCopyWarning: \n", + "/tmp/ipykernel_3323114/3302879299.py:25: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", @@ -655,7 +660,13 @@ "output_type": "stream", "text": [ "/home/ynowatzk/repos/fiora/fiora/MS/SimulationFramework.py:170: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n", - " df = pd.concat([df, pd.DataFrame(columns=[x + suffix for x in [\"cosine_similarity\", \"kl_div\", \"sim_peaks\", \"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"spectral_refl_cosine\", \"spectral_bias\", \"spectral_sqrt_bias\", \"spectral_sqrt_bias_wo_prec\", \"spectral_refl_bias\", \"steins_cosine\", \"steins_bias\", \"RT_pred\", \"RT_dif\", \"CCS_pred\"]])])\n", + " df = pd.concat([df, pd.DataFrame(columns=[x + suffix for x in [\"cosine_similarity\", \"kl_div\", \"sim_peaks\", \"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"spectral_refl_cosine\", \"spectral_bias\", \"spectral_sqrt_bias\", \"spectral_sqrt_bias_wo_prec\", \"spectral_refl_bias\", \"steins_cosine\", \"steins_bias\", \"RT_pred\", \"RT_dif\", \"CCS_pred\"]])])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ "/home/ynowatzk/repos/fiora/fiora/MS/spectral_scores.py:26: RuntimeWarning: divide by zero encountered in divide\n", " vec = vec / np.linalg.norm(vec)\n", "/home/ynowatzk/repos/fiora/fiora/MS/spectral_scores.py:27: RuntimeWarning: divide by zero encountered in divide\n", @@ -1650,7 +1661,7 @@ " 0.717736\n", " 0.780847\n", " 0.273686\n", - " 0.307230\n", + " 0.307231\n", " \n", " \n", " 1\n", @@ -1687,7 +1698,7 @@ "2 ICEBERG 0.716473 NaN 0.575536 NaN 0.709359 NaN \n", "\n", " CASMI22+ CASMI22- \n", - "0 0.273686 0.307230 \n", + "0 0.273686 0.307231 \n", "1 0.376012 0.292884 \n", "2 0.358571 0.000000 " ] @@ -2011,17 +2022,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "cos: 0.7399321531311543 0.5158647138450783\n", - "cos: 0.7399321531311543 0.3008378811624405\n", - "cos: 0.7399321531311543 0.3014740423593587\n", - "cos: 0.5824634669347201 0.5158647138450783\n", - "cos: 0.5824634669347201 0.3008378811624405\n", - "cos: 0.5824634669347201 0.3014740423593587\n", - "cos: 0.5824634893228796 0.5158647138450783\n", - "cos: 0.5824634893228796 0.3008378811624405\n", - "cos: 0.5824634893228796 0.3014740423593587\n", - "cos: 0.5780898248582669 0.4830928581051953\n", - "cos: 0.49608554398224664 0.4830928581051953\n" + "cos: 0.7399321707609304 0.5158647125171295\n", + "cos: 0.7399321707609304 0.3008378766012895\n", + "cos: 0.7399321707609304 0.30147404234591146\n", + "cos: 0.5824634973655437 0.5158647125171295\n", + "cos: 0.5824634973655437 0.3008378766012895\n", + "cos: 0.5824634973655437 0.30147404234591146\n", + "cos: 0.5824634973655437 0.5158647125171295\n", + "cos: 0.5824634973655437 0.3008378766012895\n", + "cos: 0.5824634973655437 0.30147404234591146\n", + "cos: 0.5780898227574808 0.48309285488696774\n", + "cos: 0.4960855687939471 0.48309285488696774\n" ] } ], diff --git a/scripts/fiora-predict b/scripts/fiora-predict index e32b990..cbf18e1 100644 --- a/scripts/fiora-predict +++ b/scripts/fiora-predict @@ -1,10 +1,13 @@ #! /usr/bin/env python import pandas as pd import os +import warnings from rdkit import RDLogger RDLogger.DisableLog("rdApp.*") +warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") import argparse +import importlib.resources import fiora.IO.mgfWriter as mgfWriter import fiora.IO.mspWriter as mspWriter @@ -15,7 +18,6 @@ from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder from fiora.GNN.SetupFeatureEncoder import SetupFeatureEncoder - def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(prog='fiora-predict', description='Fiora is an in silico fragmentation framework, which predicts peaks and simulates tandem mass spectra including features such as retention time and collision cross sections. Use this script for spectrum predictions with a (pre-)trained model.', @@ -28,7 +30,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument('--rt', action=argparse.BooleanOptionalAction, help="predict retention time", default=False) - parser.add_argument('--ccs', action=argparse.BooleanOptionalAction, help="predict collison cross section", default=True) + parser.add_argument('--ccs', action=argparse.BooleanOptionalAction, help="predict collison cross section", default=False) parser.add_argument("--annotation", action=argparse.BooleanOptionalAction, help="annotate predicted peaks with SMILES strings", default=False) args = parser.parse_args() @@ -102,19 +104,17 @@ def main(): print(f'Running fiora prediction with the following parameters: {args}\n') # Load model - if args.model == "default": - args.model = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../models/fiora_OS_v0.1.0.pt") + if args.model == "default": + with importlib.resources.path('models', 'fiora_OS_v0.1.0.pt') as model_path: + args.model = str(model_path) + #os.path.join(os.path.dirname(os.path.abspath(__file__)), "../models/fiora_OS_v0.1.0.pt") - try: - model = GNNCompiler.load(args.model) - except: - try: - print(f"Warning: Failed loading the model {args.model}. Fallback: Loading the model from state dictionary.") - model = GNNCompiler.load_from_state_dict(args.model) - print("Model loaded from state dict without further errors.") - except Exception as e: - print(f"Error: Failed loading from state dict. Caused by: {e}.") - exit(1) + print(args.model) + try: + model = GNNCompiler.load_from_state_dict(args.model) + except Exception as e: + print(f"Error: Failed loading from model from state dict. Caused by: {e}.") + exit(1) print_model_messages(model.model_params) args = update_args_with_model_params(args, model.model_params) diff --git a/setup.py b/setup.py index 58f2e63..167b9ed 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -setup(name='Fiora', +setup(name='fiora', version='0.0.1', long_description='file: README.md', author='Yannek Nowatzky', @@ -10,14 +10,13 @@ 'Operating System :: OS Independent', 'Programming Language :: Python :: 3', ], - # entry_points={ - # "console_scripts": [ - # "fiora-predict = scripts.predict:main", - # ], - # }, + scripts=["scripts/fiora-predict"], + packages=['fiora', 'fiora.GNN', 'fiora.IO', 'fiora.MOL', 'fiora.MS', 'fiora.visualization', 'models'], include_package_data=True, - packages=['fiora', 'fiora.GNN', 'fiora.IO', 'fiora.MOL', 'fiora.MS', 'fiora.visualization'], + package_data={ + 'models': ['fiora_OS_v0.1.0.pt', 'fiora_OS_v0.1.0_state.pt', 'fiora_OS_v0.1.0_params.json'], + }, install_requires=['numpy', 'seaborn', 'torch', 'torch_geometric', 'dill', 'rdkit', 'treelib', 'spectrum_utils', 'setuptools>=24.2.0'], python_requires='>=3.10.8', # Developers may also want to install: jupyter torchmetrics umap umap-learn pytest