From ccccc6936fd95bd7295dc847abf81f16ff2c5b6d Mon Sep 17 00:00:00 2001 From: John Peters Date: Fri, 30 Aug 2024 15:19:21 -0500 Subject: [PATCH] Initial validation wrapper completed --- metl/main.py | 20 +++-- metl/model_encoder.py | 114 +++++++++++++++++++++++++ metl/models.py | 15 ---- notebooks/test_validation.ipynb | 144 +++++++++----------------------- 4 files changed, 167 insertions(+), 126 deletions(-) create mode 100644 metl/model_encoder.py diff --git a/metl/main.py b/metl/main.py index a500c58..4109650 100644 --- a/metl/main.py +++ b/metl/main.py @@ -3,6 +3,7 @@ import metl.models as models from metl.encode import DataEncoder, Encoding +from metl.model_encoder import ModelEncoder UUID_URL_MAP = { # global source models @@ -106,34 +107,37 @@ def _get_data_encoding(hparams): return encoding -def load_model_and_data_encoder(state_dict, hparams): +def load_model_and_data_encoder(state_dict, hparams, strict, raw, indexing): model = models.Model[hparams["model_name"]].cls(**hparams) model.load_state_dict(state_dict) data_encoder = DataEncoder(_get_data_encoding(hparams)) - return model, data_encoder + if raw: + return model, data_encoder + else: + return ModelEncoder(model, data_encoder, strict, indexing) -def get_from_uuid(uuid): +def get_from_uuid(uuid, strict=True, raw=False, indexing=0): if uuid in UUID_URL_MAP: state_dict, hparams = download_checkpoint(uuid) - return load_model_and_data_encoder(state_dict, hparams) + return load_model_and_data_encoder(state_dict, hparams, strict, raw, indexing) else: raise ValueError(f"UUID {uuid} not found in UUID_URL_MAP") -def get_from_ident(ident): +def get_from_ident(ident, strict=True, raw=False, indexing=0): ident = ident.lower() if ident in IDENT_UUID_MAP: state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident]) - return load_model_and_data_encoder(state_dict, hparams) + return load_model_and_data_encoder(state_dict, hparams, strict, raw, indexing) else: raise ValueError(f"Identifier {ident} not found in IDENT_UUID_MAP") -def get_from_checkpoint(ckpt_fn): +def get_from_checkpoint(ckpt_fn, strict=False, raw=False, indexing=0): ckpt = torch.load(ckpt_fn, map_location="cpu") state_dict = ckpt["state_dict"] hyper_parameters = ckpt["hyper_parameters"] - return load_model_and_data_encoder(state_dict, hyper_parameters) + return load_model_and_data_encoder(state_dict, hyper_parameters, strict, raw, indexing) diff --git a/metl/model_encoder.py b/metl/model_encoder.py new file mode 100644 index 0000000..03b5e8d --- /dev/null +++ b/metl/model_encoder.py @@ -0,0 +1,114 @@ +import torch +from typing import Literal +from biopandas.pdb import PandasPdb +import metl.relative_attention as ra +from Bio.SeqUtils import seq1 +import os + +class ModelEncoder(torch.nn.Module): + def __init__(self, model, encoder, strict=True, indexing:Literal[0,1] = 0) -> None: + super(ModelEncoder, self).__init__() + + if indexing != 0 and indexing != 1: + raise Exception("Indexing must be equal to 0 or to 1.") + + self.model = model + self.encoder = encoder + + self.indexing = indexing + self.strict = strict + + self.needs_pdb = self.check_if_pdb_needed(model) + + def check_if_pdb_needed(self, model): + sequential = next(model.children()) + + for layer in sequential: + if isinstance(layer, ra.RelativeTransformerEncoder): + return True + return False + + def validate_pdb(self, pdb_file, wt): + try: + ppdb = PandasPdb().read_pdb(pdb_file) + except Exception as e: + raise Exception(f"{str(e)} \n\n PDB file could not be read by PandasPDB. It may be incorrectly formatted.") + + groups = ppdb.df['ATOM'].groupby('residue_number') + wt_seq = [] + for group_name, group_data in groups: + wt_seq.append(seq1(group_data.iloc[0]['residue_name'])) + wildtype = ''.join(wt_seq) + + if self.strict: + err_str = "Strict mode is on because a METL model that we trained was used. Wildtype and PDB sequeunces must match." + err_str += " If this is expected behavior, pass strict=False to the load function you used." + assert wildtype == wt, err_str + + def validate_variants(self, variants, wt): + wt_len = len(wt) + for index, variant in enumerate(variants): + split = variant.split(',') + for mutation in split: + from_amino_acid = mutation[0] + to_amino_acid = mutation[-1] + location = int(mutation[1:-1]) + + errors = [] + + if location <= 0 or location >= wt_len-1: + error_str = f"The position for the mutation is {location} but it needs to be between 0 " + error_str += f"and {len(wt)-1} if 0-based and 1 and {len(wt)} if 1-based." + errors.append(error_str) + + if wt[location] != from_amino_acid: + errors.append(f"Wildtype at position {location} is {wt[location]} but variant had {from_amino_acid}. Check the variant input.") + + if len(errors) != 0: + if self.indexing == 1: + mutation = f"{from_amino_acid}{location+1}{to_amino_acid}" + one_based_variants = self.change_indexing_to(1, variants) + + raise Exception(f"Invalid mutation {mutation} that is inside variant {one_based_variants[index]}. Errors: {', '.join(errors)}") + + def change_indexing_to(self, indexing, variants): + changed_based_variants = [] + for variant in variants: + split = variant.split(',') + variant_strings = [] + for mutation in split: + from_amino_acid = mutation[0] + to_amino_acid = mutation[-1] + location = int(mutation[1:-1]) + + if indexing == 0: + location = location-1 + else: + location = location + 1 + + variant_strings.append(f'{from_amino_acid}{location}{to_amino_acid}') + changed_based_variants.append(",".join(variant_strings)) + + return changed_based_variants + + def forward(self, wt:str, variants:list[str], pdb_fn:str=None): + if self.needs_pdb and pdb_fn is None: + raise Exception("PDB path is required but it was not given. Do you have a PDB file?") + + if pdb_fn: + pdb_fn = os.path.abspath(pdb_fn) + self.validate_pdb(pdb_fn, wt) + + if self.indexing == 1: + variants = self.change_indexing_to(0, variants) + + self.validate_variants(variants, wt) + + encoded_variants = self.encoder.encode_variants(wt, variants) + + if pdb_fn: + pred = self.model(torch.Tensor(encoded_variants), pdb_fn=pdb_fn) + else: + pred = self.model(torch.Tensor(encoded_variants)) + + return pred \ No newline at end of file diff --git a/metl/models.py b/metl/models.py index 2848c64..16ac972 100644 --- a/metl/models.py +++ b/metl/models.py @@ -42,27 +42,12 @@ def forward(self, x, **kwargs): for module in self: if isinstance(module, ra.RelativeTransformerEncoder) or isinstance(module, SequentialWithArgs): # for relative transformer encoders, pass in kwargs (pdb_fn) - if 'pdb_fn' not in kwargs: - raise Exception('The model loaded requires the PDB function kwarg.') - - kwargs['pdb_fn'] = os.path.abspath(kwargs['pdb_fn']) - - self.validate_pdb(kwargs['pdb_fn']) - - # ppdb = PandasPdb().read_pdb() x = module(x, **kwargs) else: # for all modules, don't pass in kwargs x = module(x) return x - def validate_pdb(self, pdb_fn): - ppdb = PandasPdb().read_pdb(pdb_fn) - - def validate_protein_seq(self): - pass - - class PositionalEncoding(nn.Module): # originally from https://pytorch.org/tutorials/beginner/transformer_tutorial.html # they have since updated their implementation, but it is functionally equivalent diff --git a/notebooks/test_validation.ipynb b/notebooks/test_validation.ipynb index 28f6c4b..04086b0 100644 --- a/notebooks/test_validation.ipynb +++ b/notebooks/test_validation.ipynb @@ -2,154 +2,92 @@ "cells": [ { "cell_type": "code", - "execution_count": 82, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import metl\n", "import torch\n", "from biopandas.pdb import PandasPdb\n", - "from Bio.SeqUtils import seq1\n" + "from Bio.SeqUtils import seq1\n", + "\n", + "import metl.relative_attention as ra" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 2, "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "C:\\Users\\johng\\AppData\\Local\\Temp\\ipykernel_6520\\3440960159.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", - " model = torch.load(open('./model.pt', 'rb'))\n", - "C:\\Users\\johng\\AppData\\Local\\Temp\\ipykernel_6520\\3440960159.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", - " encoder = torch.load(open('./encoder.pt', 'rb'))\n" + "Initialized PDB bucket matrices in: 0.000\n", + "Initialized PDB bucket matrices in: 0.000\n" ] } ], "source": [ - "# model, encoder = metl.get_from_ident(\"metl-l-2m-3d-gb1\")\n", - "model = torch.load(open('./model.pt', 'rb'))\n", - "encoder = torch.load(open('./encoder.pt', 'rb'))" + "model_needs_pdb = metl.get_from_ident(\"metl-l-2m-3d-gb1\")\n", + "model_no_pdb = metl.get_from_uuid(uuid=\"YoQkzoLD\") # no relative attention" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "wt = \"MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE\"\n", "pdb_fn = '../pdbs/2qmt_p.pdb'\n", - "variants = [\"T17P,T54F\", \"V28L,F51A\"]" + "variants = [\"T17P,T54F\", \"V28L,D46A\"]" ] }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 4, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_keys(['ATOM', 'HETATM', 'ANISOU', 'OTHERS'])\n", - "Index(['record_name', 'atom_number', 'blank_1', 'atom_name', 'alt_loc',\n", - " 'residue_name', 'blank_2', 'chain_id', 'residue_number', 'insertion',\n", - " 'blank_3', 'x_coord', 'y_coord', 'z_coord', 'occupancy', 'b_factor',\n", - " 'blank_4', 'segment_id', 'element_symbol', 'charge', 'line_idx'],\n", - " dtype='object')\n" - ] - }, { "data": { "text/plain": [ - "True" + "(tensor([[ 4.7697e-01, 2.8120e-01, -9.2805e-02, 1.3844e+00, -1.6826e-01,\n", + " 3.0760e-01, 2.9817e-01, -1.3989e+00, -6.9998e-03, 2.9667e+00,\n", + " 5.2389e-01, -1.1593e-01, -6.0801e-01, 1.3960e+00, -7.3188e-01,\n", + " 4.5726e-01, 2.4292e-01, -6.2117e-01, -3.2025e-01, -2.7432e-03,\n", + " 1.0857e+00, 7.2377e-01, -1.8836e-01, 1.2242e+00, -2.0520e-01,\n", + " -1.2772e-01, -5.5548e-01, 1.7935e-01, -4.5334e-02, 5.6413e-01,\n", + " 7.0476e-01, 2.2904e-01, 1.3629e-02, -3.3900e-01, 9.5354e-01,\n", + " 2.7879e-01, 1.5288e+00, 6.2959e-02, -9.8674e-01, 9.6503e-01,\n", + " 1.1230e+00, 7.0476e-01, -1.5055e+00, 3.4333e-01, 1.8991e+00,\n", + " 7.1499e-01, -3.8434e-01, -5.2042e-01, 2.5414e+00, 1.2656e+00,\n", + " -8.8348e-01, 2.6817e+00, -7.5841e-02, 2.7973e+00, -1.7446e-01],\n", + " [-7.1824e-01, -1.0838e-01, -3.0332e-01, 4.3212e-01, -2.2647e-01,\n", + " -6.0814e-01, -3.1514e-01, -9.7168e-01, 2.2807e-01, -4.4759e-01,\n", + " 5.1764e-01, -6.7759e-01, 5.7723e-01, 5.8746e-02, -9.1098e-01,\n", + " -1.7514e-01, -5.6852e-01, 3.0986e-01, -6.1122e-01, -2.3290e-01,\n", + " 3.0098e-01, 6.0075e-01, 6.4402e-01, 3.9483e-01, 3.4120e-01,\n", + " 1.3791e-01, 7.0088e-01, 2.0338e-01, 1.0337e+00, -2.1346e-01,\n", + " -2.6859e-02, 2.6972e-01, 4.1215e-01, 2.8373e-01, -5.9371e-01,\n", + " 4.7806e-01, 1.2857e-01, 3.1594e-02, -2.4400e-01, 1.0700e-01,\n", + " 5.0299e-02, -2.6863e-02, 6.8558e-01, -3.8530e-01, -2.0537e-01,\n", + " -1.3260e-01, -1.0593e+00, -1.3271e-03, -4.2777e-02, 2.2335e-01,\n", + " -3.1454e-01, -7.8193e-02, -2.5404e-02, 3.4191e-01, -1.7368e-01]]),\n", + " tensor([[-3.6763],\n", + " [-3.2601]]))" ] }, - "execution_count": 85, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "ppdb = PandasPdb().read_pdb(pdb_fn)\n", - "print(ppdb.df.keys())\n", - "print(ppdb.df['ATOM'].columns)\n", - "groups = ppdb.df['ATOM'].groupby('residue_number')\n", - "wt_seq = []\n", - "for group_name, group_data in groups:\n", - " wt_seq.append(seq1(group_data.iloc[0]['residue_name']))\n", - "''.join(wt_seq)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "encoded_variants = encoder.encode_variants(wt, variants)" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[11, 14, 20, 9, 10, 8, 10, 12, 6, 9, 17, 10, 9, 6, 4, 17,\n", - " 17, 13, 4, 1, 18, 3, 1, 1, 17, 1, 4, 9, 18, 5, 9, 14,\n", - " 20, 1, 12, 3, 12, 6, 18, 3, 6, 4, 19, 17, 20, 3, 3, 1,\n", - " 17, 9, 17, 5, 17, 18, 5, 4],\n", - " [11, 14, 20, 9, 10, 8, 10, 12, 6, 9, 17, 10, 9, 6, 4, 17,\n", - " 17, 17, 4, 1, 18, 3, 1, 1, 17, 1, 4, 9, 10, 5, 9, 14,\n", - " 20, 1, 12, 3, 12, 6, 18, 3, 6, 4, 19, 17, 20, 3, 3, 1,\n", - " 17, 9, 17, 1, 17, 18, 17, 4]])" - ] - }, - "execution_count": 86, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "encoded_variants" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ 0.8411, -0.2675, 0.4626, 0.8375, 0.1142, 0.3604, 0.7205, -0.7480,\n", - " -0.0744, 2.3942, 0.7609, 0.3872, -0.1821, 1.4607, -0.0923, 1.2541,\n", - " 0.5510, -0.7746, -0.4496, 0.4191, 0.8662, 0.8771, 0.5245, 1.4875,\n", - " 0.3771, 0.5041, -0.3007, 0.1502, -0.4211, 0.3817, 0.5023, 0.2472,\n", - " -0.0427, -0.1571, 0.9614, 0.6088, 1.4072, 0.0447, -0.7701, 1.0510,\n", - " 1.2042, 0.5023, -0.9959, 0.7970, 1.3743, 0.8026, -0.4153, -0.1811,\n", - " 1.9806, 0.7943, -0.8579, 2.2671, -0.0376, 1.6017, -0.1628],\n", - " [ 0.0143, 1.2544, 0.5280, -0.8936, -0.4795, -0.1708, -0.4251, -0.0408,\n", - " -0.1400, -0.3338, -0.3094, -0.9919, 1.1584, 0.1949, -0.6159, -0.1772,\n", - " -0.3620, -0.2166, -0.4934, -1.4132, -0.4597, -1.3537, -0.6285, -0.9735,\n", - " 0.3244, 0.2013, 0.5022, 0.3627, -0.0778, 0.2239, 0.5389, 0.2339,\n", - " -0.6618, 0.0958, -0.3920, -0.5966, -0.7410, 0.0436, -0.2070, -0.9669,\n", - " -1.0585, 0.5389, -1.2097, -0.5970, 0.1008, -0.1897, 0.0036, -0.2175,\n", - " 0.1025, -0.0158, -0.1901, 0.1305, -0.0367, 0.6697, -0.1522]])\n" - ] - } - ], "source": [ "with torch.no_grad():\n", - " logits = model(encoded_variants, pdb_fn = pdb_fn)\n", - "print(logits)" + " pred = model_needs_pdb(wt, variants, pdb_fn)\n", + " pred2 = model_no_pdb(wt, variants)\n", + "pred, pred2" ] } ],