Skip to content

Commit

Permalink
Public function to register custom ops (tensorflow#1193)
Browse files Browse the repository at this point in the history
* Added public functions to register everything.

* Removed decorator

* Revert "Removed decorator"

This reverts commit ebea5bd.

* Added some tests.

* Added the two register.

* Removed unused variables.

* Private func.

* Explicit modules.

* FLake8

* Added documentation.

* Remove useless setup method.

* Black/

* Format BUILD.
  • Loading branch information
gabrieldemarmiesse authored Mar 7, 2020
1 parent 2029487 commit 90a5810
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 9 deletions.
14 changes: 14 additions & 0 deletions tensorflow_addons/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ py_library(
data = [
"__init__.py",
"options.py",
"register.py",
"version.py",
],
deps = [
Expand All @@ -25,5 +26,18 @@ py_library(
"//tensorflow_addons/rnn",
"//tensorflow_addons/seq2seq",
"//tensorflow_addons/text",
"//tensorflow_addons/utils",
],
)

py_test(
name = "register_test",
size = "small",
srcs = [
"register_test.py",
],
main = "register_test.py",
deps = [
":tensorflow_addons",
],
)
1 change: 1 addition & 0 deletions tensorflow_addons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
from tensorflow_addons import rnn
from tensorflow_addons import seq2seq
from tensorflow_addons import text
from tensorflow_addons.register import register_all

from tensorflow_addons.version import __version__
109 changes: 109 additions & 0 deletions tensorflow_addons/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import glob
import os
from pathlib import Path

import tensorflow as tf

from tensorflow_addons.utils.resource_loader import get_project_root


def register_all(keras_objects: bool = True, custom_kernels: bool = True) -> None:
"""Register TensorFlow Addons' objects in TensorFlow global dictionaries.
When loading a Keras model that has a TF Addons' function, it is needed
for this function to be known by the Keras deserialization process.
There are two ways to do this, either do
```python
tf.keras.models.load_model(
"my_model.tf",
custom_objects={"LAMB": tfa.image.optimizer.LAMB}
)
```
or you can do:
```python
tfa.register_all()
tf.tf.keras.models.load_model("my_model.tf")
```
If the model contains custom ops (compiled ops) of TensorFlow Addons,
and the graph is loaded with `tf.saved_model.load`, then custom ops need
to be registered before to avoid an error of the type:
```
tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered
'...' in binary running on ... Make sure the Op and Kernel are
registered in the binary running in this process.
```
In this case, the only way to make sure that the ops are registered is to call
this function:
```python
tfa.register_all()
tf.saved_model.load("my_model.tf")
```
Note that you can call this function multiple times in the same process,
it only has an effect the first time. Afterward, it's just a no-op.
Args:
keras_objects: boolean, `True` by default. If `True`, register all
Keras objects
with `tf.keras.utils.register_keras_serializable(package="Addons")`
If set to False, doesn't register any Keras objects
of Addons in TensorFlow.
custom_kernels: boolean, `True` by default. If `True`, loads all
custom kernels of TensorFlow Addons with
`tf.load_op_library("path/to/so/file.so")`. Loading the SO files
register them automatically. If `False` doesn't load and register
the shared objects files. Not that it might be useful to turn it off
if your installation of Addons doesn't work well with custom ops.
Returns:
None
"""
if keras_objects:
register_keras_objects()
if custom_kernels:
register_custom_kernels()


def register_keras_objects() -> None:
# TODO: once layer_test is replaced by a public API
# and we can used unregistered objects with it
# we can remove all decorators.
# And register Keras objects here.
pass


def register_custom_kernels() -> None:
all_shared_objects = _get_all_shared_objects()
if not all_shared_objects:
raise FileNotFoundError(
"No shared objects files were found in the custom ops "
"directory in Tensorflow Addons, check your installation again,"
"or, if you don't need custom ops, call `tfa.register_all(custom_kernels=False)`"
" instead."
)
try:
for shared_object in all_shared_objects:
tf.load_op_library(shared_object)
except tf.errors.NotFoundError as e:
raise RuntimeError(
"One of the shared objects ({}) could not be loaded. This may be "
"due to a number of reasons (incompatible TensorFlow version, buiding from "
"source with different flags, broken install of TensorFlow Addons...). If you"
"wanted to register the shared objects because you needed them when loading your "
"model, you should fix your install of TensorFlow Addons. If you don't "
"use custom ops in your model, you can skip registering custom ops with "
"`tfa.register_all(custom_kernels=False)`".format(shared_object)
) from e


def _get_all_shared_objects():
custom_ops_dir = os.path.join(get_project_root(), "custom_ops")
all_shared_objects = glob.glob(custom_ops_dir + "/**/*.so", recursive=True)
all_shared_objects = [x for x in all_shared_objects if Path(x).is_file()]
return all_shared_objects
20 changes: 20 additions & 0 deletions tensorflow_addons/register_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
import tensorflow as tf
from tensorflow_addons.register import register_all, _get_all_shared_objects


class AssertRNNCellTest(unittest.TestCase):
def test_multiple_register(self):
register_all()
register_all()

def test_get_all_shared_objects(self):
all_shared_objects = _get_all_shared_objects()
self.assertTrue(len(all_shared_objects) >= 4)

for file in all_shared_objects:
tf.load_op_library(file)


if __name__ == "__main__":
unittest.main()
22 changes: 13 additions & 9 deletions tools/ci_build/verify/check_typing_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
# limitations under the License.
# ==============================================================================
#

from types import ModuleType

from typedapi import ensure_api_is_typed

import tensorflow_addons
import tensorflow_addons as tfa

TUTORIAL_URL = "https://docs.python.org/3/library/typing.html"
HELP_MESSAGE = (
Expand All @@ -30,11 +27,18 @@
EXCEPTION_LIST = []


modules_list = []
for attr_name in dir(tensorflow_addons):
attr = getattr(tensorflow_addons, attr_name)
if isinstance(attr, ModuleType) and attr is not tensorflow_addons.options:
modules_list.append(attr)
modules_list = [
tfa,
tfa.activations,
tfa.callbacks,
tfa.image,
tfa.losses,
tfa.metrics,
tfa.optimizers,
tfa.rnn,
tfa.seq2seq,
tfa.text,
]


if __name__ == "__main__":
Expand Down

0 comments on commit 90a5810

Please sign in to comment.