Skip to content

Commit

Permalink
Add split test
Browse files Browse the repository at this point in the history
  • Loading branch information
Ch0ronomato committed Jan 26, 2025
1 parent dbc95e4 commit eb3ff29
Showing 1 changed file with 69 additions and 1 deletion.
70 changes: 69 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
Expand All @@ -38,6 +38,11 @@
py_mode = Mode(linker="py", optimizer=None)


def set_test_value(x, v):
x.tag.test_value = v
return x


def compare_pytorch_and_py(
fgraph: FunctionGraph,
test_inputs: Iterable,
Expand Down Expand Up @@ -471,3 +476,66 @@ def test_ScalarLoop_Elemwise_multi_carries():
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)


rng = np.random.default_rng(42849)


@pytest.mark.parametrize(
"n_splits, axis, values, sizes",
[
(
0,
0,
set_test_value(pt.vector(), rng.normal(size=20).astype(config.floatX)),
set_test_value(pt.vector(dtype="int64"), []),
),
(
5,
0,
set_test_value(pt.vector(), rng.normal(size=5).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5)
),
),
(
5,
0,
set_test_value(pt.vector(), rng.normal(size=10).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5)
),
),
(
5,
-1,
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5)
),
),
(
5,
-2,
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5)
),
),
],
)
def test_Split(n_splits, axis, values, sizes):
g = pt.split(values, sizes, n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g)

compare_pytorch_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)

0 comments on commit eb3ff29

Please sign in to comment.