Skip to content

Commit

Permalink
Support keras.KerasTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Apr 21, 2024
1 parent 952a8a9 commit f16beb7
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion lab/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _module_attr(module, attr):
# Define TensorFlow module types.
_tf_tensor = ModuleType("tensorflow", "Tensor")
_tf_indexedslices = ModuleType("tensorflow", "IndexedSlices")
_tf_kerastensor = ModuleType("keras", "KerasTensor")
_tf_variable = ModuleType("tensorflow", "Variable")
_tf_dtype = ModuleType("tensorflow", "DType")
_tf_randomstate = ModuleType("tensorflow", "random.Generator")
Expand Down Expand Up @@ -106,7 +107,7 @@ def _module_attr(module, attr):
NPNumeric = set_union_alias(NPNumeric, "B.NPNumeric")
AGNumeric = Union[_ag_tensor]
AGNumeric = set_union_alias(AGNumeric, "B.AGNumeric")
TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices]
TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices, _tf_kerastensor]
TFNumeric = set_union_alias(TFNumeric, "B.TFNumeric")
TorchNumeric = Union[_torch_tensor]
TorchNumeric = set_union_alias(TorchNumeric, "B.TorchNumeric")
Expand Down

0 comments on commit f16beb7

Please sign in to comment.