Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Public function to register custom ops #1193

Merged
merged 15 commits into from
Mar 7, 2020
14 changes: 14 additions & 0 deletions tensorflow_addons/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ py_library(
name = "tensorflow_addons",
data = [
"__init__.py",
"register.py",
"version.py",
],
deps = [
Expand All @@ -24,5 +25,18 @@ py_library(
"//tensorflow_addons/rnn",
"//tensorflow_addons/seq2seq",
"//tensorflow_addons/text",
"//tensorflow_addons/utils",
],
)

py_test(
name = "register_test",
size = "small",
srcs = [
"register_test.py",
],
main = "register_test.py",
deps = [
":tensorflow_addons",
],
)
1 change: 1 addition & 0 deletions tensorflow_addons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
from tensorflow_addons import rnn
from tensorflow_addons import seq2seq
from tensorflow_addons import text
from tensorflow_addons.register import register_all

from tensorflow_addons.version import __version__
109 changes: 109 additions & 0 deletions tensorflow_addons/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import glob
import os
from pathlib import Path

import tensorflow as tf

from tensorflow_addons.utils.resource_loader import get_project_root


def register_all(keras_objects: bool = True, custom_kernels: bool = True) -> None:
gabrieldemarmiesse marked this conversation as resolved.
Show resolved Hide resolved
"""Register TensorFlow Addons' objects in TensorFlow global dictionaries.
When loading a Keras model that has a TF Addons' function, it is needed
for this function to be known by the Keras deserialization process.
There are two ways to do this, either do
```python
tf.keras.models.load_model(
"my_model.tf",
custom_objects={"LAMB": tfa.image.optimizer.LAMB}
)
```
or you can do:
```python
tfa.register_all()
tf.tf.keras.models.load_model("my_model.tf")
```
If the model contains custom ops (compiled ops) of TensorFlow Addons,
and the graph is loaded with `tf.saved_model.load`, then custom ops need
to be registered before to avoid an error of the type:
```
tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered
'...' in binary running on ... Make sure the Op and Kernel are
registered in the binary running in this process.
```
In this case, the only way to make sure that the ops are registered is to call
this function:
```python
tfa.register_all()
tf.saved_model.load("my_model.tf")
```
Note that you can call this function multiple times in the same process,
it only has an effect the first time. Afterward, it's just a no-op.
Args:
keras_objects: boolean, `True` by default. If `True`, register all
Keras objects
with `tf.keras.utils.register_keras_serializable(package="Addons")`
If set to False, doesn't register any Keras objects
of Addons in TensorFlow.
custom_kernels: boolean, `True` by default. If `True`, loads all
custom kernels of TensorFlow Addons with
`tf.load_op_library("path/to/so/file.so")`. Loading the SO files
register them automatically. If `False` doesn't load and register
the shared objects files. Not that it might be useful to turn it off
if your installation of Addons doesn't work well with custom ops.
Returns:
None
"""
if keras_objects:
register_keras_objects()
if custom_kernels:
register_custom_kernels()


def register_keras_objects() -> None:
# TODO: once layer_test is replaced by a public API
# and we can used unregistered objects with it
# we can remove all decorators.
# And register Keras objects here.
pass


def register_custom_kernels() -> None:
all_shared_objects = _get_all_shared_objects()
if not all_shared_objects:
raise FileNotFoundError(
"No shared objects files were found in the custom ops "
"directory in Tensorflow Addons, check your installation again,"
"or, if you don't need custom ops, call `tfa.register_all(custom_kernels=False)`"
" instead."
)
try:
for shared_object in all_shared_objects:
tf.load_op_library(shared_object)
except tf.errors.NotFoundError as e:
raise RuntimeError(
"One of the shared objects ({}) could not be loaded. This may be "
"due to a number of reasons (incompatible TensorFlow version, buiding from "
"source with different flags, broken install of TensorFlow Addons...). If you"
"wanted to register the shared objects because you needed them when loading your "
"model, you should fix your install of TensorFlow Addons. If you don't "
"use custom ops in your model, you can skip registering custom ops with "
"`tfa.register_all(custom_kernels=False)`".format(shared_object)
) from e


def _get_all_shared_objects():
custom_ops_dir = os.path.join(get_project_root(), "custom_ops")
all_shared_objects = glob.glob(custom_ops_dir + "/**/*.so", recursive=True)
all_shared_objects = [x for x in all_shared_objects if Path(x).is_file()]
return all_shared_objects
23 changes: 23 additions & 0 deletions tensorflow_addons/register_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest
import tensorflow as tf
from tensorflow_addons.register import register_all, _get_all_shared_objects


class AssertRNNCellTest(unittest.TestCase):
def setUp(self):
pass

def test_multiple_register(self):
register_all()
register_all()

def test_get_all_shared_objects(self):
all_shared_objects = _get_all_shared_objects()
self.assertTrue(len(all_shared_objects) >= 4)

for file in all_shared_objects:
tf.load_op_library(file)


if __name__ == "__main__":
unittest.main()
22 changes: 13 additions & 9 deletions tools/ci_build/verify/check_typing_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
# limitations under the License.
# ==============================================================================
#

from types import ModuleType

from typedapi import ensure_api_is_typed

import tensorflow_addons
import tensorflow_addons as tfa

TUTORIAL_URL = "https://docs.python.org/3/library/typing.html"
HELP_MESSAGE = (
Expand All @@ -30,11 +27,18 @@
EXCEPTION_LIST = []


modules_list = []
for attr_name in dir(tensorflow_addons):
attr = getattr(tensorflow_addons, attr_name)
if isinstance(attr, ModuleType):
modules_list.append(attr)
modules_list = [
tfa,
tfa.activations,
tfa.callbacks,
tfa.image,
tfa.losses,
tfa.metrics,
tfa.optimizers,
tfa.rnn,
tfa.seq2seq,
tfa.text,
]


if __name__ == "__main__":
Expand Down