Skip to content

Commit

Permalink
Improve the flexibility of standardize_dtype and fix pad in torch…
Browse files Browse the repository at this point in the history
… backend (#828)

* improve flexibility in dtype check for torch

* Update

* Fix bugs

* Update

* Fix padding for torch backend

* Update
  • Loading branch information
james77777778 authored Sep 2, 2023
1 parent fa547ec commit de510e9
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 52 deletions.
14 changes: 5 additions & 9 deletions keras_core/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,25 +390,21 @@ def initialize_all_variables():

PYTHON_DTYPES_MAP = {
bool: "bool",
int: "int", # TBD by backend
int: "int64" if config.backend() == "tensorflow" else "int32",
float: "float32",
str: "string",
# special case for string value
"int": "int64" if config.backend() == "tensorflow" else "int32",
}


def standardize_dtype(dtype):
if dtype is None:
return config.floatx()
if dtype in PYTHON_DTYPES_MAP:
dtype = PYTHON_DTYPES_MAP.get(dtype)
if dtype == "int":
if config.backend() == "tensorflow":
dtype = "int64"
else:
dtype = "int32"
dtype = PYTHON_DTYPES_MAP.get(dtype, dtype)
if hasattr(dtype, "name"):
dtype = dtype.name
elif config.backend() == "torch":
elif hasattr(dtype, "__str__") and "torch" in str(dtype):
dtype = str(dtype).split(".")[-1]

if dtype not in ALLOWED_DTYPES:
Expand Down
6 changes: 6 additions & 0 deletions keras_core/backend/common/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,9 @@ def test_autocasting(self):

with AutocastScope("float16"):
self.assertEqual(backend.standardize_dtype(v.value.dtype), "int32")

def test_standardize_dtype_with_torch_dtype(self):
import torch

x = torch.randn(4, 4)
backend.standardize_dtype(x.dtype)
2 changes: 1 addition & 1 deletion keras_core/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):


def digitize(x, bins):
return np.digitize(x, bins)
return np.digitize(x, bins).astype(np.int32)


def dot(x, y):
Expand Down
44 changes: 36 additions & 8 deletions keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,12 @@ def imag(x):

def isclose(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
if torch.is_floating_point(x1) and not torch.is_floating_point(x2):
x2 = cast(x2, x1.dtype)
if torch.is_floating_point(x2) and not torch.is_floating_point(x1):
x1 = cast(x1, x2.dtype)
if x1.dtype != x2.dtype:
result_dtype = torch.result_type(x1, x2)
if x1.dtype != result_dtype:
x1 = cast(x1, result_dtype)
else:
x2 = cast(x2, result_dtype)
return torch.isclose(x1, x2)


Expand Down Expand Up @@ -670,16 +672,42 @@ def pad(x, pad_width, mode="constant"):
x = convert_to_tensor(x)
pad_sum = []
pad_width = list(pad_width)[::-1] # torch uses reverse order
pad_width_sum = 0
for pad in pad_width:
pad_width_sum += pad[0] + pad[1]
for pad in pad_width:
pad_sum += pad
pad_width_sum -= pad[0] + pad[1]
if pad_width_sum == 0: # early break when no padding in higher order
break
if mode == "symmetric":
mode = "replicate"
if mode != "constant" and x.ndim < 3:
if mode == "constant":
return torch.nn.functional.pad(x, pad=pad_sum, mode=mode)

# TODO: reflect and symmetric padding are implemented for padding the
# last 3 dimensions of a 4D or 5D input tensor, the last 2 dimensions of a
# 3D or 4D input tensor, or the last dimension of a 2D or 3D input tensor.
# https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
ori_dtype = x.dtype
ori_ndim = x.ndim
need_squeeze = False
if x.ndim < 3:
need_squeeze = True
new_dims = [1] * (3 - x.ndim)
x = cast(x, torch.float32) if x.dtype == torch.int else x
x = x.view(*new_dims, *x.shape)
return torch.nn.functional.pad(x, pad=pad_sum, mode=mode).squeeze()
return torch.nn.functional.pad(x, pad=pad_sum, mode=mode)
need_cast = False
if x.dtype not in (torch.float32, torch.float64):
# TODO: reflect and symmetric padding are only supported with float32/64
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
x = cast(x, torch.float32)
x = torch.nn.functional.pad(x, pad=pad_sum, mode=mode)
if need_cast:
x = cast(x, ori_dtype)
if need_squeeze:
x = torch.squeeze(x, dim=tuple(range(3 - ori_ndim)))
return x


def prod(x, axis=None, keepdims=False, dtype=None):
Expand Down
69 changes: 35 additions & 34 deletions keras_core/ops/numpy_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from absl.testing import parameterized
from tensorflow.python.ops.numpy_ops import np_config

from keras_core import backend
Expand Down Expand Up @@ -2223,7 +2224,7 @@ def test_digitize(self):
)


class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
class NumpyOneInputOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
def test_mean(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.mean(x), np.mean(x))
Expand Down Expand Up @@ -2933,52 +2934,52 @@ def test_ones_like(self):
self.assertAllClose(knp.ones_like(x), np.ones_like(x))
self.assertAllClose(knp.OnesLike()(x), np.ones_like(x))

def test_pad(self):
x = np.array([[1, 2], [3, 4]])
self.assertAllClose(
knp.pad(x, ((1, 1), (1, 1))),
np.pad(x, ((1, 1), (1, 1))),
)
self.assertAllClose(
knp.pad(x, ((1, 1), (1, 1))),
np.pad(x, ((1, 1), (1, 1))),
)

@parameterized.product(
dtype=[
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
],
mode=["constant", "reflect", "symmetric"],
)
def test_pad(self, dtype, mode):
# 2D
x = np.ones([2, 3], dtype=dtype)
pad_width = ((1, 1), (1, 1))
self.assertAllClose(
knp.Pad(((1, 1), (1, 1)))(x),
np.pad(x, ((1, 1), (1, 1))),
knp.pad(x, pad_width, mode=mode), np.pad(x, pad_width, mode=mode)
)
self.assertAllClose(
knp.Pad(((1, 1), (1, 1)))(x),
np.pad(x, ((1, 1), (1, 1))),
knp.Pad(pad_width, mode=mode)(x), np.pad(x, pad_width, mode=mode)
)

# 5D (pad last 3D)
x = np.ones([2, 3, 4, 5, 6], dtype=dtype)
pad_width = ((0, 0), (0, 0), (2, 3), (1, 1), (1, 1))
self.assertAllClose(
knp.pad(x, ((1, 1), (1, 1)), mode="reflect"),
np.pad(x, ((1, 1), (1, 1)), mode="reflect"),
knp.pad(x, pad_width, mode=mode), np.pad(x, pad_width, mode=mode)
)
self.assertAllClose(
knp.pad(x, ((1, 1), (1, 1)), mode="symmetric"),
np.pad(x, ((1, 1), (1, 1)), mode="symmetric"),
knp.Pad(pad_width, mode=mode)(x), np.pad(x, pad_width, mode=mode)
)

# 5D (pad arbitrary dimensions)
if backend.backend() == "torch" and mode != "constant":
self.skipTest(
"reflect and symmetric padding for arbitary dimensions are not "
"supported by torch"
)
x = np.ones([2, 3, 4, 5, 6], dtype=dtype)
pad_width = ((1, 1), (2, 1), (3, 2), (4, 3), (5, 4))
self.assertAllClose(
knp.Pad(((1, 1), (1, 1)), mode="reflect")(x),
np.pad(x, ((1, 1), (1, 1)), mode="reflect"),
)
self.assertAllClose(
knp.Pad(((1, 1), (1, 1)), mode="symmetric")(x),
np.pad(x, ((1, 1), (1, 1)), mode="symmetric"),
)

x = np.ones([2, 3, 4, 5])
self.assertAllClose(
knp.pad(x, ((2, 3), (1, 1), (1, 1), (1, 1))),
np.pad(x, ((2, 3), (1, 1), (1, 1), (1, 1))),
knp.pad(x, pad_width, mode=mode), np.pad(x, pad_width, mode=mode)
)
self.assertAllClose(
knp.Pad(((2, 3), (1, 1), (1, 1), (1, 1)))(x),
np.pad(x, ((2, 3), (1, 1), (1, 1), (1, 1))),
knp.Pad(pad_width, mode=mode)(x), np.pad(x, pad_width, mode=mode)
)

def test_prod(self):
Expand Down

0 comments on commit de510e9

Please sign in to comment.