Skip to content

Commit

Permalink
Merge pull request #7 from BAMeScience/dev
Browse files Browse the repository at this point in the history
Small hotfixes
  • Loading branch information
ch4perone authored Aug 29, 2024
2 parents 2cbaf18 + ecb156d commit ab2b7e9
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 51 deletions.
21 changes: 20 additions & 1 deletion fiora/GNN/GNNModules.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def load_from_state_dict(cls, PATH: str) -> 'GNNCompiler':
with open(PARAMS_PATH, 'r') as fp:
params = json.load(fp)
model = GNNCompiler(params)
model.load_state_dict(torch.load(STATE_PATH))
model.load_state_dict(torch.load(STATE_PATH, map_location=torch.serialization.default_restore_location, weights_only=True))

if not isinstance(model, cls):
raise ValueError(f'file {PATH} contains incorrect model class {type(model)}')
Expand All @@ -291,3 +291,22 @@ def save(self, PATH: str, dev: str="cpu") -> None:

#Reset to previous device
self.to(prev_device)

# def save(self, PATH: str, dev: str="cpu") -> None:

# # Set device to cpu for saving
# prev_device = next(self.parameters()).device
# self.to(dev)
# with open(PATH, 'wb') as f:
# dill.dump(self.to(dev), f)

# params_path = '.'.join(PATH.split('.')[:-1]) + '_params.json'
# with open(params_path, 'w') as fp:
# json.dump(self.model_params, fp)

# state_dict_path = params_path.replace("_params.json", "_state.pt")
# self.to(dev)

# torch.save(self.state_dict(), state_dict_path, _use_new_zipfile_serialization=False)

# self.to(prev_device)
Empty file added models/__init__.py
Empty file.
69 changes: 40 additions & 29 deletions notebooks/test_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"[10:21:18] WARNING: not removing hydrogen atom without neighbors\n",
"[10:21:18] WARNING: not removing hydrogen atom without neighbors\n",
"[10:21:18] WARNING: not removing hydrogen atom without neighbors\n"
"[11:46:34] WARNING: not removing hydrogen atom without neighbors\n",
"[11:46:34] WARNING: not removing hydrogen atom without neighbors\n",
"[11:46:34] WARNING: not removing hydrogen atom without neighbors\n"
]
},
{
Expand Down Expand Up @@ -68,7 +68,16 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ynowatzk/miniforge3/envs/fiora/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n"
]
}
],
"source": [
"depth = 6\n",
"\n",
Expand All @@ -85,16 +94,12 @@
"from fiora.GNN.GNNModules import GNNCompiler\n",
"from fiora.MS.SimulationFramework import SimulationFramework\n",
"\n",
"try: \n",
" model = GNNCompiler.load(MODEL_PATH)\n",
"\n",
"try:\n",
" model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n",
"except:\n",
" try:\n",
" print(f\"Warning: Failed loading the model {MODEL_PATH}. Fall back: Loading the model from state dictionary.\")\n",
" model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n",
" print(\"Model loaded from state dict without further errors.\")\n",
" except:\n",
" raise NameError(\"Error: Failed loading from state dict.\")\n",
" "
" raise NameError(\"Error: Failed loading from state dict.\")\n",
" "
]
},
{
Expand Down Expand Up @@ -289,13 +294,13 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_937446/3302879299.py:24: SettingWithCopyWarning: \n",
"/tmp/ipykernel_3323114/3302879299.py:24: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df_train[\"Metabolite\"] = df_train[\"SMILES\"].apply(Metabolite) # TRAIN Metabolites are only tracked for tanimoto distance\n",
"/tmp/ipykernel_937446/3302879299.py:25: SettingWithCopyWarning: \n",
"/tmp/ipykernel_3323114/3302879299.py:25: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
Expand Down Expand Up @@ -655,7 +660,13 @@
"output_type": "stream",
"text": [
"/home/ynowatzk/repos/fiora/fiora/MS/SimulationFramework.py:170: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n",
" df = pd.concat([df, pd.DataFrame(columns=[x + suffix for x in [\"cosine_similarity\", \"kl_div\", \"sim_peaks\", \"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"spectral_refl_cosine\", \"spectral_bias\", \"spectral_sqrt_bias\", \"spectral_sqrt_bias_wo_prec\", \"spectral_refl_bias\", \"steins_cosine\", \"steins_bias\", \"RT_pred\", \"RT_dif\", \"CCS_pred\"]])])\n",
" df = pd.concat([df, pd.DataFrame(columns=[x + suffix for x in [\"cosine_similarity\", \"kl_div\", \"sim_peaks\", \"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"spectral_refl_cosine\", \"spectral_bias\", \"spectral_sqrt_bias\", \"spectral_sqrt_bias_wo_prec\", \"spectral_refl_bias\", \"steins_cosine\", \"steins_bias\", \"RT_pred\", \"RT_dif\", \"CCS_pred\"]])])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ynowatzk/repos/fiora/fiora/MS/spectral_scores.py:26: RuntimeWarning: divide by zero encountered in divide\n",
" vec = vec / np.linalg.norm(vec)\n",
"/home/ynowatzk/repos/fiora/fiora/MS/spectral_scores.py:27: RuntimeWarning: divide by zero encountered in divide\n",
Expand Down Expand Up @@ -1650,7 +1661,7 @@
" <td>0.717736</td>\n",
" <td>0.780847</td>\n",
" <td>0.273686</td>\n",
" <td>0.307230</td>\n",
" <td>0.307231</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
Expand Down Expand Up @@ -1687,7 +1698,7 @@
"2 ICEBERG 0.716473 NaN 0.575536 NaN 0.709359 NaN \n",
"\n",
" CASMI22+ CASMI22- \n",
"0 0.273686 0.307230 \n",
"0 0.273686 0.307231 \n",
"1 0.376012 0.292884 \n",
"2 0.358571 0.000000 "
]
Expand Down Expand Up @@ -2011,17 +2022,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"cos: 0.7399321531311543 0.5158647138450783\n",
"cos: 0.7399321531311543 0.3008378811624405\n",
"cos: 0.7399321531311543 0.3014740423593587\n",
"cos: 0.5824634669347201 0.5158647138450783\n",
"cos: 0.5824634669347201 0.3008378811624405\n",
"cos: 0.5824634669347201 0.3014740423593587\n",
"cos: 0.5824634893228796 0.5158647138450783\n",
"cos: 0.5824634893228796 0.3008378811624405\n",
"cos: 0.5824634893228796 0.3014740423593587\n",
"cos: 0.5780898248582669 0.4830928581051953\n",
"cos: 0.49608554398224664 0.4830928581051953\n"
"cos: 0.7399321707609304 0.5158647125171295\n",
"cos: 0.7399321707609304 0.3008378766012895\n",
"cos: 0.7399321707609304 0.30147404234591146\n",
"cos: 0.5824634973655437 0.5158647125171295\n",
"cos: 0.5824634973655437 0.3008378766012895\n",
"cos: 0.5824634973655437 0.30147404234591146\n",
"cos: 0.5824634973655437 0.5158647125171295\n",
"cos: 0.5824634973655437 0.3008378766012895\n",
"cos: 0.5824634973655437 0.30147404234591146\n",
"cos: 0.5780898227574808 0.48309285488696774\n",
"cos: 0.4960855687939471 0.48309285488696774\n"
]
}
],
Expand Down
28 changes: 14 additions & 14 deletions scripts/fiora-predict
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#! /usr/bin/env python
import pandas as pd
import os
import warnings
from rdkit import RDLogger
RDLogger.DisableLog("rdApp.*")
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")

import argparse
import importlib.resources
import fiora.IO.mgfWriter as mgfWriter
import fiora.IO.mspWriter as mspWriter

Expand All @@ -15,7 +18,6 @@ from fiora.GNN.AtomFeatureEncoder import AtomFeatureEncoder
from fiora.GNN.BondFeatureEncoder import BondFeatureEncoder
from fiora.GNN.SetupFeatureEncoder import SetupFeatureEncoder


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(prog='fiora-predict',
description='Fiora is an in silico fragmentation framework, which predicts peaks and simulates tandem mass spectra including features such as retention time and collision cross sections. Use this script for spectrum predictions with a (pre-)trained model.',
Expand All @@ -28,7 +30,7 @@ def parse_args() -> argparse.Namespace:


parser.add_argument('--rt', action=argparse.BooleanOptionalAction, help="predict retention time", default=False)
parser.add_argument('--ccs', action=argparse.BooleanOptionalAction, help="predict collison cross section", default=True)
parser.add_argument('--ccs', action=argparse.BooleanOptionalAction, help="predict collison cross section", default=False)
parser.add_argument("--annotation", action=argparse.BooleanOptionalAction, help="annotate predicted peaks with SMILES strings", default=False)
args = parser.parse_args()

Expand Down Expand Up @@ -102,19 +104,17 @@ def main():
print(f'Running fiora prediction with the following parameters: {args}\n')

# Load model
if args.model == "default":
args.model = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../models/fiora_OS_v0.1.0.pt")
if args.model == "default":
with importlib.resources.path('models', 'fiora_OS_v0.1.0.pt') as model_path:
args.model = str(model_path)
#os.path.join(os.path.dirname(os.path.abspath(__file__)), "../models/fiora_OS_v0.1.0.pt")

try:
model = GNNCompiler.load(args.model)
except:
try:
print(f"Warning: Failed loading the model {args.model}. Fallback: Loading the model from state dictionary.")
model = GNNCompiler.load_from_state_dict(args.model)
print("Model loaded from state dict without further errors.")
except Exception as e:
print(f"Error: Failed loading from state dict. Caused by: {e}.")
exit(1)
print(args.model)
try:
model = GNNCompiler.load_from_state_dict(args.model)
except Exception as e:
print(f"Error: Failed loading from model from state dict. Caused by: {e}.")
exit(1)

print_model_messages(model.model_params)
args = update_args_with_model_params(args, model.model_params)
Expand Down
13 changes: 6 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

setup(name='Fiora',
setup(name='fiora',
version='0.0.1',
long_description='file: README.md',
author='Yannek Nowatzky',
Expand All @@ -10,14 +10,13 @@
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
],
# entry_points={
# "console_scripts": [
# "fiora-predict = scripts.predict:main",
# ],
# },

scripts=["scripts/fiora-predict"],
packages=['fiora', 'fiora.GNN', 'fiora.IO', 'fiora.MOL', 'fiora.MS', 'fiora.visualization', 'models'],
include_package_data=True,
packages=['fiora', 'fiora.GNN', 'fiora.IO', 'fiora.MOL', 'fiora.MS', 'fiora.visualization'],
package_data={
'models': ['fiora_OS_v0.1.0.pt', 'fiora_OS_v0.1.0_state.pt', 'fiora_OS_v0.1.0_params.json'],
},
install_requires=['numpy', 'seaborn', 'torch', 'torch_geometric', 'dill', 'rdkit', 'treelib', 'spectrum_utils', 'setuptools>=24.2.0'],
python_requires='>=3.10.8',
# Developers may also want to install: jupyter torchmetrics umap umap-learn pytest
Expand Down

0 comments on commit ab2b7e9

Please sign in to comment.