Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize make_vector #889

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,23 @@ def _get_vector_length_MakeVector(op, var):
return len(var.owner.inputs)


@_vectorize_node.register
def vectorize_make_vector(op: MakeVector, node, *batch_inputs):
# We vectorize make_vector as a join along the last axis of the broadcasted inputs
from pytensor.tensor.extra_ops import broadcast_arrays

# Check if we need to broadcast at all
bcast_pattern = batch_inputs[0].type.broadcastable
if not all(
batch_input.type.broadcastable == bcast_pattern for batch_input in batch_inputs
):
batch_inputs = broadcast_arrays(*batch_inputs)

# Join along the last axis
new_out = stack(batch_inputs, axis=-1)
return new_out.owner
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why .owner here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vectorize_node returns a node. No idea why, could probably have made it like rewrites and return the output variables

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's right there in the name eh?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's silly and potentially too restrictive, opened an issue: #902



def transfer(var, target):
"""
Return a version of `var` transferred to `target`.
Expand Down Expand Up @@ -2690,6 +2707,10 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
# We can vectorize join as a shifted axis on the batch inputs if:
# 1. The batch axis is a constant and has not changed
# 2. All inputs are batched with the same broadcastable pattern

# TODO: We can relax the second condition by broadcasting the batch dimensions
# This can be done with `broadcast_arrays` if the tensors shape match at the axis or reduction
# Or otherwise by calling `broadcast_to` for each tensor that needs it
if (
original_axis.type.ndim == 0
and isinstance(original_axis, Constant)
Expand Down
40 changes: 40 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4577,6 +4577,46 @@ def core_np(x):
)


@pytest.mark.parametrize(
"batch_shapes",
[
((3,),), # edge case of make_vector with a single input
((), (), ()), # Useless
((3,), (3,), (3,)), # No broadcasting needed
((3,), (5, 3), ()), # Broadcasting needed
],
)
def test_vectorize_make_vector(batch_shapes):
n_inputs = len(batch_shapes)
input_sig = ",".join(["()"] * n_inputs)
signature = f"{input_sig}->({n_inputs})" # Something like "(),(),()->(3)"

def core_pt(*scalars):
out = stack(scalars)
out.dprint()
return out

def core_np(*scalars):
return np.stack(scalars)

tensors = [tensor(shape=shape) for shape in batch_shapes]

vectorize_pt = function(tensors, vectorize(core_pt, signature=signature)(*tensors))
assert not any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)

test_values = [
np.random.normal(size=tensor.type.shape).astype(tensor.type.dtype)
for tensor in tensors
]

np.testing.assert_allclose(
vectorize_pt(*test_values),
np.vectorize(core_np, signature=signature)(*test_values),
)


@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)])
@pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"])
@config.change_flags(cxx="") # C code not needed
Expand Down
Loading