diff --git a/fiora/GNN/GNNModules.py b/fiora/GNN/GNNModules.py index de35b02..2f22181 100644 --- a/fiora/GNN/GNNModules.py +++ b/fiora/GNN/GNNModules.py @@ -265,7 +265,7 @@ def load_from_state_dict(cls, PATH: str) -> 'GNNCompiler': with open(PARAMS_PATH, 'r') as fp: params = json.load(fp) model = GNNCompiler(params) - model.load_state_dict(torch.load(STATE_PATH)) + model.load_state_dict(torch.load(STATE_PATH, map_location=torch.serialization.default_restore_location, weights_only=True)) if not isinstance(model, cls): raise ValueError(f'file {PATH} contains incorrect model class {type(model)}') @@ -291,3 +291,22 @@ def save(self, PATH: str, dev: str="cpu") -> None: #Reset to previous device self.to(prev_device) + + # def save(self, PATH: str, dev: str="cpu") -> None: + + # # Set device to cpu for saving + # prev_device = next(self.parameters()).device + # self.to(dev) + # with open(PATH, 'wb') as f: + # dill.dump(self.to(dev), f) + + # params_path = '.'.join(PATH.split('.')[:-1]) + '_params.json' + # with open(params_path, 'w') as fp: + # json.dump(self.model_params, fp) + + # state_dict_path = params_path.replace("_params.json", "_state.pt") + # self.to(dev) + + # torch.save(self.state_dict(), state_dict_path, _use_new_zipfile_serialization=False) + + # self.to(prev_device) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/notebooks/test_model.ipynb b/notebooks/test_model.ipynb index ced1d25..43d0046 100644 --- a/notebooks/test_model.ipynb +++ b/notebooks/test_model.ipynb @@ -9,9 +9,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "[10:21:18] WARNING: not removing hydrogen atom without neighbors\n", - "[10:21:18] WARNING: not removing hydrogen atom without neighbors\n", - "[10:21:18] WARNING: not removing hydrogen atom without neighbors\n" + "[11:46:34] WARNING: not removing hydrogen atom without neighbors\n", + "[11:46:34] WARNING: not removing hydrogen atom without neighbors\n", + "[11:46:34] WARNING: not removing hydrogen atom without neighbors\n" ] }, { @@ -68,7 +68,16 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ynowatzk/miniforge3/envs/fiora/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n" + ] + } + ], "source": [ "depth = 6\n", "\n", @@ -85,16 +94,12 @@ "from fiora.GNN.GNNModules import GNNCompiler\n", "from fiora.MS.SimulationFramework import SimulationFramework\n", "\n", - "try: \n", - " model = GNNCompiler.load(MODEL_PATH)\n", + "\n", + "try:\n", + " model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n", "except:\n", - " try:\n", - " print(f\"Warning: Failed loading the model {MODEL_PATH}. Fall back: Loading the model from state dictionary.\")\n", - " model = GNNCompiler.load_from_state_dict(MODEL_PATH)\n", - " print(\"Model loaded from state dict without further errors.\")\n", - " except:\n", - " raise NameError(\"Error: Failed loading from state dict.\")\n", - " " + " raise NameError(\"Error: Failed loading from state dict.\")\n", + " " ] }, { @@ -289,13 +294,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_937446/3302879299.py:24: SettingWithCopyWarning: \n", + "/tmp/ipykernel_3323114/3302879299.py:24: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " df_train[\"Metabolite\"] = df_train[\"SMILES\"].apply(Metabolite) # TRAIN Metabolites are only tracked for tanimoto distance\n", - "/tmp/ipykernel_937446/3302879299.py:25: SettingWithCopyWarning: \n", + "/tmp/ipykernel_3323114/3302879299.py:25: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", @@ -655,7 +660,13 @@ "output_type": "stream", "text": [ "/home/ynowatzk/repos/fiora/fiora/MS/SimulationFramework.py:170: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n", - " df = pd.concat([df, pd.DataFrame(columns=[x + suffix for x in [\"cosine_similarity\", \"kl_div\", \"sim_peaks\", \"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"spectral_refl_cosine\", \"spectral_bias\", \"spectral_sqrt_bias\", \"spectral_sqrt_bias_wo_prec\", \"spectral_refl_bias\", \"steins_cosine\", \"steins_bias\", \"RT_pred\", \"RT_dif\", \"CCS_pred\"]])])\n", + " df = pd.concat([df, pd.DataFrame(columns=[x + suffix for x in [\"cosine_similarity\", \"kl_div\", \"sim_peaks\", \"spectral_cosine\", \"spectral_sqrt_cosine\", \"spectral_sqrt_cosine_wo_prec\", \"spectral_refl_cosine\", \"spectral_bias\", \"spectral_sqrt_bias\", \"spectral_sqrt_bias_wo_prec\", \"spectral_refl_bias\", \"steins_cosine\", \"steins_bias\", \"RT_pred\", \"RT_dif\", \"CCS_pred\"]])])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ "/home/ynowatzk/repos/fiora/fiora/MS/spectral_scores.py:26: RuntimeWarning: divide by zero encountered in divide\n", " vec = vec / np.linalg.norm(vec)\n", "/home/ynowatzk/repos/fiora/fiora/MS/spectral_scores.py:27: RuntimeWarning: divide by zero encountered in divide\n", @@ -1650,7 +1661,7 @@ "