Skip to content

Commit

Permalink
Make astrocyte grower RandomState rng consistent with global
Browse files Browse the repository at this point in the history
Change-Id: I3c2f7bc397420829db9e53141ef1d8f79ead9115
  • Loading branch information
eleftherioszisis committed Jun 2, 2021
1 parent b4b94e0 commit fa69154
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 14 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
'matplotlib>=1.3.1',
'tmd>=2.0.8',
'morphio>=2.7.1',
'neurom>=2,<3',
'neurom>=2,<2.2',
'scipy>=0.13.3',
'numpy>=1.15.0',
'jsonschema>=3.0.1',
Expand Down
35 changes: 24 additions & 11 deletions tests/astrocyte/test_grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,23 +193,36 @@ def _context():
}


def _astrocyte_grower():
def _global_rng():
np.random.seed(0)
return np.random


def _legacy_rng():
mt = np.random.MT19937()
mt._legacy_seeding(0) # Use legacy seeding to get the same result as with np.random.seed()
return np.random.RandomState(mt)


def test_grow__run():
from numpy.random import MT19937
from numpy.random import RandomState

parameters = _parameters()
distributions = _distributions()

context = _context()

return AstrocyteGrower(
input_distributions=distributions,
input_parameters=parameters,
context=context)
for rng in [_global_rng(), _legacy_rng()]:

def test_grow__run():
print("RNG: ", rng)

np.random.seed(0)
astro_grower = AstrocyteGrower(
input_distributions=distributions,
input_parameters=parameters,
context=context,
rng_or_seed=rng)

astro_grower = _astrocyte_grower()
astro_grower.grow()
difference = diff(astro_grower.neuron, _path / 'astrocyte.h5', atol=0.001)
assert not difference, difference.info
astro_grower.grow()
difference = diff(astro_grower.neuron, _path / 'astrocyte.h5', atol=0.001)
assert not difference, difference.info
2 changes: 1 addition & 1 deletion tns/astrocyte/section.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _neighbor_contribution(self, current_point):
if pcloud_direction is not None:
return pcloud_direction

return get_random_point()
return get_random_point(random_generator=self._rng)

def next_direction(self, current_point):
'''Given a starting point, find the new direction taking into account
Expand Down
2 changes: 1 addition & 1 deletion tns/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
""" tns version """
VERSION = "2.4.2"
VERSION = "2.4.3.dev0"

0 comments on commit fa69154

Please sign in to comment.