Skip to content

Commit

Permalink
Fixes in jax backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 19, 2023
1 parent e3d6719 commit e5a6ab9
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions keras_core/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def append(
x2,
axis=None,
):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.append(x1, x2, axis=axis)


Expand Down Expand Up @@ -272,10 +274,14 @@ def full_like(x, fill_value, dtype=None):


def greater(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.greater(x1, x2)


def greater_equal(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.greater_equal(x1, x2)


Expand All @@ -292,6 +298,8 @@ def imag(x):


def isclose(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.isclose(x1, x2)


Expand All @@ -308,10 +316,14 @@ def isnan(x):


def less(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.less(x1, x2)


def less_equal(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.less_equal(x1, x2)


Expand Down Expand Up @@ -346,10 +358,14 @@ def log2(x):


def logaddexp(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.logaddexp(x1, x2)


def logical_and(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.logical_and(x1, x2)


Expand All @@ -358,6 +374,8 @@ def logical_not(x):


def logical_or(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.logical_or(x1, x2)


Expand All @@ -374,6 +392,8 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):


def maximum(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.maximum(x1, x2)


Expand All @@ -386,10 +406,14 @@ def min(x, axis=None, keepdims=False, initial=None):


def minimum(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.minimum(x1, x2)


def mod(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.mod(x1, x2)


Expand All @@ -410,6 +434,8 @@ def nonzero(x):


def not_equal(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.not_equal(x1, x2)


Expand Down Expand Up @@ -550,6 +576,8 @@ def vstack(xs):


def where(condition, x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.where(condition, x1, x2)


Expand All @@ -560,10 +588,14 @@ def divide(x1, x2):


def true_divide(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.true_divide(x1, x2)


def power(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.power(x1, x2)


Expand Down Expand Up @@ -609,8 +641,12 @@ def eye(N, M=None, k=0, dtype="float32"):


def floor_divide(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.floor_divide(x1, x2)


def logical_xor(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.logical_xor(x1, x2)

0 comments on commit e5a6ab9

Please sign in to comment.