diff --git a/lab/types.py b/lab/types.py index 52bcc01..f274a3a 100644 --- a/lab/types.py +++ b/lab/types.py @@ -68,6 +68,7 @@ def _module_attr(module, attr): # Define TensorFlow module types. _tf_tensor = ModuleType("tensorflow", "Tensor") _tf_indexedslices = ModuleType("tensorflow", "IndexedSlices") +_tf_kerastensor = ModuleType("keras", "KerasTensor") _tf_variable = ModuleType("tensorflow", "Variable") _tf_dtype = ModuleType("tensorflow", "DType") _tf_randomstate = ModuleType("tensorflow", "random.Generator") @@ -106,7 +107,7 @@ def _module_attr(module, attr): NPNumeric = set_union_alias(NPNumeric, "B.NPNumeric") AGNumeric = Union[_ag_tensor] AGNumeric = set_union_alias(AGNumeric, "B.AGNumeric") -TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices] +TFNumeric = Union[_tf_tensor, _tf_variable, _tf_indexedslices, _tf_kerastensor] TFNumeric = set_union_alias(TFNumeric, "B.TFNumeric") TorchNumeric = Union[_torch_tensor] TorchNumeric = set_union_alias(TorchNumeric, "B.TorchNumeric")