diff --git a/gui/qt_window.py b/gui/qt_window.py index ab1a34b..eae63b1 100644 --- a/gui/qt_window.py +++ b/gui/qt_window.py @@ -124,12 +124,21 @@ def __init__( self.button_dca_inference_gremlin = QtWidgets.QPushButton("MSA optimization (GREMLIN)") self.button_dca_inference_gremlin.setMinimumWidth(80) self.button_dca_inference_gremlin.setToolTip( - "Generating DCA parameters using GREMLIN (\"MSA optimization\"), " - "you have to provide an MSA in FASTA or A2M format" + "Generating DCA parameters using GREMLIN (\"MSA optimization\"); " + "requires an MSA in FASTA or A2M format" ) self.button_dca_inference_gremlin.clicked.connect(self.pypef_gremlin) self.button_dca_inference_gremlin.setStyleSheet(button_style) + self.button_dca_inference_gremlin_msa_info = QtWidgets.QPushButton("GREMLIN SSM prediction") + self.button_dca_inference_gremlin_msa_info.setMinimumWidth(80) + self.button_dca_inference_gremlin_msa_info.setToolTip( + "Generating DCA parameters using GREMLIN (\"MSA optimization\") and save plots of " + "visualized results; requires an MSA in FASTA or A2M format" + ) + self.button_dca_inference_gremlin_msa_info.clicked.connect(self.pypef_gremlin_msa_info) + self.button_dca_inference_gremlin_msa_info.setStyleSheet(button_style) + self.button_dca_test_dca = QtWidgets.QPushButton("Test (DCA)") self.button_dca_test_dca.setMinimumWidth(80) self.button_dca_test_dca.setToolTip( @@ -208,8 +217,9 @@ def __init__( layout.addWidget(self.dca_text, 3, 1, 1, 1) layout.addWidget(self.button_dca_inference_gremlin, 4, 1, 1, 1) - layout.addWidget(self.button_dca_test_dca, 5, 1, 1, 1) - layout.addWidget(self.button_dca_predict_dca, 6, 1, 1, 1) + layout.addWidget(self.button_dca_inference_gremlin_msa_info, 5, 1, 1, 1) + layout.addWidget(self.button_dca_test_dca, 6, 1, 1, 1) + layout.addWidget(self.button_dca_predict_dca, 7, 1, 1, 1) layout.addWidget(self.hybrid_text, 3, 2, 1, 1) layout.addWidget(self.button_hybrid_train_dca, 4, 2, 1, 1) @@ -222,7 +232,7 @@ def __init__( layout.addWidget(self.button_supervised_train_test_dca, 4, 3, 1, 1) layout.addWidget(self.button_supervised_train_test_onehot, 5, 3, 1, 1) - layout.addWidget(self.textedit_out, 7, 0, 1, -1) + layout.addWidget(self.textedit_out, 8, 0, 1, -1) self.process = QtCore.QProcess(self) self.process.setProcessChannelMode(QtCore.QProcess.MergedChannels) @@ -233,6 +243,8 @@ def __init__( self.process.finished.connect(lambda: self.button_mklsts.setEnabled(True)) self.process.started.connect(lambda: self.button_dca_inference_gremlin.setEnabled(False)) self.process.finished.connect(lambda: self.button_dca_inference_gremlin.setEnabled(True)) + self.process.started.connect(lambda: self.button_dca_inference_gremlin_msa_info.setEnabled(False)) + self.process.finished.connect(lambda: self.button_dca_inference_gremlin_msa_info.setEnabled(True)) self.process.started.connect(lambda: self.button_dca_test_dca.setEnabled(False)) self.process.finished.connect(lambda: self.button_dca_test_dca.setEnabled(True)) self.process.started.connect(lambda: self.button_dca_predict_dca.setEnabled(False)) @@ -290,6 +302,14 @@ def pypef_gremlin(self): self.version_text.setText("Running GREMLIN (DCA) optimization on MSA...") self.exec_pypef(f'param_inference --wt {wt_fasta_file} --msa {msa_file}') # --opt_iter 100 + @QtCore.Slot() + def pypef_gremlin_msa_info(self): + wt_fasta_file = QtWidgets.QFileDialog.getOpenFileName(self, "Select WT FASTA File")[0] + msa_file = QtWidgets.QFileDialog.getOpenFileName( + self, "Select Multiple Sequence Alignment (MSA) file (in FASTA or A2M format)")[0] + if wt_fasta_file and msa_file: + self.version_text.setText("Running GREMLIN (DCA) optimization on MSA...") + self.exec_pypef(f'save_msa_info --wt {wt_fasta_file} --msa {msa_file}') @QtCore.Slot() def pypef_dca_test(self): diff --git a/pypef/dca/dca_run.py b/pypef/dca/dca_run.py index 1baa4bc..b83d70b 100644 --- a/pypef/dca/dca_run.py +++ b/pypef/dca/dca_run.py @@ -19,7 +19,7 @@ from pypef.utils.variant_data import read_csv, get_wt_sequence from pypef.dca.plmc_encoding import save_plmc_dca_encoding_model from pypef.dca.hybrid_model import get_model_and_type, performance_ls_ts, predict_ps, generate_model_and_save_pkl -from pypef.dca.gremlin_inference import save_gremlin_as_pickle, save_corr_csv, plot_all_corr_mtx +from pypef.dca.gremlin_inference import save_gremlin_as_pickle, save_corr_csv, plot_all_corr_mtx, plot_predicted_ssm from pypef.utils.low_n_mutation_extrapolation import performance_mutation_extrapolation, low_n @@ -128,6 +128,7 @@ def run_pypef_hybrid_modeling(arguments): ) save_corr_csv(gremlin) plot_all_corr_mtx(gremlin) + plot_predicted_ssm(gremlin) else: performance_ls_ts( diff --git a/pypef/dca/gremlin_inference.py b/pypef/dca/gremlin_inference.py index 61d9599..d8e3c8f 100644 --- a/pypef/dca/gremlin_inference.py +++ b/pypef/dca/gremlin_inference.py @@ -57,8 +57,9 @@ from scipy.special import logsumexp from scipy.stats import boxcox import pandas as pd +from tqdm import tqdm import tensorflow as tf -tf.get_logger().setLevel('DEBUG') +tf.get_logger().setLevel('WARNING') # Uncomment to hide GPU devices #environ['CUDA_VISIBLE_DEVICES'] = '-1' @@ -718,7 +719,7 @@ def save_gremlin_as_pickle(alignment: str, wt_seq: str, opt_iter: int = 100): }, open('Pickles/GREMLIN', 'wb') ) - logger.info(f"Saved GREMLIN model as Pickle file ({os.path.abspath('Pickles/GREMLIN')})...") + logger.info(f"Saved GREMLIN model as Pickle file as {os.path.abspath('Pickles/GREMLIN')}...") return gremlin @@ -733,3 +734,54 @@ def save_corr_csv(gremlin: GREMLIN, min_distance: int = 0, sort_by: str = 'apc') min_distance=min_distance, sort_by=sort_by ) df_mtx_sorted_mindist.to_csv(f"coevolution_{sort_by}_sorted.csv") + + +def plot_predicted_ssm(gremlin: GREMLIN): + """ + Function to plot all predicted 19 amino acid substitution + effects at all predictable WT/input sequence positions; e.g.: + M1A, M1C, M1E, ..., D2A, D2C, D2E, ..., ..., T300V, T300W, T300Y + """ + wt_sequence = gremlin.wt_seq + wt_score = gremlin.get_wt_score()[0] + aas = "".join(sorted(gremlin.char_alphabet.replace("-", ""))) + variantss, variant_sequencess, variant_scoress = [], [], [] + for i, aa_wt in enumerate(tqdm(wt_sequence)): + variants, variant_sequences, variant_scores = [], [], [] + for aa_sub in aas: + variant = aa_wt + str(i + 1) + aa_sub + variant_sequence = wt_sequence[:i] + aa_sub + wt_sequence[i + 1:] + variant_score = gremlin.get_score(variant_sequence)[0] + variants.append(variant) + variant_sequences.append(variant_sequence) + variant_scores.append(variant_score - wt_score) + variantss.append(variants) + variant_sequencess.append(variant_sequences) + variant_scoress.append(variant_scores) + print(np.shape(variant_scoress)) + fig, ax = plt.subplots(figsize=(30, 3)) + ax.imshow(np.array(variant_scoress).T) + for i_vss, vss in enumerate(variant_scoress): + for i_vs, vs in enumerate(vss): + ax.text( + i_vss, i_vs, + f'{variantss[i_vss][i_vs]}\n{round(vs, 1)}', + size=2, va='center', ha='center' + ) + ax.set_xticks( + range(len(wt_sequence)), + [f'{aa}{i + 1}' for i, aa in enumerate(wt_sequence)], + size=6, rotation=90 + ) + ax.set_yticks(range(len(aas)), aas, size=6) + plt.tight_layout() + plt.savefig('SSM_landscape.png', dpi=300) + pd.DataFrame( + { + 'Variant': np.array(variantss).flatten(), + 'Sequence': np.array(variant_sequencess).flatten(), + 'Variant_Score': np.array(variant_scoress).flatten() + } + ).to_csv('SSM_landscape.csv', sep=',') + logger.info(f"Saved SSM landscape as {os.path.abspath('SSM_landscape.png')} " + f"and CSV data as {os.path.abspath('SSM_landscape.csv')}...") diff --git a/pypef/dca/hybrid_model.py b/pypef/dca/hybrid_model.py index 343cbab..5b72beb 100644 --- a/pypef/dca/hybrid_model.py +++ b/pypef/dca/hybrid_model.py @@ -686,7 +686,6 @@ def save_model_to_dict_pickle( model_type = 'MODEL' pkl_path = os.path.abspath(f'Pickles/{model_type}') - logger.info(f'Saving model as Pickle file ({pkl_path})...') pickle.dump( { 'model': model, @@ -698,6 +697,7 @@ def save_model_to_dict_pickle( }, open(f'Pickles/{model_type}', 'wb') ) + logger.info(f'Saved model as Pickle file ({pkl_path})...') global_model = None diff --git a/pypef/ml/regression.py b/pypef/ml/regression.py index 81b9c4a..52adf7c 100644 --- a/pypef/ml/regression.py +++ b/pypef/ml/regression.py @@ -959,10 +959,10 @@ def save_model( if model_type in ['PLMC', 'GREMLIN'] and encoding not in ['aaidx', 'onehot']: name = 'ML' + model_type.lower() f_name = os.path.abspath(os.path.join(path, 'Pickles', name)) - logger.info(f'Saving model ({f_name})...') file = open(f_name, 'wb') pickle.dump(regressor_, file) file.close() + logger.info(f'Saved model as {f_name}...') except IndexError: raise IndexError diff --git a/pypef/utils/low_n_mutation_extrapolation.py b/pypef/utils/low_n_mutation_extrapolation.py index a286cbe..bd5e773 100644 --- a/pypef/utils/low_n_mutation_extrapolation.py +++ b/pypef/utils/low_n_mutation_extrapolation.py @@ -289,8 +289,8 @@ def performance_mutation_extrapolation( logger.info('Fitting regressor on lvl 1 substitution data...') regressor.fit(x_train, y_train) if save_model: - logger.info(f'Saving model as Pickle file: ML_LVL_1') pickle.dump(regressor, open(os.path.join('Pickles', 'ML_LVL_1'), 'wb')) + logger.info(f'Saved model as Pickle file: ML_LVL_1') for i, _ in enumerate(tqdm(collected_levels)): if i < len(collected_levels) - 1: # not last i else error, last entry is: lvl 1 --> all higher variants test_idx = collected_levels[i + 1] diff --git a/pypef/utils/plot.py b/pypef/utils/plot.py index aaf50ad..ebbc7a0 100644 --- a/pypef/utils/plot.py +++ b/pypef/utils/plot.py @@ -87,6 +87,6 @@ def plot_y_true_vs_y_pred( # i += 1 # iterate until finding an unused file name # file_name = f'DCA_Hybrid_Model_LS_TS_Performance({i}).png' plt.colorbar() - logger.info(f'Saving plot ({os.path.abspath(file_name)})...') plt.savefig(file_name, dpi=500) plt.close('all') + logger.info(f'Saved plot as {os.path.abspath(file_name)}...')