From e685ea953dd24367d3223b7ecfd00d08c30a277f Mon Sep 17 00:00:00 2001 From: Kveta Brazdilova <40542984+kvetab@users.noreply.github.com> Date: Sun, 16 Jan 2022 17:30:51 +0100 Subject: [PATCH 1/2] Return embeddings from model --- biophi/humanization/methods/sapiens/roberta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/biophi/humanization/methods/sapiens/roberta.py b/biophi/humanization/methods/sapiens/roberta.py index f634f76..598d02c 100644 --- a/biophi/humanization/methods/sapiens/roberta.py +++ b/biophi/humanization/methods/sapiens/roberta.py @@ -73,7 +73,7 @@ def predict_proba(self, seq, remove_special=True, return_all_hiddens=False): pred = pd.DataFrame(pred.numpy(), columns=self.interface.task.target_dictionary.symbols) if remove_special: pred.drop(['<s>', '<pad>', '</s>', '<unk>', '<mask>'], axis=1, inplace=True) - return pred + return pred, extra def _is_adding_bos(self): if isinstance(self.interface.task, SentencePredictionTask): From cd36c6237f19955495cd1ca6f3fe168bee8a1438 Mon Sep 17 00:00:00 2001 From: Kveta Brazdilova <40542984+kvetab@users.noreply.github.com> Date: Mon, 17 Jan 2022 09:41:15 +0100 Subject: [PATCH 2/2] Return embeddings only when specified --- biophi/humanization/methods/sapiens/roberta.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/biophi/humanization/methods/sapiens/roberta.py b/biophi/humanization/methods/sapiens/roberta.py index 598d02c..13dbbc6 100644 --- a/biophi/humanization/methods/sapiens/roberta.py +++ b/biophi/humanization/methods/sapiens/roberta.py @@ -73,7 +73,10 @@ def predict_proba(self, seq, remove_special=True, return_all_hiddens=False): pred = pd.DataFrame(pred.numpy(), columns=self.interface.task.target_dictionary.symbols) if remove_special: pred.drop(['<s>', '<pad>', '</s>', '<unk>', '<mask>'], axis=1, inplace=True) - return pred, extra + if return_all_hiddens: + return pred, extra + else: + return pred def _is_adding_bos(self): if isinstance(self.interface.task, SentencePredictionTask):