Skip to content

Commit

Permalink
Inline the deprecated tf.layers import.
Browse files Browse the repository at this point in the history
In principle this still exists in Keras, but it's a pain to import and it seemed simpler to just inline the few functions.

PiperOrigin-RevId: 565273232
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Sep 14, 2023
1 parent e5f85c6 commit f4836cc
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 14 deletions.
127 changes: 118 additions & 9 deletions tensorflow_probability/python/layers/conv_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tensorflow_probability.python.internal import docstring_util
from tensorflow_probability.python.layers import util as tfp_layers_util
from tensorflow_probability.python.util.seed_stream import SeedStream
from tensorflow.python.layers import utils as tf_layers_util # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import nn_ops # pylint: disable=g-direct-tensorflow-import


Expand Down Expand Up @@ -149,12 +148,12 @@ def __init__(
**kwargs)
self.rank = rank
self.filters = filters
self.kernel_size = tf_layers_util.normalize_tuple(
self.kernel_size = normalize_tuple(
kernel_size, rank, 'kernel_size')
self.strides = tf_layers_util.normalize_tuple(strides, rank, 'strides')
self.padding = tf_layers_util.normalize_padding(padding)
self.data_format = tf_layers_util.normalize_data_format(data_format)
self.dilation_rate = tf_layers_util.normalize_tuple(
self.strides = normalize_tuple(strides, rank, 'strides')
self.padding = normalize_padding(padding)
self.data_format = normalize_data_format(data_format)
self.dilation_rate = normalize_tuple(
dilation_rate, rank, 'dilation_rate')
self.activation = tf.keras.activations.get(activation)
self.input_spec = tf.keras.layers.InputSpec(ndim=self.rank + 2)
Expand Down Expand Up @@ -216,7 +215,7 @@ def build(self, input_shape):
dilation_rate=self.dilation_rate,
strides=self.strides,
padding=self.padding.upper(),
data_format=tf_layers_util.convert_data_format(
data_format=convert_data_format(
self.data_format, self.rank + 2))

self.built = True
Expand Down Expand Up @@ -256,7 +255,7 @@ def compute_output_shape(self, input_shape):
space = input_shape[1:-1]
new_space = []
for i in range(len(space)):
new_dim = tf_layers_util.conv_output_length(
new_dim = conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
Expand All @@ -268,7 +267,7 @@ def compute_output_shape(self, input_shape):
space = input_shape[2:]
new_space = []
for i in range(len(space)):
new_dim = tf_layers_util.conv_output_length(
new_dim = conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
Expand Down Expand Up @@ -1581,3 +1580,113 @@ def __init__(
Convolution1DFlipout = Conv1DFlipout
Convolution2DFlipout = Conv2DFlipout
Convolution3DFlipout = Conv3DFlipout


def convert_data_format(data_format, ndim): # pylint: disable=missing-function-docstring
if data_format == 'channels_last':
if ndim == 3:
return 'NWC'
elif ndim == 4:
return 'NHWC'
elif ndim == 5:
return 'NDHWC'
else:
raise ValueError(f'Input rank: {ndim} not supported. We only support '
'input rank 3, 4 or 5.')
elif data_format == 'channels_first':
if ndim == 3:
return 'NCW'
elif ndim == 4:
return 'NCHW'
elif ndim == 5:
return 'NCDHW'
else:
raise ValueError(f'Input rank: {ndim} not supported. We only support '
'input rank 3, 4 or 5.')
else:
raise ValueError(f'Invalid data_format: {data_format}. We only support '
'"channels_first" or "channels_last"')


def normalize_tuple(value, n, name):
"""Transforms a single integer or iterable of integers into an integer tuple.
Args:
value: The value to validate and convert. Could an int, or any iterable
of ints.
n: The size of the tuple to be returned.
name: The name of the argument being validated, e.g. "strides" or
"kernel_size". This is only used to format error messages.
Returns:
A tuple of n integers.
Raises:
ValueError: If something else than an int/long or iterable thereof was
passed.
"""
if isinstance(value, int):
return (value,) * n
else:
try:
value_tuple = tuple(value)
except TypeError:
raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} '
f'integers. Received: {str(value)}') from None
if len(value_tuple) != n:
raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} '
f'integers. Received: {str(value)}')
for single_value in value_tuple:
try:
int(single_value)
except (ValueError, TypeError):
raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} '
f'integers. Received: {str(value)} including element '
f'{str(single_value)} of type '
f'{str(type(single_value))}') from None
return value_tuple


def normalize_data_format(value):
data_format = value.lower()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('The `data_format` argument must be one of '
'"channels_first", "channels_last". Received: '
f'{str(value)}.')
return data_format


def normalize_padding(value):
padding = value.lower()
if padding not in {'valid', 'same'}:
raise ValueError('The `padding` argument must be one of "valid", "same". '
f'Received: {str(padding)}.')
return padding


def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
"""Determines output length of a convolution given input length.
Args:
input_length: integer.
filter_size: integer.
padding: one of "same", "valid", "full".
stride: integer.
dilation: dilation rate, integer.
Returns:
The output length (integer).
"""
if input_length is None:
return None
assert padding in {'same', 'valid', 'full'}
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
if padding == 'same':
output_length = input_length
elif padding == 'valid':
output_length = input_length - dilated_filter_size + 1
elif padding == 'full':
output_length = input_length + dilated_filter_size - 1
else:
raise ValueError(f'Invalid padding: {padding}')
return (output_length + stride - 1) // stride
9 changes: 4 additions & 5 deletions tensorflow_probability/python/layers/conv_variational_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from tensorflow_probability.python.layers import util
from tensorflow_probability.python.random import random_ops
from tensorflow_probability.python.util import seed_stream
from tensorflow.python.layers import utils as tf_layers_util
from tensorflow.python.ops import nn_ops


Expand Down Expand Up @@ -369,13 +368,13 @@ def _testConvReparameterization(self, layer_class): # pylint: disable=invalid-n
tf.TensorShape(inputs.shape),
filter_shape=tf.TensorShape(kernel_shape),
padding='SAME',
data_format=tf_layers_util.convert_data_format(
data_format=conv_variational.convert_data_format(
self.data_format, inputs.shape.rank))
expected_outputs = convolution_op(inputs, kernel_posterior.result_sample)
expected_outputs = tf.nn.bias_add(
expected_outputs,
bias_posterior.result_sample,
data_format=tf_layers_util.convert_data_format(self.data_format, 4))
data_format=conv_variational.convert_data_format(self.data_format, 4))

[
expected_outputs_, actual_outputs_,
Expand Down Expand Up @@ -435,7 +434,7 @@ def _testConvFlipout(self, layer_class): # pylint: disable=invalid-name
tf.TensorShape(inputs.shape),
filter_shape=tf.TensorShape(kernel_shape),
padding='SAME',
data_format=tf_layers_util.convert_data_format(
data_format=conv_variational.convert_data_format(
self.data_format, inputs.shape.rank))

expected_kernel_posterior_affine = normal.Normal(
Expand Down Expand Up @@ -483,7 +482,7 @@ def _testConvFlipout(self, layer_class): # pylint: disable=invalid-name
expected_outputs = tf.nn.bias_add(
expected_outputs,
bias_posterior.result_sample,
data_format=tf_layers_util.convert_data_format(self.data_format, 4))
data_format=conv_variational.convert_data_format(self.data_format, 4))

[
expected_outputs_, actual_outputs_,
Expand Down

0 comments on commit f4836cc

Please sign in to comment.