Skip to content

Commit

Permalink
Initial validation wrapper completed
Browse files Browse the repository at this point in the history
  • Loading branch information
John-Peters-UW committed Aug 30, 2024
1 parent b11a075 commit ccccc69
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 126 deletions.
20 changes: 12 additions & 8 deletions metl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import metl.models as models
from metl.encode import DataEncoder, Encoding
from metl.model_encoder import ModelEncoder

UUID_URL_MAP = {
# global source models
Expand Down Expand Up @@ -106,34 +107,37 @@ def _get_data_encoding(hparams):
return encoding


def load_model_and_data_encoder(state_dict, hparams):
def load_model_and_data_encoder(state_dict, hparams, strict, raw, indexing):
model = models.Model[hparams["model_name"]].cls(**hparams)
model.load_state_dict(state_dict)

data_encoder = DataEncoder(_get_data_encoding(hparams))

return model, data_encoder
if raw:
return model, data_encoder
else:
return ModelEncoder(model, data_encoder, strict, indexing)


def get_from_uuid(uuid):
def get_from_uuid(uuid, strict=True, raw=False, indexing=0):
if uuid in UUID_URL_MAP:
state_dict, hparams = download_checkpoint(uuid)
return load_model_and_data_encoder(state_dict, hparams)
return load_model_and_data_encoder(state_dict, hparams, strict, raw, indexing)
else:
raise ValueError(f"UUID {uuid} not found in UUID_URL_MAP")


def get_from_ident(ident):
def get_from_ident(ident, strict=True, raw=False, indexing=0):
ident = ident.lower()
if ident in IDENT_UUID_MAP:
state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident])
return load_model_and_data_encoder(state_dict, hparams)
return load_model_and_data_encoder(state_dict, hparams, strict, raw, indexing)
else:
raise ValueError(f"Identifier {ident} not found in IDENT_UUID_MAP")


def get_from_checkpoint(ckpt_fn):
def get_from_checkpoint(ckpt_fn, strict=False, raw=False, indexing=0):
ckpt = torch.load(ckpt_fn, map_location="cpu")
state_dict = ckpt["state_dict"]
hyper_parameters = ckpt["hyper_parameters"]
return load_model_and_data_encoder(state_dict, hyper_parameters)
return load_model_and_data_encoder(state_dict, hyper_parameters, strict, raw, indexing)
114 changes: 114 additions & 0 deletions metl/model_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
from typing import Literal
from biopandas.pdb import PandasPdb
import metl.relative_attention as ra
from Bio.SeqUtils import seq1
import os

class ModelEncoder(torch.nn.Module):
def __init__(self, model, encoder, strict=True, indexing:Literal[0,1] = 0) -> None:
super(ModelEncoder, self).__init__()

if indexing != 0 and indexing != 1:
raise Exception("Indexing must be equal to 0 or to 1.")

self.model = model
self.encoder = encoder

self.indexing = indexing
self.strict = strict

self.needs_pdb = self.check_if_pdb_needed(model)

def check_if_pdb_needed(self, model):
sequential = next(model.children())

for layer in sequential:
if isinstance(layer, ra.RelativeTransformerEncoder):
return True
return False

def validate_pdb(self, pdb_file, wt):
try:
ppdb = PandasPdb().read_pdb(pdb_file)
except Exception as e:
raise Exception(f"{str(e)} \n\n PDB file could not be read by PandasPDB. It may be incorrectly formatted.")

groups = ppdb.df['ATOM'].groupby('residue_number')
wt_seq = []
for group_name, group_data in groups:
wt_seq.append(seq1(group_data.iloc[0]['residue_name']))
wildtype = ''.join(wt_seq)

if self.strict:
err_str = "Strict mode is on because a METL model that we trained was used. Wildtype and PDB sequeunces must match."
err_str += " If this is expected behavior, pass strict=False to the load function you used."
assert wildtype == wt, err_str

def validate_variants(self, variants, wt):
wt_len = len(wt)
for index, variant in enumerate(variants):
split = variant.split(',')
for mutation in split:
from_amino_acid = mutation[0]
to_amino_acid = mutation[-1]
location = int(mutation[1:-1])

errors = []

if location <= 0 or location >= wt_len-1:
error_str = f"The position for the mutation is {location} but it needs to be between 0 "
error_str += f"and {len(wt)-1} if 0-based and 1 and {len(wt)} if 1-based."
errors.append(error_str)

if wt[location] != from_amino_acid:
errors.append(f"Wildtype at position {location} is {wt[location]} but variant had {from_amino_acid}. Check the variant input.")

if len(errors) != 0:
if self.indexing == 1:
mutation = f"{from_amino_acid}{location+1}{to_amino_acid}"
one_based_variants = self.change_indexing_to(1, variants)

raise Exception(f"Invalid mutation {mutation} that is inside variant {one_based_variants[index]}. Errors: {', '.join(errors)}")

def change_indexing_to(self, indexing, variants):
changed_based_variants = []
for variant in variants:
split = variant.split(',')
variant_strings = []
for mutation in split:
from_amino_acid = mutation[0]
to_amino_acid = mutation[-1]
location = int(mutation[1:-1])

if indexing == 0:
location = location-1
else:
location = location + 1

variant_strings.append(f'{from_amino_acid}{location}{to_amino_acid}')
changed_based_variants.append(",".join(variant_strings))

return changed_based_variants

def forward(self, wt:str, variants:list[str], pdb_fn:str=None):
if self.needs_pdb and pdb_fn is None:
raise Exception("PDB path is required but it was not given. Do you have a PDB file?")

if pdb_fn:
pdb_fn = os.path.abspath(pdb_fn)
self.validate_pdb(pdb_fn, wt)

if self.indexing == 1:
variants = self.change_indexing_to(0, variants)

self.validate_variants(variants, wt)

encoded_variants = self.encoder.encode_variants(wt, variants)

if pdb_fn:
pred = self.model(torch.Tensor(encoded_variants), pdb_fn=pdb_fn)
else:
pred = self.model(torch.Tensor(encoded_variants))

return pred
15 changes: 0 additions & 15 deletions metl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,12 @@ def forward(self, x, **kwargs):
for module in self:
if isinstance(module, ra.RelativeTransformerEncoder) or isinstance(module, SequentialWithArgs):
# for relative transformer encoders, pass in kwargs (pdb_fn)
if 'pdb_fn' not in kwargs:
raise Exception('The model loaded requires the PDB function kwarg.')

kwargs['pdb_fn'] = os.path.abspath(kwargs['pdb_fn'])

self.validate_pdb(kwargs['pdb_fn'])

# ppdb = PandasPdb().read_pdb()
x = module(x, **kwargs)
else:
# for all modules, don't pass in kwargs
x = module(x)
return x

def validate_pdb(self, pdb_fn):
ppdb = PandasPdb().read_pdb(pdb_fn)

def validate_protein_seq(self):
pass


class PositionalEncoding(nn.Module):
# originally from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# they have since updated their implementation, but it is functionally equivalent
Expand Down
144 changes: 41 additions & 103 deletions notebooks/test_validation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,154 +2,92 @@
"cells": [
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import metl\n",
"import torch\n",
"from biopandas.pdb import PandasPdb\n",
"from Bio.SeqUtils import seq1\n"
"from Bio.SeqUtils import seq1\n",
"\n",
"import metl.relative_attention as ra"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"C:\\Users\\johng\\AppData\\Local\\Temp\\ipykernel_6520\\3440960159.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" model = torch.load(open('./model.pt', 'rb'))\n",
"C:\\Users\\johng\\AppData\\Local\\Temp\\ipykernel_6520\\3440960159.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" encoder = torch.load(open('./encoder.pt', 'rb'))\n"
"Initialized PDB bucket matrices in: 0.000\n",
"Initialized PDB bucket matrices in: 0.000\n"
]
}
],
"source": [
"# model, encoder = metl.get_from_ident(\"metl-l-2m-3d-gb1\")\n",
"model = torch.load(open('./model.pt', 'rb'))\n",
"encoder = torch.load(open('./encoder.pt', 'rb'))"
"model_needs_pdb = metl.get_from_ident(\"metl-l-2m-3d-gb1\")\n",
"model_no_pdb = metl.get_from_uuid(uuid=\"YoQkzoLD\") # no relative attention"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"wt = \"MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE\"\n",
"pdb_fn = '../pdbs/2qmt_p.pdb'\n",
"variants = [\"T17P,T54F\", \"V28L,F51A\"]"
"variants = [\"T17P,T54F\", \"V28L,D46A\"]"
]
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['ATOM', 'HETATM', 'ANISOU', 'OTHERS'])\n",
"Index(['record_name', 'atom_number', 'blank_1', 'atom_name', 'alt_loc',\n",
" 'residue_name', 'blank_2', 'chain_id', 'residue_number', 'insertion',\n",
" 'blank_3', 'x_coord', 'y_coord', 'z_coord', 'occupancy', 'b_factor',\n",
" 'blank_4', 'segment_id', 'element_symbol', 'charge', 'line_idx'],\n",
" dtype='object')\n"
]
},
{
"data": {
"text/plain": [
"True"
"(tensor([[ 4.7697e-01, 2.8120e-01, -9.2805e-02, 1.3844e+00, -1.6826e-01,\n",
" 3.0760e-01, 2.9817e-01, -1.3989e+00, -6.9998e-03, 2.9667e+00,\n",
" 5.2389e-01, -1.1593e-01, -6.0801e-01, 1.3960e+00, -7.3188e-01,\n",
" 4.5726e-01, 2.4292e-01, -6.2117e-01, -3.2025e-01, -2.7432e-03,\n",
" 1.0857e+00, 7.2377e-01, -1.8836e-01, 1.2242e+00, -2.0520e-01,\n",
" -1.2772e-01, -5.5548e-01, 1.7935e-01, -4.5334e-02, 5.6413e-01,\n",
" 7.0476e-01, 2.2904e-01, 1.3629e-02, -3.3900e-01, 9.5354e-01,\n",
" 2.7879e-01, 1.5288e+00, 6.2959e-02, -9.8674e-01, 9.6503e-01,\n",
" 1.1230e+00, 7.0476e-01, -1.5055e+00, 3.4333e-01, 1.8991e+00,\n",
" 7.1499e-01, -3.8434e-01, -5.2042e-01, 2.5414e+00, 1.2656e+00,\n",
" -8.8348e-01, 2.6817e+00, -7.5841e-02, 2.7973e+00, -1.7446e-01],\n",
" [-7.1824e-01, -1.0838e-01, -3.0332e-01, 4.3212e-01, -2.2647e-01,\n",
" -6.0814e-01, -3.1514e-01, -9.7168e-01, 2.2807e-01, -4.4759e-01,\n",
" 5.1764e-01, -6.7759e-01, 5.7723e-01, 5.8746e-02, -9.1098e-01,\n",
" -1.7514e-01, -5.6852e-01, 3.0986e-01, -6.1122e-01, -2.3290e-01,\n",
" 3.0098e-01, 6.0075e-01, 6.4402e-01, 3.9483e-01, 3.4120e-01,\n",
" 1.3791e-01, 7.0088e-01, 2.0338e-01, 1.0337e+00, -2.1346e-01,\n",
" -2.6859e-02, 2.6972e-01, 4.1215e-01, 2.8373e-01, -5.9371e-01,\n",
" 4.7806e-01, 1.2857e-01, 3.1594e-02, -2.4400e-01, 1.0700e-01,\n",
" 5.0299e-02, -2.6863e-02, 6.8558e-01, -3.8530e-01, -2.0537e-01,\n",
" -1.3260e-01, -1.0593e+00, -1.3271e-03, -4.2777e-02, 2.2335e-01,\n",
" -3.1454e-01, -7.8193e-02, -2.5404e-02, 3.4191e-01, -1.7368e-01]]),\n",
" tensor([[-3.6763],\n",
" [-3.2601]]))"
]
},
"execution_count": 85,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ppdb = PandasPdb().read_pdb(pdb_fn)\n",
"print(ppdb.df.keys())\n",
"print(ppdb.df['ATOM'].columns)\n",
"groups = ppdb.df['ATOM'].groupby('residue_number')\n",
"wt_seq = []\n",
"for group_name, group_data in groups:\n",
" wt_seq.append(seq1(group_data.iloc[0]['residue_name']))\n",
"''.join(wt_seq)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"encoded_variants = encoder.encode_variants(wt, variants)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[11, 14, 20, 9, 10, 8, 10, 12, 6, 9, 17, 10, 9, 6, 4, 17,\n",
" 17, 13, 4, 1, 18, 3, 1, 1, 17, 1, 4, 9, 18, 5, 9, 14,\n",
" 20, 1, 12, 3, 12, 6, 18, 3, 6, 4, 19, 17, 20, 3, 3, 1,\n",
" 17, 9, 17, 5, 17, 18, 5, 4],\n",
" [11, 14, 20, 9, 10, 8, 10, 12, 6, 9, 17, 10, 9, 6, 4, 17,\n",
" 17, 17, 4, 1, 18, 3, 1, 1, 17, 1, 4, 9, 10, 5, 9, 14,\n",
" 20, 1, 12, 3, 12, 6, 18, 3, 6, 4, 19, 17, 20, 3, 3, 1,\n",
" 17, 9, 17, 1, 17, 18, 17, 4]])"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"encoded_variants"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.8411, -0.2675, 0.4626, 0.8375, 0.1142, 0.3604, 0.7205, -0.7480,\n",
" -0.0744, 2.3942, 0.7609, 0.3872, -0.1821, 1.4607, -0.0923, 1.2541,\n",
" 0.5510, -0.7746, -0.4496, 0.4191, 0.8662, 0.8771, 0.5245, 1.4875,\n",
" 0.3771, 0.5041, -0.3007, 0.1502, -0.4211, 0.3817, 0.5023, 0.2472,\n",
" -0.0427, -0.1571, 0.9614, 0.6088, 1.4072, 0.0447, -0.7701, 1.0510,\n",
" 1.2042, 0.5023, -0.9959, 0.7970, 1.3743, 0.8026, -0.4153, -0.1811,\n",
" 1.9806, 0.7943, -0.8579, 2.2671, -0.0376, 1.6017, -0.1628],\n",
" [ 0.0143, 1.2544, 0.5280, -0.8936, -0.4795, -0.1708, -0.4251, -0.0408,\n",
" -0.1400, -0.3338, -0.3094, -0.9919, 1.1584, 0.1949, -0.6159, -0.1772,\n",
" -0.3620, -0.2166, -0.4934, -1.4132, -0.4597, -1.3537, -0.6285, -0.9735,\n",
" 0.3244, 0.2013, 0.5022, 0.3627, -0.0778, 0.2239, 0.5389, 0.2339,\n",
" -0.6618, 0.0958, -0.3920, -0.5966, -0.7410, 0.0436, -0.2070, -0.9669,\n",
" -1.0585, 0.5389, -1.2097, -0.5970, 0.1008, -0.1897, 0.0036, -0.2175,\n",
" 0.1025, -0.0158, -0.1901, 0.1305, -0.0367, 0.6697, -0.1522]])\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" logits = model(encoded_variants, pdb_fn = pdb_fn)\n",
"print(logits)"
" pred = model_needs_pdb(wt, variants, pdb_fn)\n",
" pred2 = model_no_pdb(wt, variants)\n",
"pred, pred2"
]
}
],
Expand Down

0 comments on commit ccccc69

Please sign in to comment.