diff --git a/biophi/humanization/methods/sapiens/roberta.py b/biophi/humanization/methods/sapiens/roberta.py index f634f76..13dbbc6 100644 --- a/biophi/humanization/methods/sapiens/roberta.py +++ b/biophi/humanization/methods/sapiens/roberta.py @@ -73,7 +73,10 @@ def predict_proba(self, seq, remove_special=True, return_all_hiddens=False): pred = pd.DataFrame(pred.numpy(), columns=self.interface.task.target_dictionary.symbols) if remove_special: pred.drop(['', '', '', '', ''], axis=1, inplace=True) - return pred + if return_all_hiddens: + return pred, extra + else: + return pred def _is_adding_bos(self): if isinstance(self.interface.task, SentencePredictionTask):