Skip to content

Commit

Permalink
Fix expand change_dist_size of SymbolicRandomVariables with size=None
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 27, 2025
1 parent 355b475 commit 9f52e3d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) ->

params = op.dist_params(rv.owner)

if expand:
if expand and not rv_size_is_none(size):
new_size = tuple(new_size) + tuple(size)

return op.rv_op(*params, size=new_size)
Expand Down
26 changes: 24 additions & 2 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import scipy.stats as st

from pytensor import shared
from pytensor.tensor import TensorVariable
from pytensor.tensor import NoneConst, TensorVariable
from pytensor.tensor.random.utils import normalize_size_param

import pymc as pm

Expand All @@ -43,7 +44,7 @@
)
from pymc.distributions.shape_utils import change_dist_size
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import compile
from pymc.pytensorf import compile, normalize_rng_param
from pymc.testing import (
BaseTestDistributionRandom,
I,
Expand Down Expand Up @@ -210,6 +211,27 @@ def test_recreate_with_different_rng_inputs(self):
new_next_rng, new_x = x.owner.op(*inputs)
assert op.update(new_x.owner) == {new_rng: new_next_rng}

def test_change_dist_size_none(self):
class TestRV(SymbolicRandomVariable):
extended_signature = "[rng],[size]->[rng],(n)"

@classmethod
def rv_op(cls, size=None, rng=None):
rng = normalize_rng_param(rng)
size = normalize_size_param(size)
next_rng, draws = Normal.dist(size=size, rng=rng).owner.outputs
return cls(inputs=[rng, size], outputs=[next_rng, draws])(rng, size)

size = NoneConst
rv = TestRV.rv_op(size=size)
assert rv.type.shape == ()

resized_rv = change_dist_size(rv, new_size=5)
assert resized_rv.type.shape == (5,)

resized_rv = change_dist_size(rv, new_size=5, expand=True)
assert resized_rv.type.shape == (5,)


def test_tag_future_warning_dist():
# Test no unexpected warnings
Expand Down
11 changes: 11 additions & 0 deletions tests/distributions/test_shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,17 @@ def test_change_rv_size():
assert tuple(rv_newer.shape.eval()) == (2,)


def test_change_rv_size_expand_none_size():
x = pt.random.normal()
size = x.owner.op.size_param(x.owner)
assert rv_size_is_none(size)
new_x = change_dist_size(x, new_size=(2,), expand=True)
new_size = new_x.owner.op.size_param(new_x.owner)
assert not rv_size_is_none(new_size)
assert new_size.data == [2]
assert new_x.type.shape == (2,)


def test_change_rv_size_default_update():
rng = pytensor.shared(np.random.default_rng(0))
x = normal(rng=rng)
Expand Down

0 comments on commit 9f52e3d

Please sign in to comment.