Skip to content

Commit

Permalink
updates to consistency for params
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Sep 8, 2024
1 parent 7613d39 commit 6269603
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 47 deletions.
24 changes: 12 additions & 12 deletions neurodsp/sim/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
###################################################################################################
###################################################################################################

def sig_yielder(sim_func, sim_params, n_sims):
def sig_yielder(sim_func, params, n_sims):
"""Generator to yield simulated signals from a given simulation function and parameters.
Parameters
----------
sim_func : callable
Function to create the simulated time series.
sim_params : dict
params : dict
The parameters for the simulated signal, passed into `sim_func`.
n_sims : int, optional
Number of simulations to set as the max.
Expand All @@ -27,42 +27,42 @@ def sig_yielder(sim_func, sim_params, n_sims):
"""

for _ in counter(n_sims):
yield sim_func(**sim_params)
yield sim_func(**params)


def sig_sampler(sim_func, sim_params, return_sim_params=False, n_sims=None):
def sig_sampler(sim_func, params, return_params=False, n_sims=None):
"""Generator to yield simulated signals from a parameter sampler.
Parameters
----------
sim_func : callable
Function to create the simulated time series.
sim_params : iterable
params : iterable
The parameters for the simulated signal, passed into `sim_func`.
return_sim_params : bool, optional, default: False
return_params : bool, optional, default: False
Whether to yield the simulation parameters as well as the simulated time series.
n_sims : int, optional
Number of simulations to set as the max.
If None, length is defined by the length of `sim_params`, and could be infinite.
If None, length is defined by the length of `params`, and could be infinite.
Yields
------
sig : 1d array
Simulated time series.
sample_params : dict
Simulation parameters for the yielded time series.
Only returned if `return_sim_params` is True.
Only returned if `return_params` is True.
"""

# If `sim_params` has a size, and `n_sims` is defined, check that they are compatible
# If `params` has a size, and `n_sims` is defined, check that they are compatible
# To do so, we first check if the iterable has a __len__ attr, and if so check values
if isinstance(sim_params, Sized) and len(sim_params) and n_sims and n_sims > len(sim_params):
if isinstance(params, Sized) and len(params) and n_sims and n_sims > len(params):
msg = 'Cannot simulate the requested number of sims with the given parameters.'
raise ValueError(msg)

Check warning on line 61 in neurodsp/sim/generators.py

View check run for this annotation

Codecov / codecov/patch

neurodsp/sim/generators.py#L60-L61

Added lines #L60 - L61 were not covered by tests

for ind, sample_params in zip(counter(n_sims), sim_params):
for ind, sample_params in zip(counter(n_sims), params):

if return_sim_params:
if return_params:
yield sim_func(**sample_params), sample_params
else:
yield sim_func(**sample_params)
Expand Down
42 changes: 21 additions & 21 deletions neurodsp/sim/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
###################################################################################################
###################################################################################################

def sim_multiple(sim_func, sim_params, n_sims):
def sim_multiple(sim_func, params, n_sims):
"""Simulate multiple samples of a specified simulation.
Parameters
----------
sim_func : callable
Function to create the simulated time series.
sim_params : dict
params : dict
The parameters for the simulated signal, passed into `sim_func`.
n_sims : int
Number of simulations to create.
Expand All @@ -34,21 +34,21 @@ def sim_multiple(sim_func, sim_params, n_sims):
>>> sims = sim_multiple(sim_powerlaw, params, n_sims=3)
"""

sims = Simulations(n_sims, sim_params, sim_func)
for ind, sig in enumerate(sig_yielder(sim_func, sim_params, n_sims)):
sims = Simulations(n_sims, params, sim_func)
for ind, sig in enumerate(sig_yielder(sim_func, params, n_sims)):
sims.add_signal(sig, index=ind)

return sims


def sim_across_values(sim_func, sim_params):
def sim_across_values(sim_func, params):
"""Simulate signals across different parameter values.
Parameters
----------
sim_func : callable
Function to create the simulated time series.
sim_params : ParamIter or iterable or list of dict
params : ParamIter or iterable or list of dict
Simulation parameters for `sim_func`.
Returns
Expand All @@ -74,24 +74,24 @@ def sim_across_values(sim_func, sim_params):
>>> sims = sim_across_values(sim_powerlaw, params)
"""

sims = VariableSimulations(len(sim_params), get_base_params(sim_params), sim_func,
update=getattr(sim_params, 'update', None),
component=getattr(sim_params, 'component', None))
sims = VariableSimulations(len(params), get_base_params(params), sim_func,
update=getattr(params, 'update', None),
component=getattr(params, 'component', None))

for ind, cur_sim_params in enumerate(sim_params):
sims.add_signal(sim_func(**cur_sim_params), cur_sim_params, index=ind)
for ind, cur_params in enumerate(params):
sims.add_signal(sim_func(**cur_params), cur_params, index=ind)

return sims


def sim_multi_across_values(sim_func, sim_params, n_sims):
def sim_multi_across_values(sim_func, params, n_sims):
"""Simulate multiple signals across different parameter values.
Parameters
----------
sim_func : callable
Function to create the simulated time series.
sim_params : ParamIter or iterable or list of dict
params : ParamIter or iterable or list of dict
Simulation parameters for `sim_func`.
n_sims : int
Number of simulations to create per parameter definition.
Expand Down Expand Up @@ -119,22 +119,22 @@ def sim_multi_across_values(sim_func, sim_params, n_sims):
>>> sims = sim_multi_across_values(sim_powerlaw, params, n_sims=2)
"""

sims = MultiSimulations(update=getattr(sim_params, 'update', None),
component=getattr(sim_params, 'component', None))
for cur_sim_params in sim_params:
sims.add_signals(sim_multiple(sim_func, cur_sim_params, n_sims))
sims = MultiSimulations(update=getattr(params, 'update', None),
component=getattr(params, 'component', None))
for cur_params in params:
sims.add_signals(sim_multiple(sim_func, cur_params, n_sims))

return sims


def sim_from_sampler(sim_func, sim_sampler, n_sims):
def sim_from_sampler(sim_func, sampler, n_sims):
"""Simulate a set of signals from a parameter sampler.
Parameters
----------
sim_func : callable
Function to create the simulated time series.
sim_sampler : ParamSampler
sampler : ParamSampler
Parameter definition to sample from.
n_sims : int
Number of simulations to create per parameter definition.
Expand All @@ -157,8 +157,8 @@ def sim_from_sampler(sim_func, sim_sampler, n_sims):
>>> sims = sim_from_sampler(sim_powerlaw, param_sampler, n_sims=2)
"""

sims = VariableSimulations(n_sims, get_base_params(sim_sampler), sim_func)
for ind, (sig, params) in enumerate(sig_sampler(sim_func, sim_sampler, True, n_sims)):
sims = VariableSimulations(n_sims, get_base_params(sampler), sim_func)
for ind, (sig, params) in enumerate(sig_sampler(sim_func, sampler, True, n_sims)):
sims.add_signal(sig, params, index=ind)

return sims
20 changes: 10 additions & 10 deletions neurodsp/sim/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ def create_updater(update, component=None):

## PARAM ITER

def param_iter_yielder(sim_params, updater, values):
def param_iter_yielder(params, updater, values):
"""Parameter yielder.
Parameters
----------
sim_params : dict
params : dict
Parameter definition.
updater : callable
Updater function to update parameter definition.
Expand All @@ -121,15 +121,15 @@ def param_iter_yielder(sim_params, updater, values):
Yields
------
sim_params : dict
params : dict
Simulation parameter definition.
"""

sim_params = deepcopy(sim_params)
params = deepcopy(params)

for value in values:
updater(sim_params, value)
yield deepcopy(sim_params)
updater(params, value)
yield deepcopy(params)


class ParamIter(BaseUpdater):
Expand Down Expand Up @@ -249,12 +249,12 @@ def create_sampler(values, probs=None, n_samples=None):
yield np.random.choice(values, p=probs)


def param_sample_yielder(sim_params, samplers, n_samples=None):
def param_sample_yielder(params, samplers, n_samples=None):
"""Generator to yield randomly sampled parameter definitions.
Parameters
----------
sim_params : dict
params : dict
The parameters for the simulated signal.
samplers : dict
Sampler definitions to update parameters with.
Expand All @@ -266,12 +266,12 @@ def param_sample_yielder(sim_params, samplers, n_samples=None):
Yields
------
sim_params : dict
params : dict
Simulation parameter definition.
"""

for _ in counter(n_samples):
out_params = deepcopy(sim_params)
out_params = deepcopy(params)
for updater, sampler in samplers.items():
updater(out_params, next(sampler))

Expand Down
8 changes: 4 additions & 4 deletions neurodsp/tests/sim/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def test_create_updater():

def test_param_iter_yielder():

sim_params = {'n_seconds' : 5, 'fs' : 250, 'exponent' : None}
params = {'n_seconds' : 5, 'fs' : 250, 'exponent' : None}
updater = create_updater('exponent')
values = [-2, -1, 0]

iter_yielder = param_iter_yielder(sim_params, updater, values)
iter_yielder = param_iter_yielder(params, updater, values)
for ind, params in enumerate(iter_yielder):
assert isinstance(params, dict)
for el in ['n_seconds', 'fs', 'exponent']:
Expand All @@ -56,11 +56,11 @@ def test_param_iter_yielder():

def test_class_param_iter():

sim_params = {'n_seconds' : 5, 'fs' : 250, 'exponent' : None}
params = {'n_seconds' : 5, 'fs' : 250, 'exponent' : None}
update = 'exponent'
values = [-2, -1, 0]

piter = ParamIter(sim_params, update, values)
piter = ParamIter(params, update, values)
assert piter
for ind, params in enumerate(piter):
assert isinstance(params, dict)
Expand Down

0 comments on commit 6269603

Please sign in to comment.