Skip to content

Commit

Permalink
Remove false positive check for supported Subtensors operations in JAX
Browse files Browse the repository at this point in the history
The check was failing incorrectly for cases that are supported such as constant Boolean arrays.
Besides that, user may dispatch without necessarily jitting the graph. There is no reason to fail eagerly.
  • Loading branch information
ricardoV94 committed Jun 24, 2024
1 parent d3bd1f1 commit c27898a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 40 deletions.
20 changes: 2 additions & 18 deletions pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,36 +31,20 @@
"""


def subtensor_assert_indices_jax_compatible(node, idx_list):
from pytensor.graph.basic import Constant
from pytensor.tensor.variable import TensorVariable

ilist = indices_from_subtensor(node.inputs[1:], idx_list)
for idx in ilist:
if isinstance(idx, TensorVariable):
if idx.type.dtype == "bool":
raise NotImplementedError(BOOLEAN_MASK_ERROR)
elif isinstance(idx, slice):
for slice_arg in (idx.start, idx.stop, idx.step):
if slice_arg is not None and not isinstance(slice_arg, Constant):
raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR)


@jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
subtensor_assert_indices_jax_compatible(node, idx_list)

def subtensor_constant(x, *ilists):
def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1:
indices = indices[0]

return x.__getitem__(indices)

return subtensor_constant
return subtensor


@jax_funcify.register(IncSubtensor)
Expand Down
54 changes: 32 additions & 22 deletions tests/link/jax/test_subtensor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import pytest
from jax._src.errors import NonConcreteBooleanIndexError

import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import subtensor as pt_subtensor
from pytensor.tensor import tensor
from pytensor.tensor.rewriting.jax import (
boolean_indexing_set_or_inc,
boolean_indexing_sum,
Expand All @@ -13,54 +15,62 @@


def test_jax_Subtensor_constant():
shape = (3, 4, 5)
x_pt = tensor("x", shape=shape, dtype="int")
x_np = np.arange(np.prod(shape)).reshape(shape)

# Basic indices
x_pt = pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])

out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])

out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])

out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])

# Advanced indexing
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])

out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])

# Advanced and basic indexing
out_pt = x_pt[[1, 2], :]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])

out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])

# Flipping
out_pt = x_pt[::-1]
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])

# Boolean indexing should work if indexes are constant
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5))]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])


@pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling")
Expand All @@ -73,16 +83,16 @@ def test_jax_Subtensor_dynamic():
compare_jax_and_py(out_fg, [1])


def test_jax_Subtensor_boolean_mask():
"""JAX does not support resizing arrays with boolean masks."""
def test_jax_Subtensor_dynmaic_boolean_mask():
"""JAX does not support resizing arrays with dynamic boolean masks."""
x_pt = pt.vector("x", dtype="float64")
out_pt = x_pt[x_pt < 0]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)

out_fg = FunctionGraph([x_pt], [out_pt])

x_pt_test = np.arange(-5, 5)
with pytest.raises(NotImplementedError, match="resizing arrays with boolean"):
with pytest.raises(NonConcreteBooleanIndexError):
compare_jax_and_py(out_fg, [x_pt_test])


Expand Down

0 comments on commit c27898a

Please sign in to comment.