Skip to content

Commit

Permalink
Add nan_to_num conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jun 1, 2024
1 parent b7b309d commit f97146c
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 0 deletions.
50 changes: 50 additions & 0 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,56 @@ def c_code_cache_version(self):
isinf = IsInf()


class IsPosInf(FixedLogicalComparison):
nfunc_spec = ("isposinf", 1, 1)

def impl(self, x):
return np.isposinf(x)

def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError()
# Discrete type can never be posinf
if node.inputs[0].type in discrete_types:
return f"{z} = false;"

return f"{z} = isinf({x}) && !signbit({x});"

def c_code_cache_version(self):
scalarop_version = super().c_code_cache_version()
return (*scalarop_version, 4)


isposinf = IsPosInf()


class IsNegInf(FixedLogicalComparison):
nfunc_spec = ("isneginf", 1, 1)

def impl(self, x):
return np.isneginf(x)

def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError()
# Discrete type can never be neginf
if node.inputs[0].type in discrete_types:
return f"{z} = false;"

return f"{z} = isinf({x}) && signbit({x});"

def c_code_cache_version(self):
scalarop_version = super().c_code_cache_version()
return (*scalarop_version, 4)


isneginf = IsNegInf()


class InRange(LogicalComparison):
nin = 3

Expand Down
100 changes: 100 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,46 @@ def isinf(a):
return isinf_(a)


@scalar_elemwise
def isposinf(a):
"""isposinf(a)"""


# Rename isposnan to isposnan_ to allow to bypass it when not needed.
# glibc 2.23 don't allow isposnan on int, so we remove it from the graph.
isposinf_ = isposinf


def isposinf(a):
"""isposinf(a)"""
a = as_tensor_variable(a)
if a.dtype in discrete_dtypes:
return alloc(
np.asarray(False, dtype="bool"), *[a.shape[i] for i in range(a.ndim)]
)
return isposinf_(a)


@scalar_elemwise
def isneginf(a):
"""isneginf(a)"""


# Rename isnegnan to isnegnan_ to allow to bypass it when not needed.
# glibc 2.23 don't allow isnegnan on int, so we remove it from the graph.
isneginf_ = isneginf


def isneginf(a):
"""isneginf(a)"""
a = as_tensor_variable(a)
if a.dtype in discrete_dtypes:
return alloc(
np.asarray(False, dtype="bool"), *[a.shape[i] for i in range(a.ndim)]
)
return isneginf_(a)


def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
"""
Implement Numpy's ``allclose`` on tensors.
Expand Down Expand Up @@ -3043,6 +3083,65 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
return vectorize_node_fallback(op, node, batched_x, batched_y)


def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"""
Replace NaN with zero and infinity with large finite numbers (default
behaviour) or with the numbers defined by the user using the `nan`,
`posinf` and/or `neginf` keywords.
NaN is replaced by zero or by the user defined value in
`nan` keyword, infinity is replaced by the largest finite floating point
values representable by ``x.dtype`` or by the user defined value in
`posinf` keyword and -infinity is replaced by the most negative finite
floating point values representable by ``x.dtype`` or by the user defined
value in `neginf` keyword.
Parameters
----------
x : symbolic tensor
Input array.
nan
The value to replace NaN's with in the tensor (default = 0).
posinf
The value to replace +INF with in the tensor (default max
in range representable by ``x.dtype``).
neginf
The value to replace -INF with in the tensor (default min
in range representable by ``x.dtype``).
Returns
-------
out
The tensor with NaN's, +INF, and -INF replaced with the
specified and/or default substitutions.
"""
# Replace NaN's with nan keyword
is_nan = isnan(x)
is_pos_inf = isposinf(x)
is_neg_inf = isneginf(x)

if not any(is_nan) and not any(is_pos_inf) and not any(is_neg_inf):
return

x = switch(is_nan, nan, x)

# Get max and min values representable by x.dtype
maxf = posinf
minf = neginf

# Specify the value to replace +INF and -INF with
if maxf is None:
maxf = np.finfo(x.real.dtype).max
if minf is None:
minf = np.finfo(x.real.dtype).min

# Replace +INF and -INF values
x = switch(is_pos_inf, maxf, x)
x = switch(is_neg_inf, minf, x)

return x


# NumPy logical aliases
square = sqr

Expand Down Expand Up @@ -3199,4 +3298,5 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
"logaddexp",
"logsumexp",
"hyp2f1",
"nan_to_num",
]
24 changes: 24 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
minimum,
mod,
mul,
nan_to_num,
neg,
neq,
outer,
Expand Down Expand Up @@ -3641,3 +3642,26 @@ def test_grad_n_undefined(self):
n = scalar(dtype="int64")
with pytest.raises(NullTypeGradError):
grad(polygamma(n, 0.5), wrt=n)


@pytest.mark.parametrize(
["nan", "posinf", "neginf"],
[(0, None, None), (0, 0, 0), (0, None, 1000), (3, 1, -1)],
)
def test_nan_to_num(nan, posinf, neginf):
x = tensor(shape=(7,))

out = nan_to_num(x, nan, posinf, neginf)

f = function([x], nan_to_num(x, nan, posinf, neginf), on_unused_input="warn")

y = np.array([1, 2, np.nan, np.inf, -np.inf, 3, 4])
out = f(y)

posinf = np.finfo(x.real.dtype).max if posinf is None else posinf
neginf = np.finfo(x.real.dtype).min if neginf is None else neginf

np.testing.assert_allclose(
out,
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
)

0 comments on commit f97146c

Please sign in to comment.