Skip to content

Commit

Permalink
Fix keras.KerasTensor availability issue
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Apr 21, 2024
1 parent 96f370f commit e700958
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
12 changes: 5 additions & 7 deletions lab/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def _module_attr(module, attr):
# Define TensorFlow module types.
_tf_tensor = ModuleType("tensorflow", "Tensor")
_tf_indexedslices = ModuleType("tensorflow", "IndexedSlices")
# On Python 3.9 and higher, we also need to support `keras.KerasTensor`.
if sys.version_info >= (3, 9):
_tf_kerastensor = ModuleType("keras", "KerasTensor")
# `keras.KerasTensor` is only available on newer versions of `keras`. Instead of
# determining exactly when it is available, we simply allow the retrieval to fail.
# TODO: Set `allow_fail=False` in the future.
_tf_kerastensor = ModuleType("keras", "KerasTensor", allow_fail=True)
_tf_variable = ModuleType("tensorflow", "Variable")
_tf_dtype = ModuleType("tensorflow", "DType")
_tf_randomstate = ModuleType("tensorflow", "random.Generator")
Expand Down Expand Up @@ -109,10 +110,7 @@ def _module_attr(module, attr):
NPNumeric = set_union_alias(NPNumeric, "B.NPNumeric")
AGNumeric = Union[_ag_tensor]
AGNumeric = set_union_alias(AGNumeric, "B.AGNumeric")
if sys.version_info >= (3, 9):
TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices, _tf_kerastensor]
else:
TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices]
TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices, _tf_kerastensor]
TFNumeric = set_union_alias(TFNumeric, "B.TFNumeric")
TorchNumeric = Union[_torch_tensor]
TorchNumeric = set_union_alias(TorchNumeric, "B.TorchNumeric")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
"numpy>=1.16",
"scipy>=1.3",
"fdm",
"plum-dispatch>=2.3.2",
"plum-dispatch>=2.3.5",
"opt-einsum",
]

Expand Down

0 comments on commit e700958

Please sign in to comment.