diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 16e9bc5d7c9ca4..b92a237282b6cf 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -76,14 +76,14 @@ def infer_framework_from_repr(x): Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the frameworks in a smart order, without the need to import the frameworks). """ - representation = repr(x) - if representation.startswith("tensor"): + representation = str(type(x)) + if representation.startswith("