-
Notifications
You must be signed in to change notification settings - Fork 116
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:keras-team/keras-core
- Loading branch information
Showing
12 changed files
with
603 additions
and
209 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,155 +1,254 @@ | ||
def compute_conv_transpose_output_length( | ||
input_length, | ||
kernel_size, | ||
padding, | ||
output_padding=None, | ||
stride=1, | ||
dilation=1, | ||
import warnings | ||
|
||
|
||
def _convert_conv_tranpose_padding_args_from_keras_to_jax( | ||
kernel_size, stride, dilation_rate, padding, output_padding | ||
): | ||
"""Computes output size of a transposed convolution given input size.""" | ||
assert padding in {"same", "valid"} | ||
if input_length is None: | ||
return None | ||
|
||
# Get the dilated kernel size | ||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) | ||
|
||
# Infer length if output padding is None, else compute the exact length | ||
if output_padding is None: | ||
if padding == "valid": | ||
length = input_length * stride + max(kernel_size - stride, 0) | ||
else: | ||
length = input_length * stride | ||
"""Convert the padding arguments from Keras to the ones used by JAX. | ||
JAX starts with an shape of size `(input-1) * stride - kernel_size + 2`, | ||
then adds `left_pad` on the left, and `right_pad` on the right. | ||
In Keras, the `padding` argument determines a base shape, to which | ||
`output_padding` is added on the right. If `output_padding` is None, it will | ||
be given a default value. | ||
""" | ||
|
||
assert padding.lower() in {"valid", "same"} | ||
kernel_size = (kernel_size - 1) * dilation_rate + 1 | ||
|
||
if padding.lower() == "valid": | ||
# If output_padding is None, we fill it so that the shape of the ouput | ||
# is `(input-1)*s + max(kernel_size, stride)` | ||
output_padding = ( | ||
max(kernel_size, stride) - kernel_size | ||
if output_padding is None | ||
else output_padding | ||
) | ||
left_pad = kernel_size - 1 | ||
right_pad = kernel_size - 1 + output_padding | ||
|
||
else: | ||
if padding == "same": | ||
pad = kernel_size // 2 | ||
if output_padding is None: | ||
# When output_padding is None, we want the shape of the ouput to | ||
# be `input * s`, therefore a total padding of | ||
# `stride + kernel_size - 2` | ||
pad_len = stride + kernel_size - 2 | ||
else: | ||
pad = 0 | ||
# When output_padding is filled, we want the shape of the ouput to | ||
# be `(input-1)*stride + kernel_size%2 + output_padding` | ||
pad_len = kernel_size + kernel_size % 2 - 2 + output_padding | ||
left_pad = min(pad_len // 2 + pad_len % 2, kernel_size - 1) | ||
right_pad = pad_len - left_pad | ||
|
||
return left_pad, right_pad | ||
|
||
|
||
length = ( | ||
(input_length - 1) * stride + kernel_size - 2 * pad + output_padding | ||
def _convert_conv_tranpose_padding_args_from_keras_to_torch( | ||
kernel_size, stride, dilation_rate, padding, output_padding | ||
): | ||
"""Convert the padding arguments from Keras to the ones used by Torch. | ||
Torch starts with an output shape of `(input-1) * stride + kernel_size`, | ||
then removes `torch_padding` from both sides, and adds | ||
`torch_output_padding` on the right. | ||
Because in Torch the output_padding can only be added to the right, | ||
consistency with Tensorflow is not always possible. In particular this is | ||
the case when both the Torch padding and output_padding values are stricly | ||
positive. | ||
""" | ||
assert padding.lower() in {"valid", "same"} | ||
original_kernel_size = kernel_size | ||
kernel_size = (kernel_size - 1) * dilation_rate + 1 | ||
|
||
if padding.lower() == "valid": | ||
# If output_padding is None, we fill it so that the shape of the ouput | ||
# is `(i-1)*s + max(k, s)` | ||
output_padding = ( | ||
max(kernel_size, stride) - kernel_size | ||
if output_padding is None | ||
else output_padding | ||
) | ||
torch_padding = 0 | ||
torch_output_padding = output_padding | ||
|
||
else: | ||
# When output_padding is None, we want the shape of the ouput to be | ||
# `input * s`, otherwise we use the value provided. | ||
output_padding = ( | ||
stride - kernel_size % 2 | ||
if output_padding is None | ||
else output_padding | ||
) | ||
torch_padding = max( | ||
-((kernel_size % 2 - kernel_size + output_padding) // 2), 0 | ||
) | ||
torch_output_padding = ( | ||
2 * torch_padding + kernel_size % 2 - kernel_size + output_padding | ||
) | ||
|
||
if torch_padding > 0 and torch_output_padding > 0: | ||
warnings.warn( | ||
f"You might experience inconsistencies accross backends when " | ||
f"calling conv transpose with kernel_size={original_kernel_size}, " | ||
f"stride={stride}, dilation_rate={dilation_rate}, " | ||
f"padding={padding}, output_padding={output_padding}." | ||
) | ||
return length | ||
|
||
if torch_output_padding >= stride: | ||
raise ValueError( | ||
f"The padding arguments (padding={padding}) and " | ||
f"output_padding={output_padding}) lead to a Torch " | ||
f"output_padding ({torch_output_padding}) that is greater than " | ||
f"strides ({stride}). This is not supported. You can change the " | ||
f"padding arguments, kernel or stride, or run on another backend. " | ||
) | ||
|
||
def compute_conv_transpose_output_shape( | ||
return torch_padding, torch_output_padding | ||
|
||
|
||
def compute_conv_transpose_padding_args_for_jax( | ||
input_shape, | ||
kernel_size, | ||
filters, | ||
kernel_shape, | ||
strides, | ||
padding, | ||
output_padding=None, | ||
data_format="channels_last", | ||
dilation_rate=1, | ||
output_padding, | ||
dilation_rate, | ||
): | ||
num_spatial_dims = len(input_shape) - 2 | ||
kernel_spatial_shape = kernel_size | ||
|
||
if isinstance(output_padding, int): | ||
output_padding = (output_padding,) * len(kernel_spatial_shape) | ||
if isinstance(strides, int): | ||
strides = (strides,) * num_spatial_dims | ||
if isinstance(dilation_rate, int): | ||
dilation_rate = (dilation_rate,) * num_spatial_dims | ||
|
||
if data_format == "channels_last": | ||
input_spatial_shape = input_shape[1:-1] | ||
else: | ||
input_spatial_shape = input_shape[2:] | ||
kernel_spatial_shape = kernel_shape[:-2] | ||
|
||
output_shape = [] | ||
jax_padding = [] | ||
for i in range(num_spatial_dims): | ||
current_output_padding = ( | ||
None if output_padding is None else output_padding[i] | ||
output_padding_i = ( | ||
output_padding | ||
if output_padding is None or isinstance(output_padding, int) | ||
else output_padding[i] | ||
) | ||
strides_i = strides if isinstance(strides, int) else strides[i] | ||
dilation_rate_i = ( | ||
dilation_rate | ||
if isinstance(dilation_rate, int) | ||
else dilation_rate[i] | ||
) | ||
output_shape.append( | ||
compute_conv_transpose_output_length( | ||
input_spatial_shape[i], | ||
kernel_spatial_shape[i], | ||
padding=padding, | ||
output_padding=current_output_padding, | ||
stride=strides[i], | ||
dilation=dilation_rate[i], | ||
) | ||
( | ||
pad_left, | ||
pad_right, | ||
) = _convert_conv_tranpose_padding_args_from_keras_to_jax( | ||
kernel_size=kernel_spatial_shape[i], | ||
stride=strides_i, | ||
dilation_rate=dilation_rate_i, | ||
padding=padding, | ||
output_padding=output_padding_i, | ||
) | ||
jax_padding.append((pad_left, pad_right)) | ||
|
||
if data_format == "channels_last": | ||
output_shape = [input_shape[0]] + output_shape + [filters] | ||
else: | ||
output_shape = [input_shape[0], filters] + output_shape | ||
return output_shape | ||
return jax_padding | ||
|
||
|
||
def _compute_conv_transpose_padding_one_dim( | ||
input_length, | ||
output_length, | ||
kernel_size, | ||
stride, | ||
def compute_conv_transpose_padding_args_for_torch( | ||
input_shape, | ||
kernel_shape, | ||
strides, | ||
padding, | ||
output_padding, | ||
dilation_rate, | ||
): | ||
"""Computes adjusted padding for `conv_transpose` in one dim.""" | ||
num_spatial_dims = len(input_shape) - 2 | ||
kernel_spatial_shape = kernel_shape[:-2] | ||
|
||
torch_paddings = [] | ||
torch_output_paddings = [] | ||
for i in range(num_spatial_dims): | ||
output_padding_i = ( | ||
output_padding | ||
if output_padding is None or isinstance(output_padding, int) | ||
else output_padding[i] | ||
) | ||
strides_i = strides if isinstance(strides, int) else strides[i] | ||
dilation_rate_i = ( | ||
dilation_rate | ||
if isinstance(dilation_rate, int) | ||
else dilation_rate[i] | ||
) | ||
( | ||
torch_padding, | ||
torch_output_padding, | ||
) = _convert_conv_tranpose_padding_args_from_keras_to_torch( | ||
kernel_size=kernel_spatial_shape[i], | ||
stride=strides_i, | ||
dilation_rate=dilation_rate_i, | ||
padding=padding, | ||
output_padding=output_padding_i, | ||
) | ||
torch_paddings.append(torch_padding) | ||
torch_output_paddings.append(torch_output_padding) | ||
|
||
return torch_paddings, torch_output_paddings | ||
|
||
|
||
def _get_output_shape_given_tf_padding( | ||
input_size, kernel_size, strides, padding, output_padding, dilation_rate | ||
): | ||
assert padding.lower() in {"valid", "same"} | ||
|
||
kernel_size = (kernel_size - 1) * dilation_rate + 1 | ||
if padding == "valid": | ||
padding_before = 0 | ||
else: | ||
# padding == "same". | ||
padding_needed = max( | ||
0, (input_length - 1) * stride + kernel_size - output_length | ||
|
||
if padding.lower() == "valid": | ||
output_padding = ( | ||
max(kernel_size, strides) - kernel_size | ||
if output_padding is None | ||
else output_padding | ||
) | ||
padding_before = padding_needed // 2 | ||
return (input_size - 1) * strides + kernel_size + output_padding | ||
|
||
expanded_input_length = (input_length - 1) * stride + 1 | ||
padded_out_length = output_length + kernel_size - 1 | ||
pad_before = kernel_size - 1 - padding_before | ||
pad_after = padded_out_length - expanded_input_length - pad_before | ||
return (pad_before, pad_after) | ||
else: | ||
if output_padding is None: | ||
return input_size * strides | ||
else: | ||
return (input_size - 1) * strides + kernel_size % 2 + output_padding | ||
|
||
|
||
def compute_conv_transpose_padding( | ||
def compute_conv_transpose_output_shape( | ||
input_shape, | ||
kernel_shape, | ||
strides=1, | ||
padding="valid", | ||
kernel_size, | ||
filters, | ||
strides, | ||
padding, | ||
output_padding=None, | ||
data_format="channels_last", | ||
dilation_rate=1, | ||
): | ||
"""Computes adjusted padding for `conv_transpose`.""" | ||
num_spatial_dims = len(input_shape) - 2 | ||
kernel_spatial_shape = kernel_size | ||
|
||
if isinstance(output_padding, int): | ||
output_padding = (output_padding,) * num_spatial_dims | ||
output_padding = (output_padding,) * len(kernel_spatial_shape) | ||
if isinstance(strides, int): | ||
strides = (strides,) * num_spatial_dims | ||
if isinstance(dilation_rate, int): | ||
dilation_rate = (dilation_rate,) * num_spatial_dims | ||
|
||
kernel_spatial_shape = kernel_shape[:-2] | ||
if data_format == "channels_last": | ||
input_spatial_shape = input_shape[1:-1] | ||
else: | ||
input_spatial_shape = input_shape[2:] | ||
padding_values = [] | ||
|
||
output_shape = [] | ||
for i in range(num_spatial_dims): | ||
input_length = input_spatial_shape[i] | ||
current_output_padding = ( | ||
None if output_padding is None else output_padding[i] | ||
) | ||
output_length = compute_conv_transpose_output_length( | ||
input_spatial_shape[i], | ||
kernel_spatial_shape[i], | ||
|
||
shape_i = _get_output_shape_given_tf_padding( | ||
input_size=input_spatial_shape[i], | ||
kernel_size=kernel_spatial_shape[i], | ||
strides=strides[i], | ||
padding=padding, | ||
output_padding=current_output_padding, | ||
stride=strides[i], | ||
dilation=dilation_rate[i], | ||
) | ||
padding_value = _compute_conv_transpose_padding_one_dim( | ||
input_length, | ||
output_length, | ||
kernel_spatial_shape[i], | ||
strides[i], | ||
padding=padding, | ||
dilation_rate=dilation_rate[i], | ||
) | ||
padding_values.append(padding_value) | ||
return padding_values | ||
output_shape.append(shape_i) | ||
|
||
if data_format == "channels_last": | ||
output_shape = [input_shape[0]] + output_shape + [filters] | ||
else: | ||
output_shape = [input_shape[0], filters] + output_shape | ||
return output_shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.