Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support consecutive integer vector indexing in Numba backend #1106

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 145 additions & 6 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand All @@ -13,6 +14,7 @@
IncSubtensor,
Subtensor,
)
from pytensor.tensor.type_other import NoneTypeT, SliceType


@numba_funcify.register(Subtensor)
Expand Down Expand Up @@ -104,18 +106,73 @@
@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
adv_idxs_dims = [
idx.type.ndim
if isinstance(op, AdvancedSubtensor):
x, y, idxs = node.inputs[0], None, node.inputs[1:]
else:
x, y, *idxs = node.inputs

basic_idxs = [
idx
for idx in idxs
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
if (
isinstance(idx.type, NoneTypeT)
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
)
]
adv_idxs = [
{
"axis": i,
"dtype": idx.type.dtype,
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
for i, idx in enumerate(idxs)
if isinstance(idx.type, TensorType)
]

# Special case for consecutive consecutive vector indices
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
# Check that x is not broadcasted to y based on broadcastable info
if len(x_bcast) < len(to_bcast):
return True
for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True):
if x_bcast_dim and not to_bcast_dim:
return True

Check warning on line 140 in pytensor/link/numba/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/subtensor.py#L140

Added line #L140 was not covered by tests
return False

# Special implementation for consecutive integer vector indices
if (
not basic_idxs
and len(adv_idxs) >= 2
# Must be integer vectors
# Todo: we could allow shape=(1,) if this is the shape of x
and all(
(adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool")
for adv_idx in adv_idxs
)
# Must be consecutive
and not op.non_contiguous_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted
and (
y is None
or not broadcasted_to(
y.type.broadcastable,
(
x.type.broadcastable[: adv_idxs[0]["axis"]]
+ x.type.broadcastable[adv_idxs[-1]["axis"] :]
),
)
)
):
return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs)

# Other cases not natively supported by Numba (fallback to obj-mode)
if (
# Numba does not support indexes with more than one dimension
any(idx["ndim"] > 1 for idx in adv_idxs)
# Nor multiple vector indexes
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
# The default index implementation does not handle duplicate indices correctly
or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1
# The default PyTensor implementation does not handle duplicate indices correctly
or (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
Expand All @@ -124,9 +181,91 @@
):
return generate_fallback_impl(op, node, **kwargs)

# What's left should all be supported natively by numba
return numba_funcify_default_subtensor(op, node, **kwargs)


def numba_funcify_multiple_integer_vector_indexing(
op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs
):
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
if isinstance(op, AdvancedSubtensor):
y, idxs = None, node.inputs[1:]
else:
y, *idxs = node.inputs[1:]

first_axis = next(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
)
try:
after_last_axis = next(
i
for i, idx in enumerate(idxs[first_axis:], start=first_axis)
if not isinstance(idx.type, TensorType)
)
except StopIteration:
after_last_axis = len(idxs)

if isinstance(op, AdvancedSubtensor):

@numba_njit
def advanced_subtensor_multiple_vector(x, *idxs):
none_slices = idxs[:first_axis]
vec_idxs = idxs[first_axis:after_last_axis]

x_shape = x.shape
idx_shape = vec_idxs[0].shape
shape_bef = x_shape[:first_axis]
shape_aft = x_shape[after_last_axis:]
out_shape = (*shape_bef, *idx_shape, *shape_aft)
out_buffer = np.empty(out_shape, dtype=x.dtype)
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
return out_buffer

return advanced_subtensor_multiple_vector

elif op.set_instead_of_inc:
inplace = op.inplace

@numba_njit
def advanced_set_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape

if inplace:
out = x
else:
out = x.copy()

for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out

return advanced_set_subtensor_multiple_vector

else:
inplace = op.inplace

@numba_njit
def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
vec_idxs = idxs[first_axis:after_last_axis]
x_shape = x.shape

if inplace:
out = x
else:
out = x.copy()

for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out

return advanced_inc_subtensor_multiple_vector


@numba_funcify.register(AdvancedIncSubtensor1)
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
Expand Down
25 changes: 25 additions & 0 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2937,6 +2937,31 @@ def grad(self, inpt, output_gradients):
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()() for _ in idxs]

@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).

This function checks if the advanced indexing is non-contiguous,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.

See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing


Parameters
----------
node : Apply
The node of the AdvancedSubtensor operation.

Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
"""
_, _, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)


advanced_inc_subtensor = AdvancedIncSubtensor()
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
Expand Down
17 changes: 12 additions & 5 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,11 @@ def compare_numba_and_py(
fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]],
inputs: Sequence["TensorLike"],
assert_fn: Callable | None = None,
*,
numba_mode=numba_mode,
py_mode=py_mode,
updates=None,
inplace: bool = False,
eval_obj_mode: bool = True,
) -> tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality
Expand Down Expand Up @@ -276,7 +278,14 @@ def assert_fn(x, y):
pytensor_py_fn = function(
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
)
py_res = pytensor_py_fn(*inputs)

test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
py_res = pytensor_py_fn(*test_inputs)

# Get some coverage (and catch errors in python mode before unreadable numba ones)
if eval_obj_mode:
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode)

pytensor_numba_fn = function(
fn_inputs,
Expand All @@ -285,11 +294,9 @@ def assert_fn(x, y):
accept_inplace=True,
updates=updates,
)
numba_res = pytensor_numba_fn(*inputs)

# Get some coverage
if eval_obj_mode:
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
numba_res = pytensor_numba_fn(*test_inputs)

if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res, strict=True):
Expand Down
Loading
Loading