Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Oct 11, 2023
1 parent bef5926 commit bc46f3a
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
from tensorflow.keras.applications import ResNet50

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_numpy_dtype_converter,
conv_sequence,
load_pretrained_params,
)
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_numpy_dtype, conv_sequence, load_pretrained_params
from doctr.utils.repr import NestedObject

from ...classification import mobilenet_v3_large
Expand Down Expand Up @@ -246,7 +241,7 @@ def call(
return out

if return_model_output or target is None or return_preds:
prob_map = _bf16_numpy_dtype_converter(tf.math.sigmoid(logits))
prob_map = _bf16_to_numpy_dtype(tf.math.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map
Expand Down
9 changes: 2 additions & 7 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@

from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50
from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_numpy_dtype_converter,
conv_sequence,
load_pretrained_params,
)
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_numpy_dtype, conv_sequence, load_pretrained_params
from doctr.utils.repr import NestedObject

from .base import LinkNetPostProcessor, _LinkNet
Expand Down Expand Up @@ -234,7 +229,7 @@ def call(
return out

if return_model_output or target is None or return_preds:
prob_map = _bf16_numpy_dtype_converter(tf.math.sigmoid(logits))
prob_map = _bf16_to_numpy_dtype(tf.math.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/crnn/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.datasets import VOCABS

from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
Expand Down Expand Up @@ -199,7 +199,7 @@ def call(
w, h, c = transposed_feat.get_shape().as_list()[1:]
# B x W x H x C --> B x W x H * C
features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
logits = _bf16_numpy_dtype_converter(self.decoder(features_seq, **kwargs))
logits = _bf16_to_numpy_dtype(self.decoder(features_seq, **kwargs))

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.models.classification import magc_resnet31
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -181,7 +181,7 @@ def call(
output = self.decoder(gt, encoded, source_mask, target_mask, **kwargs)
logits = self.linear(output, **kwargs)
else:
logits = _bf16_numpy_dtype_converter(self.decode(encoded, **kwargs))
logits = _bf16_to_numpy_dtype(self.decode(encoded, **kwargs))

if self.exportable:
out["logits"] = logits
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward

from ...classification import vit_s
from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -388,7 +388,7 @@ def call(
)
)
else:
logits = _bf16_numpy_dtype_converter(self.decode_autoregressive(features, **kwargs))
logits = _bf16_to_numpy_dtype(self.decode_autoregressive(features, **kwargs))

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/sar/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.utils.repr import NestedObject

from ...classification import resnet31
from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["SAR", "sar_resnet31"]
Expand Down Expand Up @@ -316,7 +316,7 @@ def call(
if kwargs.get("training", False) and target is None:
raise ValueError("Need to provide labels during training for teacher forcing")

decoded_features = _bf16_numpy_dtype_converter(
decoded_features = _bf16_to_numpy_dtype(
self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
)

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from doctr.datasets import VOCABS

from ...classification import vit_b, vit_s
from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
Expand Down Expand Up @@ -131,7 +131,7 @@ def call(
logits = tf.reshape(
self.head(features, **kwargs), (B, N, len(self.vocab) + 1)
) # (batch_size, max_length, vocab + 1)
decoded_features = _bf16_numpy_dtype_converter(logits[:, 1:]) # remove cls_token
decoded_features = _bf16_to_numpy_dtype(logits[:, 1:]) # remove cls_token

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
"IntermediateLayerGetter",
"export_model_to_onnx",
"_copy_tensor",
"_bf16_numpy_dtype_converter",
"_bf16_to_numpy_dtype",
]


def _copy_tensor(x: tf.Tensor) -> tf.Tensor:
return tf.identity(x)


def _bf16_numpy_dtype_converter(x: tf.Tensor) -> tf.Tensor:
def _bf16_to_numpy_dtype(x: tf.Tensor) -> tf.Tensor:
# Convert bfloat16 to float32 for numpy compatibility
return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x

Expand Down
6 changes: 3 additions & 3 deletions tests/tensorflow/test_models_utils_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_numpy_dtype_converter,
_bf16_to_numpy_dtype,
_copy_tensor,
conv_sequence,
load_pretrained_params,
Expand All @@ -20,9 +20,9 @@ def test_copy_tensor():
assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and tf.reduce_all(tf.equal(m, x))


def test_bf16_numpy_dtype_converter():
def test_bf16_to_numpy_dtype():
x = tf.random.uniform(shape=[8], minval=0, maxval=1, dtype=tf.bfloat16)
m = _bf16_numpy_dtype_converter(x)
m = _bf16_to_numpy_dtype(x)
assert x.dtype == tf.bfloat16 and m.dtype == tf.float32 and tf.reduce_all(tf.equal(m, tf.cast(x, tf.float32)))


Expand Down

0 comments on commit bc46f3a

Please sign in to comment.