From 3dd8580d6ec9cd8813f10437db1b7a123b79aa2c Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 12 Oct 2023 08:22:29 +0200 Subject: [PATCH] unify to pt --- doctr/models/recognition/master/tensorflow.py | 4 +++- doctr/models/recognition/parseq/tensorflow.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index 4985f1d981..bbae216f74 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -181,7 +181,9 @@ def call( output = self.decoder(gt, encoded, source_mask, target_mask, **kwargs) logits = self.linear(output, **kwargs) else: - logits = _bf16_to_numpy_dtype(self.decode(encoded, **kwargs)) + logits = self.decode(encoded, **kwargs) + + logits = _bf16_to_numpy_dtype(logits) if self.exportable: out["logits"] = logits diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 95cc6d03b5..8ef77af4cd 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -388,7 +388,9 @@ def call( ) ) else: - logits = _bf16_to_numpy_dtype(self.decode_autoregressive(features, **kwargs)) + logits = self.decode_autoregressive(features, **kwargs) + + logits = _bf16_to_numpy_dtype(logits) out: Dict[str, tf.Tensor] = {} if self.exportable: