Skip to content

Commit

Permalink
Add support for Tensorflow SparseTensors: core classes. (#839)
Browse files Browse the repository at this point in the history
This adds:
- Support for specifying  sparse in `KerasTensor` and `Input`.
- A boolean flag `backend.SUPPORTS_SPARSE_TENSORS`.
- Support for `tf.SparseTensor` is Tensorflow core ops.
  • Loading branch information
hertschuh authored Sep 6, 2023
1 parent 49e5b06 commit 67722d7
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 18 deletions.
12 changes: 10 additions & 2 deletions keras_core/backend/common/keras_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@ class KerasTensor:
dtype is called "static shape inference".
"""

def __init__(self, shape, dtype="float32", record_history=True, name=None):
def __init__(
self,
shape,
dtype="float32",
sparse=False,
record_history=True,
name=None,
):
from keras_core import backend

self.shape = backend.standardize_shape(shape)
self.dtype = backend.standardize_dtype(dtype)
self.sparse = sparse
self.name = name or auto_name(self.__class__.__name__)
self.record_history = record_history

Expand Down Expand Up @@ -106,7 +114,7 @@ def __tf_tensor__(self, dtype=None, name=None):
def __repr__(self):
return (
f"<KerasTensor shape={self.shape}, dtype={self.dtype}, "
f"name={self.name}>"
f"sparse={self.sparse}, name={self.name}>"
)

def __iter__(self):
Expand Down
3 changes: 2 additions & 1 deletion keras_core/backend/common/keras_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

class KerasTensorTest(testing.TestCase):
def test_attributes(self):
x = keras_tensor.KerasTensor(shape=(3,), dtype="float32")
x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", sparse=True)
self.assertEqual(x.dtype, "float32")
self.assertEqual(x.shape, (3,))
self.assertEqual(x.sparse, True)

def test_numpy_methods(self):
x = keras_tensor.KerasTensor(shape=(3, 2), dtype="float32")
Expand Down
1 change: 1 addition & 0 deletions keras_core/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras_core.backend.jax import nn
from keras_core.backend.jax import numpy
from keras_core.backend.jax import random
from keras_core.backend.jax.core import SUPPORTS_SPARSE_TENSORS
from keras_core.backend.jax.core import Variable
from keras_core.backend.jax.core import cast
from keras_core.backend.jax.core import compute_output_spec
Expand Down
6 changes: 5 additions & 1 deletion keras_core/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from keras_core.backend.jax import distribution_lib
from keras_core.utils.nest import pack_sequence_as

SUPPORTS_SPARSE_TENSORS = False


class Variable(KerasVariable):
def _initialize(self, value):
Expand Down Expand Up @@ -44,7 +46,9 @@ def __jax_array__(self):
return self.value


def convert_to_tensor(x, dtype=None):
def convert_to_tensor(x, dtype=None, sparse=False):
if sparse:
raise ValueError("`sparse=True` is not supported with jax backend")
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
Expand Down
1 change: 1 addition & 0 deletions keras_core/backend/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras_core.backend.numpy import nn
from keras_core.backend.numpy import numpy
from keras_core.backend.numpy import random
from keras_core.backend.numpy.core import SUPPORTS_SPARSE_TENSORS
from keras_core.backend.numpy.core import Variable
from keras_core.backend.numpy.core import cast
from keras_core.backend.numpy.core import compute_output_spec
Expand Down
6 changes: 5 additions & 1 deletion keras_core/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.utils.nest import pack_sequence_as

SUPPORTS_SPARSE_TENSORS = False


class Variable(KerasVariable):
def _initialize(self, value):
Expand All @@ -23,7 +25,9 @@ def __array__(self):
return self.value


def convert_to_tensor(x, dtype=None):
def convert_to_tensor(x, dtype=None, sparse=False):
if sparse:
raise ValueError("`sparse=True` is not supported with numpy backend")
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
Expand Down
1 change: 1 addition & 0 deletions keras_core/backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras_core.backend.tensorflow import numpy
from keras_core.backend.tensorflow import random
from keras_core.backend.tensorflow import tensorboard
from keras_core.backend.tensorflow.core import SUPPORTS_SPARSE_TENSORS
from keras_core.backend.tensorflow.core import Variable
from keras_core.backend.tensorflow.core import cast
from keras_core.backend.tensorflow.core import compute_output_spec
Expand Down
26 changes: 21 additions & 5 deletions keras_core/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.utils.naming import auto_name

SUPPORTS_SPARSE_TENSORS = True


class Variable(
KerasVariable,
Expand Down Expand Up @@ -70,15 +72,28 @@ def _write_object_proto(self, proto, options):
return self.value._write_object_proto(proto, options)


def convert_to_tensor(x, dtype=None):
def convert_to_tensor(x, dtype=None, sparse=True):
"""Convert to a TensorFlow tensor.
`sparse=True` means that `tf.SparseTensor`s are returned as-is, which is the
default with the TensorFlow backend. An explicit `sparse=False` densifies
`tf.SparseTensor`s.
"""
if isinstance(x, tf.SparseTensor) and not sparse:
x = tf.sparse.to_dense(x)
if dtype is not None:
dtype = standardize_dtype(dtype)
if tf.is_tensor(x):
return tf.cast(x, dtype=dtype)
return tf.convert_to_tensor(x, dtype=dtype)
if not tf.is_tensor(x):
return tf.convert_to_tensor(x, dtype=dtype)
elif dtype is not None:
return tf.cast(x, dtype=dtype)
else:
return x


def convert_to_numpy(x):
if isinstance(x, tf.SparseTensor):
x = tf.sparse.to_dense(x)
return np.array(x)


Expand All @@ -95,7 +110,8 @@ def shape(x):
tensor values when the shape is unknown (this is tf specific, as dynamic
shapes do not apply in other backends).
"""
x = tf.convert_to_tensor(x)
if not tf.is_tensor(x):
x = tf.convert_to_tensor(x)
dynamic = tf.shape(x)
if x.shape == tf.TensorShape(None):
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions keras_core/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras_core.backend.torch import nn
from keras_core.backend.torch import numpy
from keras_core.backend.torch import random
from keras_core.backend.torch.core import SUPPORTS_SPARSE_TENSORS
from keras_core.backend.torch.core import Variable
from keras_core.backend.torch.core import cast
from keras_core.backend.torch.core import compute_output_spec
Expand Down
6 changes: 5 additions & 1 deletion keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.utils.nest import pack_sequence_as

SUPPORTS_SPARSE_TENSORS = False

# Some operators such as 'aten::_foreach_mul_.Scalar'
# are not currently implemented for the MPS device.
# check https://github.com/pytorch/pytorch/issues/77764.
Expand Down Expand Up @@ -118,7 +120,9 @@ def __eq__(self, other):
return False


def convert_to_tensor(x, dtype=None):
def convert_to_tensor(x, dtype=None, sparse=False):
if sparse:
raise ValueError("`sparse=True` is not supported with torch backend")
if is_tensor(x):
device = get_device()
if x.device != device:
Expand Down
20 changes: 18 additions & 2 deletions keras_core/layers/core/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def __init__(
shape=None,
batch_size=None,
dtype=None,
sparse=None,
batch_shape=None,
input_tensor=None,
name=None,
**kwargs,
):
# TODO: support for sparse, ragged.
# TODO: support for ragged.
super().__init__(name=name)
if "input_shape" in kwargs:
warnings.warn(
Expand All @@ -45,6 +46,13 @@ def __init__(
self.batch_shape = tuple(batch_shape)
self._dtype = backend.standardize_dtype(dtype)

self.sparse = bool(sparse)
if self.sparse and not backend.SUPPORTS_SPARSE_TENSORS:
raise ValueError(
"`sparse=True` is not supported with backend: "
f"{backend.backend()}"
)

if input_tensor is not None:
if not isinstance(input_tensor, backend.KerasTensor):
raise ValueError(
Expand All @@ -54,7 +62,7 @@ def __init__(
)
else:
input_tensor = backend.KerasTensor(
shape=batch_shape, dtype=dtype, name=name
shape=batch_shape, dtype=dtype, sparse=sparse, name=name
)
self._input_tensor = input_tensor
Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor)
Expand All @@ -71,6 +79,7 @@ def get_config(self):
return {
"batch_shape": self.batch_shape,
"dtype": self.dtype,
"sparse": self.sparse,
"name": self.name,
}

Expand All @@ -80,6 +89,7 @@ def Input(
shape=None,
batch_size=None,
dtype=None,
sparse=None,
batch_shape=None,
name=None,
tensor=None,
Expand All @@ -104,6 +114,11 @@ def Input(
batch_size: Optional static batch size (integer).
dtype: The data type expected by the input, as a string
(e.g. `"float32"`, `"int32"`...)
sparse: A boolean specifying whether the expected input will be sparse
tensors. Note that, if `sparse` is `False`, sparse tensors can still
be passed into the input - they will be densified with a default
value of 0. This feature is only supported with the TensorFlow
backend. Defaults to `False`.
name: Optional name string for the layer.
Should be unique in a model (do not reuse the same name twice).
It will be autogenerated if it isn't provided.
Expand All @@ -127,6 +142,7 @@ def Input(
shape=shape,
batch_size=batch_size,
dtype=dtype,
sparse=sparse,
batch_shape=batch_shape,
name=name,
input_tensor=tensor,
Expand Down
32 changes: 27 additions & 5 deletions keras_core/layers/core/input_layer_test.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,51 @@
import numpy as np
from absl.testing import parameterized

from keras_core import backend
from keras_core import testing
from keras_core.backend import KerasTensor
from keras_core.layers import InputLayer


class InputLayerTest(testing.TestCase):
class InputLayerTest(testing.TestCase, parameterized.TestCase):
# Testing happy path for layer without input tensor
def test_input_basic(self):
@parameterized.named_parameters(
[
{"testcase_name": "dense", "sparse": False},
{"testcase_name": "sparse", "sparse": True},
]
)
def test_input_basic(self, sparse):
input_shape = (2, 3)
batch_size = 4
dtype = "float32"
ndim = len(tuple((batch_size,) + input_shape))

values = InputLayer(
shape=input_shape, batch_size=batch_size, dtype=dtype
)
init_kwargs = {
"shape": input_shape,
"batch_size": batch_size,
"dtype": dtype,
"sparse": sparse,
}

if sparse and not backend.SUPPORTS_SPARSE_TENSORS:
with self.assertRaisesRegex(
ValueError, "`sparse=True` is not supported"
):
InputLayer(**init_kwargs)
return

values = InputLayer(**init_kwargs)

self.assertEqual(values.dtype, dtype)
self.assertEqual(values.batch_shape[0], batch_size)
self.assertEqual(values.batch_shape[1:], input_shape)
self.assertEqual(values.sparse, sparse)
self.assertEqual(values.trainable, True)
self.assertIsInstance(values.output, KerasTensor)
self.assertEqual(values.output.ndim, ndim)
self.assertEqual(values.output.dtype, dtype)
self.assertEqual(values.output.sparse, sparse)

# Testing shape is not None and batch_shape is not None condition
def test_input_error1(self):
Expand Down
39 changes: 39 additions & 0 deletions keras_core/ops/core_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from keras_core import backend
from keras_core import layers
from keras_core import losses
from keras_core import models
Expand Down Expand Up @@ -270,6 +271,18 @@ def test_shape(self):
x = KerasTensor((None, 3, None, 1))
self.assertAllEqual(core.shape(x), (None, 3, None, 1))

@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
reason="Backend does not support sparse tensors.",
)
def test_shape_sparse(self):
import tensorflow as tf

x = tf.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 3)
)
self.assertAllEqual(core.shape(x), (2, 3))

def test_convert_to_tensor(self):
x = np.ones((2,))
x = ops.convert_to_tensor(x)
Expand All @@ -284,6 +297,32 @@ def test_convert_to_tensor(self):
with self.assertRaises(ValueError):
ops.convert_to_numpy(KerasTensor((2,)))

@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
reason="Backend does not support sparse tensors.",
)
def test_convert_to_tensor_sparse(self):
import tensorflow as tf

x = tf.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 3)
)

x_default = ops.convert_to_tensor(x)
self.assertIsInstance(x_default, tf.SparseTensor)
self.assertAllClose(x, x_default)
# Note that ops.convert_to_tensor does not expose the 'sparse' arg
x_sparse = backend.convert_to_tensor(x, sparse=True)
self.assertIsInstance(x_sparse, tf.SparseTensor)
self.assertAllClose(x, x_sparse)
x_dense = backend.convert_to_tensor(x, sparse=False)
self.assertNotIsInstance(x_dense, tf.SparseTensor)
self.assertAllClose(x, x_dense)

x_numpy = ops.convert_to_numpy(x)
self.assertIsInstance(x_numpy, np.ndarray)
self.assertAllClose(x_numpy, x_dense)

def test_cond(self):
t = ops.cond(True, lambda: 0, lambda: 1)
self.assertEqual(t, 0)
Expand Down

0 comments on commit 67722d7

Please sign in to comment.