From e685ea953dd24367d3223b7ecfd00d08c30a277f Mon Sep 17 00:00:00 2001
From: Kveta Brazdilova <40542984+kvetab@users.noreply.github.com>
Date: Sun, 16 Jan 2022 17:30:51 +0100
Subject: [PATCH 1/2] Return embeddings from model

---
 biophi/humanization/methods/sapiens/roberta.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/biophi/humanization/methods/sapiens/roberta.py b/biophi/humanization/methods/sapiens/roberta.py
index f634f76..598d02c 100644
--- a/biophi/humanization/methods/sapiens/roberta.py
+++ b/biophi/humanization/methods/sapiens/roberta.py
@@ -73,7 +73,7 @@ def predict_proba(self, seq, remove_special=True, return_all_hiddens=False):
         pred = pd.DataFrame(pred.numpy(), columns=self.interface.task.target_dictionary.symbols)
         if remove_special:
             pred.drop(['<s>', '<pad>', '</s>', '<unk>', '<mask>'], axis=1, inplace=True)
-        return pred
+        return pred, extra
 
     def _is_adding_bos(self):
         if isinstance(self.interface.task, SentencePredictionTask):

From cd36c6237f19955495cd1ca6f3fe168bee8a1438 Mon Sep 17 00:00:00 2001
From: Kveta Brazdilova <40542984+kvetab@users.noreply.github.com>
Date: Mon, 17 Jan 2022 09:41:15 +0100
Subject: [PATCH 2/2] Return embeddings only when specified

---
 biophi/humanization/methods/sapiens/roberta.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/biophi/humanization/methods/sapiens/roberta.py b/biophi/humanization/methods/sapiens/roberta.py
index 598d02c..13dbbc6 100644
--- a/biophi/humanization/methods/sapiens/roberta.py
+++ b/biophi/humanization/methods/sapiens/roberta.py
@@ -73,7 +73,10 @@ def predict_proba(self, seq, remove_special=True, return_all_hiddens=False):
         pred = pd.DataFrame(pred.numpy(), columns=self.interface.task.target_dictionary.symbols)
         if remove_special:
             pred.drop(['<s>', '<pad>', '</s>', '<unk>', '<mask>'], axis=1, inplace=True)
-        return pred, extra
+        if return_all_hiddens:
+            return pred, extra
+        else:
+            return pred
 
     def _is_adding_bos(self):
         if isinstance(self.interface.task, SentencePredictionTask):