Skip to content

Commit

Permalink
fix: Add support for handling tf.bool tensors in tf.min backend
Browse files Browse the repository at this point in the history
  • Loading branch information
hmahmood24 committed Aug 31, 2024
1 parent 80ccd1a commit 759ef5d
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions ivy/functional/backends/tensorflow/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
# -------------------#


@with_unsupported_dtypes(
{"2.15.0 and below": ("complex", "bool", "uint64")}, backend_version
)
@with_unsupported_dtypes({"2.15.0 and below": ("complex", "uint64")}, backend_version)
def min(
x: Union[tf.Tensor, tf.Variable],
/,
Expand All @@ -26,6 +24,9 @@ def min(
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
axis = tuple(axis) if isinstance(axis, list) else axis
is_bool = tf.dtypes.as_dtype(x.dtype) == tf.bool
if is_bool:
x = tf.cast(x, tf.int32)
if where is not None:
max_val = (
ivy.iinfo(x.dtype).max
Expand All @@ -36,6 +37,8 @@ def min(
result = tf.math.reduce_min(x, axis=axis, keepdims=keepdims)
if initial is not None:
result = tf.minimum(result, initial)
if is_bool:
result = tf.cast(result, tf.bool)
return result


Expand Down

0 comments on commit 759ef5d

Please sign in to comment.