Skip to content

Commit

Permalink
Lazy loading of shared libraries. (#869)
Browse files Browse the repository at this point in the history
* Added lazy loading
  • Loading branch information
gabrieldemarmiesse authored and seanpmorgan committed Jan 14, 2020
1 parent 0a23fd2 commit 9e13d30
Show file tree
Hide file tree
Showing 18 changed files with 77 additions and 81 deletions.
9 changes: 4 additions & 5 deletions tensorflow_addons/activations/gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package='Addons')
Expand All @@ -44,10 +43,10 @@ def gelu(x, approximate=True):
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_gelu(x, approximate)
return _activation_so.ops.addons_gelu(x, approximate)


@tf.RegisterGradient("Addons>Gelu")
def _gelu_grad(op, grad):
return _activation_ops_so.addons_gelu_grad(grad, op.inputs[0],
return _activation_so.ops.addons_gelu_grad(grad, op.inputs[0],
op.get_attr("approximate"))
9 changes: 4 additions & 5 deletions tensorflow_addons/activations/hardshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package='Addons')
Expand All @@ -40,11 +39,11 @@ def hardshrink(x, lower=-0.5, upper=0.5):
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_hardshrink(x, lower, upper)
return _activation_so.ops.addons_hardshrink(x, lower, upper)


@tf.RegisterGradient("Addons>Hardshrink")
def _hardshrink_grad(op, grad):
return _activation_ops_so.addons_hardshrink_grad(grad, op.inputs[0],
return _activation_so.ops.addons_hardshrink_grad(grad, op.inputs[0],
op.get_attr("lower"),
op.get_attr("upper"))
9 changes: 4 additions & 5 deletions tensorflow_addons/activations/lisht.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package='Addons')
Expand All @@ -39,9 +38,9 @@ def lisht(x):
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_lisht(x)
return _activation_so.ops.addons_lisht(x)


@tf.RegisterGradient("Addons>Lisht")
def _lisht_grad(op, grad):
return _activation_ops_so.addons_lisht_grad(grad, op.inputs[0])
return _activation_so.ops.addons_lisht_grad(grad, op.inputs[0])
9 changes: 4 additions & 5 deletions tensorflow_addons/activations/mish.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package='Addons')
Expand All @@ -39,9 +38,9 @@ def mish(x):
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_mish(x)
return _activation_so.ops.addons_mish(x)


@tf.RegisterGradient("Addons>Mish")
def _mish_grad(op, grad):
return _activation_ops_so.addons_mish_grad(grad, op.inputs[0])
return _activation_so.ops.addons_mish_grad(grad, op.inputs[0])
9 changes: 4 additions & 5 deletions tensorflow_addons/activations/softshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package='Addons')
Expand All @@ -40,11 +39,11 @@ def softshrink(x, lower=-0.5, upper=0.5):
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_softshrink(x, lower, upper)
return _activation_so.ops.addons_softshrink(x, lower, upper)


@tf.RegisterGradient("Addons>Softshrink")
def _softshrink_grad(op, grad):
return _activation_ops_so.addons_softshrink_grad(grad, op.inputs[0],
return _activation_so.ops.addons_softshrink_grad(grad, op.inputs[0],
op.get_attr("lower"),
op.get_attr("upper"))
9 changes: 4 additions & 5 deletions tensorflow_addons/activations/tanhshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_activation_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
_activation_so = LazySO("custom_ops/activations/_activation_ops.so")


@tf.keras.utils.register_keras_serializable(package='Addons')
Expand All @@ -35,9 +34,9 @@ def tanhshrink(x):
A `Tensor`. Has the same type as `features`.
"""
x = tf.convert_to_tensor(x)
return _activation_ops_so.addons_tanhshrink(x)
return _activation_so.ops.addons_tanhshrink(x)


@tf.RegisterGradient("Addons>Tanhshrink")
def _tanhshrink_grad(op, grad):
return _activation_ops_so.addons_tanhshrink_grad(grad, op.inputs[0])
return _activation_so.ops.addons_tanhshrink_grad(grad, op.inputs[0])
7 changes: 3 additions & 4 deletions tensorflow_addons/image/connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@

import tensorflow as tf

from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_image_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/image/_image_ops.so"))
_image_so = LazySO("custom_ops/image/_image_ops.so")


@tf.function
Expand Down Expand Up @@ -62,7 +61,7 @@ def connected_components(images, name=None):
raise TypeError(
"images should have rank 2 (HW) or 3 (NHW). Static shape is %s"
% image_or_images.get_shape())
components = _image_ops_so.addons_image_connected_components(images)
components = _image_so.ops.addons_image_connected_components(images)

# TODO(ringwalt): Component id renaming should be done in the op,
# to avoid constructing multiple additional large tensors.
Expand Down
7 changes: 3 additions & 4 deletions tensorflow_addons/image/distance_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

import tensorflow as tf
from tensorflow_addons.image import utils as img_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_image_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/image/_image_ops.so"))
_image_so = LazySO("custom_ops/image/_image_ops.so")

tf.no_gradient("Addons>EuclideanDistanceTransform")

Expand Down Expand Up @@ -64,6 +63,6 @@ def euclidean_dist_transform(images, dtype=tf.float32, name=None):
raise TypeError("`dtype` must be float16, float32 or float64")

images = tf.cast(images, dtype)
output = _image_ops_so.addons_euclidean_distance_transform(images)
output = _image_so.ops.addons_euclidean_distance_transform(images)

return img_utils.from_4D_image(output, original_ndims)
7 changes: 3 additions & 4 deletions tensorflow_addons/image/distort_image_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_distort_image_ops = tf.load_op_library(
get_path_to_datafile("custom_ops/image/_distort_image_ops.so"))
_distort_image_so = LazySO("custom_ops/image/_distort_image_ops.so")


# pylint: disable=invalid-name
Expand Down Expand Up @@ -141,7 +140,7 @@ def adjust_hsv_in_yiq(image,
orig_dtype = image.dtype
flt_image = tf.image.convert_image_dtype(image, tf.dtypes.float32)

rgb_altered = _distort_image_ops.addons_adjust_hsv_in_yiq(
rgb_altered = _distort_image_so.ops.addons_adjust_hsv_in_yiq(
flt_image, delta_hue, scale_saturation, scale_value)

return tf.image.convert_image_dtype(rgb_altered, orig_dtype)
10 changes: 5 additions & 5 deletions tensorflow_addons/image/resampler_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_resampler_ops = tf.load_op_library(
get_path_to_datafile("custom_ops/image/_resampler_ops.so"))
_resampler_so = LazySO("custom_ops/image/_resampler_ops.so")


@tf.function
Expand Down Expand Up @@ -52,14 +51,15 @@ def resampler(data, warp, name=None):
with tf.name_scope(name or "resampler"):
data_tensor = tf.convert_to_tensor(data, name="data")
warp_tensor = tf.convert_to_tensor(warp, name="warp")
return _resampler_ops.addons_resampler(data_tensor, warp_tensor)
return _resampler_so.ops.addons_resampler(data_tensor, warp_tensor)


@tf.RegisterGradient("Addons>Resampler")
def _resampler_grad(op, grad_output):
data, warp = op.inputs
grad_output_tensor = tf.convert_to_tensor(grad_output, name="grad_output")
return _resampler_ops.addons_resampler_grad(data, warp, grad_output_tensor)
return _resampler_so.ops.addons_resampler_grad(data, warp,
grad_output_tensor)


tf.no_gradient("Addons>ResamplerGrad")
9 changes: 4 additions & 5 deletions tensorflow_addons/image/transform_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

import tensorflow as tf
from tensorflow_addons.image import utils as img_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_image_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/image/_image_ops.so"))
_image_so = LazySO("custom_ops/image/_image_ops.so")

_IMAGE_DTYPES = set([
tf.dtypes.uint8, tf.dtypes.int32, tf.dtypes.int64, tf.dtypes.float16,
Expand Down Expand Up @@ -98,7 +97,7 @@ def transform(images,
"transforms should have rank 1 or 2, but got rank %d" % len(
transforms.get_shape()))

output = _image_ops_so.addons_image_projective_transform_v2(
output = _image_so.ops.addons_image_projective_transform_v2(
images,
output_shape=output_shape,
transforms=transforms,
Expand Down Expand Up @@ -270,7 +269,7 @@ def _image_projective_transform_grad(op, grad):
transforms = flat_transforms_to_matrices(transforms=transforms)
inverse = tf.linalg.inv(transforms)
transforms = matrices_to_flat_transforms(inverse)
output = _image_ops_so.addons_image_projective_transform_v2(
output = _image_so.ops.addons_image_projective_transform_v2(
images=grad,
transforms=transforms,
output_shape=tf.shape(image_or_images)[1:3],
Expand Down
10 changes: 5 additions & 5 deletions tensorflow_addons/layers/optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_correlation_cost_op_so = tf.load_op_library(
get_path_to_datafile("custom_ops/layers/_correlation_cost_ops.so"))
_correlation_cost_so = LazySO(
"custom_ops/layers/_correlation_cost_ops.so")


def _correlation_cost(input_a,
Expand Down Expand Up @@ -81,7 +81,7 @@ def _correlation_cost(input_a,
"""

with tf.name_scope(name or "correlation_cost"):
op_call = _correlation_cost_op_so.addons_correlation_cost
op_call = _correlation_cost_so.ops.addons_correlation_cost

if data_format == "channels_last":
op_data_format = "NHWC"
Expand Down Expand Up @@ -120,7 +120,7 @@ def _correlation_cost_grad(op, grad_output):
input_b = tf.convert_to_tensor(op.inputs[1], name="input_b")
grad_output_tensor = tf.convert_to_tensor(grad_output, name="grad_output")

op_call = _correlation_cost_op_so.addons_correlation_cost_grad
op_call = _correlation_cost_so.ops.addons_correlation_cost_grad
grads = op_call(
input_a,
input_b,
Expand Down
10 changes: 6 additions & 4 deletions tensorflow_addons/seq2seq/beam_search_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
from tensorflow_addons.seq2seq import attention_wrapper
from tensorflow_addons.seq2seq import decoder
from tensorflow_addons.utils import keras_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_beam_search_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/seq2seq/_beam_search_ops.so"))
gather_tree = _beam_search_ops_so.addons_gather_tree
_beam_search_so = LazySO("custom_ops/seq2seq/_beam_search_ops.so")


def gather_tree(*args, **kwargs):
return _beam_search_so.ops.addons_gather_tree(*args, **kwargs)


class BeamSearchDecoderState(
Expand Down
7 changes: 1 addition & 6 deletions tensorflow_addons/seq2seq/beam_search_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,8 @@
import tensorflow as tf

from tensorflow_addons.seq2seq import attention_wrapper
from tensorflow_addons.seq2seq import beam_search_decoder
from tensorflow_addons.seq2seq import beam_search_decoder, gather_tree
from tensorflow_addons.utils import test_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile

_beam_search_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/seq2seq/_beam_search_ops.so"))
gather_tree = _beam_search_ops_so.addons_gather_tree


class TestGatherTree(tf.test.TestCase):
Expand Down
6 changes: 1 addition & 5 deletions tensorflow_addons/seq2seq/beam_search_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@
import numpy as np
import tensorflow as tf

from tensorflow_addons.utils.resource_loader import get_path_to_datafile

_beam_search_ops_so = tf.load_op_library(
get_path_to_datafile("custom_ops/seq2seq/_beam_search_ops.so"))
gather_tree = _beam_search_ops_so.addons_gather_tree
from tensorflow_addons.seq2seq import gather_tree


def _transpose_batch_time(x):
Expand Down
9 changes: 4 additions & 5 deletions tensorflow_addons/text/parse_time_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

import tensorflow as tf

from tensorflow_addons.utils.resource_loader import get_path_to_datafile
from tensorflow_addons.utils.resource_loader import LazySO

_parse_time_op = tf.load_op_library(
get_path_to_datafile("custom_ops/text/_parse_time_op.so"))
_parse_time_so = LazySO("custom_ops/text/_parse_time_op.so")

tf.no_gradient("Addons>ParseTime")

Expand Down Expand Up @@ -82,5 +81,5 @@ def parse_time(time_string, time_format, output_unit):
ValueError: If `output_unit` is not a valid value,
if parsing `time_string` according to `time_format` failed.
"""
return _parse_time_op.addons_parse_time(time_string, time_format,
output_unit)
return _parse_time_so.ops.addons_parse_time(time_string, time_format,
output_unit)
Loading

0 comments on commit 9e13d30

Please sign in to comment.